diff options
author | Sonja Galovic <galovicsonja@gmail.com> | 2022-03-30 23:55:26 +0200 |
---|---|---|
committer | Sonja Galovic <galovicsonja@gmail.com> | 2022-03-30 23:55:26 +0200 |
commit | 5a223516a7481098fadab81ad062e5ec5b38144c (patch) | |
tree | 24dc2fe7cde46194b6addfd35e1e1a0ea44242dc /backend/microservice | |
parent | 0467667df8e5beaa08f6546cb6ef93ebd3c8db8d (diff) | |
parent | 39fc1f0cc9871b4436b839acb6ce4260e6c33931 (diff) |
Merge branch 'dev' of http://gitlab.pmf.kg.ac.rs/igrannonica/neuronstellar into dev
Diffstat (limited to 'backend/microservice')
-rw-r--r-- | backend/microservice/api/ml_service.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py index efd24fdc..68595a89 100644 --- a/backend/microservice/api/ml_service.py +++ b/backend/microservice/api/ml_service.py @@ -34,6 +34,7 @@ class TrainingResult: tpr: float def train(dataset, params, callback): + problem_type = params["type"] data = pd.DataFrame() for col in params["inputColumns"]: data[col]=dataset[col] @@ -123,13 +124,17 @@ def train(dataset, params, callback): # # Test # - y_pred=classifier.predict(x_test) - y_pred=(y_pred>=0.5).astype('int') - #y_pred=(y_pred * 100).astype('int') - y_pred=y_pred.flatten() - result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) - model_name = params['_id'] - classifier.save("temp/"+model_name, save_format='h5') + if(problem_type == "regresioni"): + classifier.evaluate(x_test, y_test) + classifier.save("temp/"+model_name, save_format='h5') + elif(problem_type == "binarni-klasifikacioni"): + y_pred=classifier.predict(x_test) + y_pred=(y_pred>=0.5).astype('int') + #y_pred=(y_pred * 100).astype('int') + y_pred=y_pred.flatten() + result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) + model_name = params['_id'] + classifier.save("temp/"+model_name, save_format='h5') # # Metrike # |