Skip to content

Commit

Permalink
Merge pull request #11 from Theodlz/missing-continue
Browse files Browse the repository at this point in the history
* track available models and filters, map missing filters to similar ones when submitting analysis jobs
  • Loading branch information
Theodlz authored Nov 29, 2023
2 parents 85a3aef + 6148457 commit 1d4b360
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ run
env

config.yaml
models.yaml

.DS_Store
29 changes: 29 additions & 0 deletions nmma_api/services/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import traceback
from datetime import datetime
from astropy.table import Table

import tornado.escape
import tornado.web
Expand All @@ -11,6 +12,7 @@
from nmma_api.utils.logs import make_log
from nmma_api.utils.mongo import Mongo, init_db
from nmma_api.tools.expanse import validate_credentials
from nmma_api.tools.enums import verify_and_match_filter

log = make_log("main")

Expand Down Expand Up @@ -39,6 +41,33 @@ def validate(data: dict) -> str:
f"model {model} is not allowed, must be one of: {ALLOWED_MODELS.join(',')}"
)

if "photometry" in data["inputs"]:
if (
isinstance(data["inputs"]["photometry"], str)
and len(data["inputs"]["photometry"]) > 0
):
temp = Table.read(data["inputs"]["photometry"], format="ascii.csv")
skipped = 0
skipped_filters = []
for row in temp:
try:
row["filter"] = verify_and_match_filter(model, row["filter"])
except ValueError:
skipped += 1
continue
if skipped == len(temp):
log(
"No valid filters found in photometry data for this model, cancelling analysis submission."
)
return "no valid filters found in photometry data"
elif skipped > 0:
log(
f"Will skip {skipped} rows in photometry data due to invalid filters for this model: {', '.join(list(set(skipped_filters)))}"
)

else:
return "photometry data must be a ascii csv string"

return None


Expand Down
3 changes: 2 additions & 1 deletion nmma_api/services/retrieval_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def retrieval_queue():
continue

# analysis has been running for too long, cancel the job and set the status to job_expired
# the submission queue will take of starting the plot generation job
# the submission queue will take care of starting the plot generation job
# and setting the status to "running_plot"
if analysis["status"] == "running" and analysis.get(
"submitted_at"
Expand All @@ -78,6 +78,7 @@ def retrieval_queue():
{"_id": analysis["_id"]},
{"$set": {"status": "job_expired"}},
)
continue

# analysis failed to submit to expanse, update the status upstream
if analysis["status"] == "failed_submission_to_upload":
Expand Down
5 changes: 3 additions & 2 deletions nmma_api/services/submission_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def submission_queue():
jobs = submit(analysis_requests)
for analysis_request in analysis_requests:
job = jobs.get(analysis_request["_id"], {})
error = jobs.get(analysis_request["_id"], {}).get("error", "")
message = jobs.get(analysis_request["_id"], {}).get("message", "")
if job.get("job_id") is not None:
mongo.db.analysis.update_one(
{"_id": analysis_request["_id"]},
Expand All @@ -44,6 +44,7 @@ def submission_queue():
else "running",
"job_id": job.get("job_id"),
"submitted_at": job.get("submitted_at"),
"warning": message,
}
},
)
Expand All @@ -53,7 +54,7 @@ def submission_queue():
{
"$set": {
"status": "failed_submission_to_upload",
"error": error,
"error": message,
"job_id": None,
}
},
Expand Down
61 changes: 61 additions & 0 deletions nmma_api/tools/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import requests
import yaml
import os

# we map sncosmo filters for which we have no trained models to similar filters for which we do have trained models

REPO = "https://gitlab.com/Theodlz/nmma-models/raw/main/models.yaml"

FILTERS_MAPPER = {
"sdssg": "ps1__g",
"sdssi": "ps1__i",
"sdssr": "ps1__r",
"sdssz": "ps1__z",
"sdssu": "ps1__u",
}


def fetch_models():
# check if the file exists
try:
if os.path.exists("models.yaml"):
with open("models.yaml", "r") as f:
return yaml.safe_load(f)
except Exception:
pass

response = requests.get(REPO)
content = response.content.decode("utf-8")
models = yaml.safe_load(content)
# save to file
with open("models.yaml", "w") as f:
yaml.dump(models, f)

return models


CENTRAL_WAVELENGTH_MODELS = ["Me2017", "Piro2021", "nugent-hyper", "TrPi2018"]
FIXED_FILTERS_MODELS = fetch_models()


def verify_and_match_filter(model, filter):
if model in CENTRAL_WAVELENGTH_MODELS:
return filter

# we only support _tf models for now, so if the model does not end with _tf, we add it
if not model.endswith("_tf"):
model = model + "_tf"

if model not in FIXED_FILTERS_MODELS:
raise ValueError(f"Model {model} not found")

if filter not in FIXED_FILTERS_MODELS[model].get("filters", []):
# see if there is a similar filter
replacement = FILTERS_MAPPER.get(filter)
if replacement and replacement in FIXED_FILTERS_MODELS[model].get(
"filters", []
):
return replacement
raise ValueError(f"Filter {filter} not found in model {model}")

return filter
19 changes: 18 additions & 1 deletion nmma_api/tools/expanse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nmma_api.utils.logs import make_log
from nmma_api.utils.config import load_config
from nmma_api.tools.enums import verify_and_match_filter


config = load_config()
Expand Down Expand Up @@ -118,6 +119,8 @@ def submit(analyses: list[dict], **kwargs) -> bool:
except Exception as e:
raise ValueError(f"input data is not in the expected format {e}")

skipped = 0
skipped_filters = []
try:
# Set trigger time based on first detection
TT = np.min(data[data["mag"] != np.ma.masked]["mjd"])
Expand All @@ -138,10 +141,17 @@ def submit(analyses: list[dict], **kwargs) -> bool:
]
for row in data:
tt = Time(row["mjd"], format="mjd").isot
filt = row["filter"]
try:
filt = verify_and_match_filter(MODEL, row["filter"])
except ValueError:
skipped += 1
skipped_filters.append(row["filter"])
continue
mag = row["mag"]
magerr = row["magerr"]
f.write(f"{tt} {filt} {mag} {magerr}\n")
if skipped == len(data):
raise ValueError("no valid filters found in photometry data")
except Exception as e:
raise ValueError(f"failed to format data {e}")

Expand Down Expand Up @@ -175,6 +185,10 @@ def submit(analyses: list[dict], **kwargs) -> bool:
"message": "",
"submitted_at": datetime.timestamp(datetime.utcnow()),
}
if skipped > 0:
jobs[data_dict["_id"]][
"message"
] = f"Skipped {skipped} observations with filters: {', '.join(list(set(skipped_filters)))} as they are not supported by the model."
log(f"Submitted job {job_id} for analysis {data_dict['_id']}")
except Exception as e:
log(f"Failed to submit analysis {data_dict['_id']} to expanse: {e}")
Expand Down Expand Up @@ -250,6 +264,9 @@ def retrieve(analysis: dict) -> dict:
pop_list = ["samples", "nested_samples"]
[result.pop(x) for x in pop_list]

if "warning" in analysis:
result["warning"] = analysis["warning"]

f = tempfile.NamedTemporaryFile(suffix=".png", prefix="nmmaplot_", delete=False)
f.close()

Expand Down

0 comments on commit 1d4b360

Please sign in to comment.