From b573b0069512f4444aefdac0a24eb99c8fcb6586 Mon Sep 17 00:00:00 2001 From: Danijel Andjelkovic Date: Thu, 31 Mar 2022 13:48:13 +0200 Subject: Malo sredio add-model stranicu, popravio bug sa prikazom javnih datasetova na naslovnoj strani. ML socket salje poruke za epohe backendu. --- backend/microservice/api/controller.py | 1 + backend/microservice/api/ml_service.py | 12 +++++------- backend/microservice/api/ml_socket.py | 2 ++ 3 files changed, 8 insertions(+), 7 deletions(-) (limited to 'backend/microservice/api') diff --git a/backend/microservice/api/controller.py b/backend/microservice/api/controller.py index ceed02ad..059af317 100644 --- a/backend/microservice/api/controller.py +++ b/backend/microservice/api/controller.py @@ -16,6 +16,7 @@ class train_callback(tf.keras.callbacks.Callback): # def on_epoch_end(self, epoch, logs=None): print(epoch) + ml_socket.send(epoch) #print('Evaluation: ', self.model.evaluate(self.x_test,self.y_test),"\n") #broj parametara zavisi od izabranih metrika loss je default @app.route('/train', methods = ['POST']) diff --git a/backend/microservice/api/ml_service.py b/backend/microservice/api/ml_service.py index 68595a89..ea562212 100644 --- a/backend/microservice/api/ml_service.py +++ b/backend/microservice/api/ml_service.py @@ -124,17 +124,15 @@ def train(dataset, params, callback): # # Test # + model_name = params['_id'] + y_pred=classifier.predict(x_test) if(problem_type == "regresioni"): classifier.evaluate(x_test, y_test) - classifier.save("temp/"+model_name, save_format='h5') elif(problem_type == "binarni-klasifikacioni"): - y_pred=classifier.predict(x_test) y_pred=(y_pred>=0.5).astype('int') - #y_pred=(y_pred * 100).astype('int') - y_pred=y_pred.flatten() - result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) - model_name = params['_id'] - classifier.save("temp/"+model_name, save_format='h5') + y_pred=y_pred.flatten() + result=pd.DataFrame({"Actual":y_test,"Predicted":y_pred}) + classifier.save("temp/"+model_name, save_format='h5') # # Metrike # diff --git a/backend/microservice/api/ml_socket.py b/backend/microservice/api/ml_socket.py index 65dd7321..cab157eb 100644 --- a/backend/microservice/api/ml_socket.py +++ b/backend/microservice/api/ml_socket.py @@ -25,4 +25,6 @@ async def start(): get_or_create_eventloop().run_forever() async def send(msg): + print("WS sending message:") + print(msg) await websocket.send(msg) \ No newline at end of file -- cgit v1.2.3