aboutsummaryrefslogtreecommitdiff
path: root/backend/microservice
diff options
context:
space:
mode:
authorSonja Galovic <galovicsonja@gmail.com>2022-03-30 23:55:26 +0200
committerSonja Galovic <galovicsonja@gmail.com>2022-03-30 23:55:26 +0200
commit5a223516a7481098fadab81ad062e5ec5b38144c (patch)
tree24dc2fe7cde46194b6addfd35e1e1a0ea44242dc /backend/microservice
parent0467667df8e5beaa08f6546cb6ef93ebd3c8db8d (diff)
parent39fc1f0cc9871b4436b839acb6ce4260e6c33931 (diff)
Merge branch 'dev' of http://gitlab.pmf.kg.ac.rs/igrannonica/neuronstellar into dev
Diffstat (limited to 'backend/microservice')
-rw-r--r--backend/microservice/api/ml_service.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py
index efd24fdc..68595a89 100644
--- a/backend/microservice/api/ml_service.py
+++ b/backend/microservice/api/ml_service.py
@@ -34,6 +34,7 @@ class TrainingResult:
tpr: float
def train(dataset, params, callback):
+ problem_type = params["type"]
data = pd.DataFrame()
for col in params["inputColumns"]:
data[col]=dataset[col]
@@ -123,13 +124,17 @@ def train(dataset, params, callback):
#
# Test
#
- y_pred=classifier.predict(x_test)
- y_pred=(y_pred>=0.5).astype('int')
- #y_pred=(y_pred * 100).astype('int')
- y_pred=y_pred.flatten()
- result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
- model_name = params['_id']
- classifier.save("temp/"+model_name, save_format='h5')
+ if(problem_type == "regresioni"):
+ classifier.evaluate(x_test, y_test)
+ classifier.save("temp/"+model_name, save_format='h5')
+ elif(problem_type == "binarni-klasifikacioni"):
+ y_pred=classifier.predict(x_test)
+ y_pred=(y_pred>=0.5).astype('int')
+ #y_pred=(y_pred * 100).astype('int')
+ y_pred=y_pred.flatten()
+ result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred})
+ model_name = params['_id']
+ classifier.save("temp/"+model_name, save_format='h5')
#
# Metrike
#