Skip to content

Commit

Permalink
Update plotestimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Sep 18, 2024
1 parent 6d5143f commit c611a2f
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions artistools/estimators/plotestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def plot_series(
if isinstance(variable, pl.Expr):
colexpr = variable
else:
assert variable in estimators.collect_schema().names()
assert variable in estimators.collect_schema().names(), f"Variable {variable} not found in estimators"
colexpr = pl.col(variable)

variablename = colexpr.meta.output_name()
Expand Down Expand Up @@ -595,6 +595,7 @@ def get_xlist(
.lazy()
.collect()
)
assert len(pointgroups) > 0, "No data found for x-axis variable"

return (
pointgroups["xvalue"].to_list(),
Expand Down Expand Up @@ -764,7 +765,6 @@ def make_plot(
args=args,
)
startfromzero = (xvariable.startswith("velocity") or xvariable == "beta") and not args.markersonly

xmin = args.xmin if args.xmin >= 0 else min(xlist)
xmax = args.xmax if args.xmax > 0 else max(xlist)

Expand Down Expand Up @@ -1035,6 +1035,7 @@ def main(args: argparse.Namespace | None = None, argsraw: t.Sequence[str] | None
estimators = at.estimators.scan_estimators(
modelpath=modelpath, modelgridindex=args.modelgridindex, timestep=tuple(timesteps_included)
)

assert estimators is not None
tmids = at.get_timestep_times(modelpath, loc="mid")
estimators = estimators.join(
Expand Down Expand Up @@ -1089,10 +1090,13 @@ def main(args: argparse.Namespace | None = None, argsraw: t.Sequence[str] | None
if args.x == "velocity" and modelmeta["vmax_cmps"] > 0.3 * 29979245800:
args.x = "beta"

dfmodel = dfmodel.filter(pl.col("vel_r_mid") <= modelmeta["vmax_cmps"])
estimators = estimators.join(dfmodel, on="modelgridindex")
estimators = estimators.with_columns(
rho_init=pl.col("rho"), rho=pl.col("rho") * (modelmeta["t_model_init_days"] / pl.col("time_mid")) ** 3
dfmodel = dfmodel.filter(pl.col("vel_r_mid") <= modelmeta["vmax_cmps"]).rename({
colname: f"init_{colname}"
for colname in dfmodel.collect_schema().names()
if not colname.startswith("vel_") and colname not in {"inputcellid", "modelgridindex", "mass_g"}
})
estimators = estimators.join(dfmodel, on="modelgridindex", suffix="_initmodel").with_columns(
rho=pl.col("init_rho") * (modelmeta["t_model_init_days"] / pl.col("time_mid")) ** 3
)

if args.readonlymgi:
Expand Down

0 comments on commit c611a2f

Please sign in to comment.