aboutsummaryrefslogtreecommitdiff
path: root/backend/microservice/api/ml_service.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/microservice/api/ml_service.py')
-rw-r--r--backend/microservice/api/ml_service.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py
index 68595a89..ea562212 100644
--- a/backend/microservice/api/ml_service.py
+++ b/backend/microservice/api/ml_service.py
@@ -124,17 +124,15 @@ def train(dataset, params, callback):
#
# Test
#
+ model_name = params['_id']
+ y_pred=classifier.predict(x_test)
if(problem_type == "regresioni"):
classifier.evaluate(x_test, y_test)
- classifier.save("temp/"+model_name, save_format='h5')
elif(problem_type == "binarni-klasifikacioni"):
- y_pred=classifier.predict(x_test)
y_pred=(y_pred>=0.5).astype('int')
- #y_pred=(y_pred * 100).astype('int')
- y_pred=y_pred.flatten()
- result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
- model_name = params['_id']
- classifier.save("temp/"+model_name, save_format='h5')
+ y_pred=y_pred.flatten()
+ result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
+ classifier.save("temp/"+model_name, save_format='h5')
#
# Metrike
#