Skip to content

Commit

Permalink
Refactor get_trial function to handle source trials recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
amirnd51 committed Sep 3, 2024
1 parent 73ae791 commit 129c735
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
34 changes: 29 additions & 5 deletions python_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ async def predict(request: PredictRequest):
# experiment_id=create_expriement( cur, conn)

trial= get_trial_by_model_and_input( model_id, inputs)
# print(trial)
if not experiment_id:
cur,conn=get_db_cur_con()
experiment_id=create_expriement(cur, conn)
Expand Down Expand Up @@ -401,11 +402,29 @@ async def delete_trial(trial_id: str):
@app.get("/trial/{trial_id}")
async def get_trial(trial_id: str):
cur,conn=get_db_cur_con()

# check if trial has a source trial
cur.execute("""
SELECT * FROM trials t
WHERE t.id = %s
""", (trial_id,))
row = cur.fetchone()
if row["source_trial_id"] is not None:
# print("\n\n\n\n\n\n\n\n\n\n")
# print(row["source_trial_id"])
source_trial= await get_trial(row["source_trial_id"])
return source_trial
print(row)
# else

cur.close()
cur,conn=get_db_cur_con()


cur.execute("""
SELECT t.id AS trial_id,
t.result,
t.source_trial_id as source_trial,
t.source_trial_id as source_trial_id,
t.completed_at,
ti.url AS input_url,
m.id AS modelId,
Expand Down Expand Up @@ -450,10 +469,15 @@ async def get_trial(trial_id: str):
row = cur.fetchone()

if not row:
# raise Exception(f"No trial found with ID {trial_id}")
return None
if row["source_trial"] is not None:
return get_trial(row["source_trial"])
raise Exception(f"No trial found with ID {trial_id}")
# return None
print(row)

if row["source_trial_id"] is not None:
# print("\n\n\n\n\n\n\n\n\n\n")
# print(row["source_trial_id"])

return get_trial(row["source_trial_id"])
# print(row)
# Prepare the response structure
result = {
Expand Down
7 changes: 5 additions & 2 deletions python_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def close_db_cur_con(cur, conn):

def create_trial( model_id, experiment_id, cur, conn,source_id="",completed_at=None):
trial_id= str(uuid.uuid4())
cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id))
if source_id!="":
cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id))
else:
cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,experiment_id) VALUES (%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() , experiment_id))
conn.commit()
return trial_id

Expand Down Expand Up @@ -77,7 +80,7 @@ def get_trial_by_model_and_input(model_id, input_urls):
JOIN experiments ON trials.experiment_id = experiments.id
JOIN models ON trials.model_id = models.id
JOIN trial_inputs ON trials.id = trial_inputs.trial_id
WHERE trials.completed_at IS NOT NULL
WHERE trials.results IS NOT NULL
AND trials.model_id = %s
AND ({input_query})
"""
Expand Down

0 comments on commit 129c735

Please sign in to comment.