aboutsummaryrefslogtreecommitdiff
path: root/backend/microservice/api
diff options
context:
space:
mode:
Diffstat (limited to 'backend/microservice/api')
-rw-r--r--backend/microservice/api/ml_service.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py
index efd24fdc..68595a89 100644
--- a/backend/microservice/api/ml_service.py
+++ b/backend/microservice/api/ml_service.py
@@ -34,6 +34,7 @@ class TrainingResult:
tpr: float
def train(dataset, params, callback):
+ problem_type = params["type"]
data = pd.DataFrame()
for col in params["inputColumns"]:
data[col]=dataset[col]
@@ -123,13 +124,17 @@ def train(dataset, params, callback):
#
# Test
#
- y_pred=classifier.predict(x_test)
- y_pred=(y_pred>=0.5).astype('int')
- #y_pred=(y_pred * 100).astype('int')
- y_pred=y_pred.flatten()
- result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
- model_name = params['_id']
- classifier.save("temp/"+model_name, save_format='h5')
+ if(problem_type == "regresioni"):
+ classifier.evaluate(x_test, y_test)
+ classifier.save("temp/"+model_name, save_format='h5')
+ elif(problem_type == "binarni-klasifikacioni"):
+ y_pred=classifier.predict(x_test)
+ y_pred=(y_pred>=0.5).astype('int')
+ #y_pred=(y_pred * 100).astype('int')
+ y_pred=y_pred.flatten()
+ result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
+ model_name = params['_id']
+ classifier.save("temp/"+model_name, save_format='h5')
#
# Metrike
#