diff --git a/python_api/api.py b/python_api/api.py index 4816458..4ae734f 100644 --- a/python_api/api.py +++ b/python_api/api.py @@ -341,6 +341,7 @@ async def predict(request: PredictRequest): # experiment_id=create_expriement( cur, conn) trial= get_trial_by_model_and_input( model_id, inputs) + # print(trial) if not experiment_id: cur,conn=get_db_cur_con() experiment_id=create_expriement(cur, conn) @@ -401,11 +402,29 @@ async def delete_trial(trial_id: str): @app.get("/trial/{trial_id}") async def get_trial(trial_id: str): cur,conn=get_db_cur_con() + + # check if trial has a source trial + cur.execute(""" + SELECT * FROM trials t + WHERE t.id = %s + """, (trial_id,)) + row = cur.fetchone() + if row["source_trial_id"] is not None: + # print("\n\n\n\n\n\n\n\n\n\n") + # print(row["source_trial_id"]) + source_trial= await get_trial(row["source_trial_id"]) + return source_trial + print(row) + # else + + cur.close() + cur,conn=get_db_cur_con() + cur.execute(""" SELECT t.id AS trial_id, t.result, - t.source_trial_id as source_trial, + t.source_trial_id as source_trial_id, t.completed_at, ti.url AS input_url, m.id AS modelId, @@ -450,10 +469,15 @@ async def get_trial(trial_id: str): row = cur.fetchone() if not row: - # raise Exception(f"No trial found with ID {trial_id}") - return None - if row["source_trial"] is not None: - return get_trial(row["source_trial"]) + raise Exception(f"No trial found with ID {trial_id}") + # return None + print(row) + + if row["source_trial_id"] is not None: + # print("\n\n\n\n\n\n\n\n\n\n") + # print(row["source_trial_id"]) + + return get_trial(row["source_trial_id"]) # print(row) # Prepare the response structure result = { diff --git a/python_api/db.py b/python_api/db.py index a39cfbb..64d0703 100644 --- a/python_api/db.py +++ b/python_api/db.py @@ -32,7 +32,10 @@ def close_db_cur_con(cur, conn): def create_trial( model_id, experiment_id, cur, conn,source_id="",completed_at=None): trial_id= str(uuid.uuid4()) - cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id)) + if source_id!="": + cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id)) + else: + cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,experiment_id) VALUES (%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() , experiment_id)) conn.commit() return trial_id @@ -77,7 +80,7 @@ def get_trial_by_model_and_input(model_id, input_urls): JOIN experiments ON trials.experiment_id = experiments.id JOIN models ON trials.model_id = models.id JOIN trial_inputs ON trials.id = trial_inputs.trial_id - WHERE trials.completed_at IS NOT NULL + WHERE trials.results IS NOT NULL AND trials.model_id = %s AND ({input_query}) """