diff options
Diffstat (limited to 'backend/microservice/api')
-rw-r--r-- | backend/microservice/api/ml_service.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py index efd24fdc..68595a89 100644 --- a/backend/microservice/api/ml_service.py +++ b/backend/microservice/api/ml_service.py @@ -34,6 +34,7 @@ class TrainingResult: tpr: float def train(dataset, params, callback): + problem_type = params["type"] data = pd.DataFrame() for col in params["inputColumns"]: data[col]=dataset[col] @@ -123,13 +124,17 @@ def train(dataset, params, callback): # # Test # - y_pred=classifier.predict(x_test) - y_pred=(y_pred>=0.5).astype('int') - #y_pred=(y_pred * 100).astype('int') - y_pred=y_pred.flatten() - result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) - model_name = params['_id'] - classifier.save("temp/"+model_name, save_format='h5') + if(problem_type == "regresioni"): + classifier.evaluate(x_test, y_test) + classifier.save("temp/"+model_name, save_format='h5') + elif(problem_type == "binarni-klasifikacioni"): + y_pred=classifier.predict(x_test) + y_pred=(y_pred>=0.5).astype('int') + #y_pred=(y_pred * 100).astype('int') + y_pred=y_pred.flatten() + result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) + model_name = params['_id'] + classifier.save("temp/"+model_name, save_format='h5') # # Metrike # |