Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submission queue - add missing continue #11

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ run
env

config.yaml
models.yaml

.DS_Store
29 changes: 29 additions & 0 deletions nmma_api/services/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import traceback
from datetime import datetime
from astropy.table import Table

import tornado.escape
import tornado.web
Expand All @@ -11,6 +12,7 @@
from nmma_api.utils.logs import make_log
from nmma_api.utils.mongo import Mongo, init_db
from nmma_api.tools.expanse import validate_credentials
from nmma_api.tools.enums import verify_and_match_filter

log = make_log("main")

Expand Down Expand Up @@ -39,6 +41,33 @@ def validate(data: dict) -> str:
f"model {model} is not allowed, must be one of: {ALLOWED_MODELS.join(',')}"
)

if "photometry" in data["inputs"]:
if (
isinstance(data["inputs"]["photometry"], str)
and len(data["inputs"]["photometry"]) > 0
):
temp = Table.read(data["inputs"]["photometry"], format="ascii.csv")
skipped = 0
skipped_filters = []
for row in temp:
try:
row["filter"] = verify_and_match_filter(model, row["filter"])
except ValueError:
skipped += 1
continue
if skipped == len(temp):
log(
"No valid filters found in photometry data for this model, cancelling analysis submission."
)
return "no valid filters found in photometry data"
elif skipped > 0:
log(
f"Will skip {skipped} rows in photometry data due to invalid filters for this model: {', '.join(list(set(skipped_filters)))}"
)

else:
return "photometry data must be a ascii csv string"

return None


Expand Down
3 changes: 2 additions & 1 deletion nmma_api/services/retrieval_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def retrieval_queue():
continue

# analysis has been running for too long, cancel the job and set the status to job_expired
# the submission queue will take of starting the plot generation job
# the submission queue will take care of starting the plot generation job
# and setting the status to "running_plot"
if analysis["status"] == "running" and analysis.get(
"submitted_at"
Expand All @@ -78,6 +78,7 @@ def retrieval_queue():
{"_id": analysis["_id"]},
{"$set": {"status": "job_expired"}},
)
continue

# analysis failed to submit to expanse, update the status upstream
if analysis["status"] == "failed_submission_to_upload":
Expand Down
5 changes: 3 additions & 2 deletions nmma_api/services/submission_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def submission_queue():
jobs = submit(analysis_requests)
for analysis_request in analysis_requests:
job = jobs.get(analysis_request["_id"], {})
error = jobs.get(analysis_request["_id"], {}).get("error", "")
message = jobs.get(analysis_request["_id"], {}).get("message", "")
if job.get("job_id") is not None:
mongo.db.analysis.update_one(
{"_id": analysis_request["_id"]},
Expand All @@ -44,6 +44,7 @@ def submission_queue():
else "running",
"job_id": job.get("job_id"),
"submitted_at": job.get("submitted_at"),
"warning": message,
}
},
)
Expand All @@ -53,7 +54,7 @@ def submission_queue():
{
"$set": {
"status": "failed_submission_to_upload",
"error": error,
"error": message,
"job_id": None,
}
},
Expand Down
56 changes: 56 additions & 0 deletions nmma_api/tools/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import requests
import yaml
import os

# we map sncosmo filters for which we have no trained models to similar filters for which we do have trained models

REPO = "https://gitlab.com/Theodlz/nmma-models/raw/main/models.yaml"

FILTERS_MAPPER = {
"sdssg": "ps1__g",
"sdssi": "ps1__i",
"sdssr": "ps1__r",
"sdssz": "ps1__z",
"sdssu": "ps1__u",
}


def fetch_models():
# check if the file exists
try:
if os.path.exists("models.yaml"):
with open("models.yaml", "r") as f:
return yaml.safe_load(f)
except Exception:
pass

response = requests.get(REPO)
content = response.content.decode("utf-8")
models = yaml.safe_load(content)
# save to file
with open("models.yaml", "w") as f:
yaml.dump(models, f)

return models


CENTRAL_WAVELENGTH_MODELS = ["Me2017", "Piro2021", "nugent-hyper", "TrPi2018"]
FIXED_FILTERS_MODELS = fetch_models()


def verify_and_match_filter(model, filter):
if model in CENTRAL_WAVELENGTH_MODELS:
return filter
elif model not in FIXED_FILTERS_MODELS:
raise ValueError(f"Model {model} not found")

if filter not in FIXED_FILTERS_MODELS[model].get("filters", []):
# see if there is a similar filter
replacement = FILTERS_MAPPER.get(filter)
if replacement and replacement in FIXED_FILTERS_MODELS[model].get(
"filters", []
):
return replacement
raise ValueError(f"Filter {filter} not found in model {model}")

return filter
19 changes: 18 additions & 1 deletion nmma_api/tools/expanse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nmma_api.utils.logs import make_log
from nmma_api.utils.config import load_config
from nmma_api.tools.enums import verify_and_match_filter


config = load_config()
Expand Down Expand Up @@ -118,6 +119,8 @@ def submit(analyses: list[dict], **kwargs) -> bool:
except Exception as e:
raise ValueError(f"input data is not in the expected format {e}")

skipped = 0
skipped_filters = []
try:
# Set trigger time based on first detection
TT = np.min(data[data["mag"] != np.ma.masked]["mjd"])
Expand All @@ -138,10 +141,17 @@ def submit(analyses: list[dict], **kwargs) -> bool:
]
for row in data:
tt = Time(row["mjd"], format="mjd").isot
filt = row["filter"]
try:
filt = verify_and_match_filter(MODEL, row["filter"])
except ValueError:
skipped += 1
skipped_filters.append(row["filter"])
continue
mag = row["mag"]
magerr = row["magerr"]
f.write(f"{tt} {filt} {mag} {magerr}\n")
if skipped == len(data):
raise ValueError("no valid filters found in photometry data")
except Exception as e:
raise ValueError(f"failed to format data {e}")

Expand Down Expand Up @@ -175,6 +185,10 @@ def submit(analyses: list[dict], **kwargs) -> bool:
"message": "",
"submitted_at": datetime.timestamp(datetime.utcnow()),
}
if skipped > 0:
jobs[data_dict["_id"]][
"message"
] = f"Skipped {skipped} observations with filters: {', '.join(list(set(skipped_filters)))} as they are not supported by the model."
log(f"Submitted job {job_id} for analysis {data_dict['_id']}")
except Exception as e:
log(f"Failed to submit analysis {data_dict['_id']} to expanse: {e}")
Expand Down Expand Up @@ -250,6 +264,9 @@ def retrieve(analysis: dict) -> dict:
pop_list = ["samples", "nested_samples"]
[result.pop(x) for x in pop_list]

if "warning" in analysis:
result["warning"] = analysis["warning"]

f = tempfile.NamedTemporaryFile(suffix=".png", prefix="nmmaplot_", delete=False)
f.close()

Expand Down
Loading