Skip to content

Commit

Permalink
Merge branch 'master' into update_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg authored Jan 29, 2021
2 parents 0c483e0 + a750601 commit 1b20d85
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
name: install dependencies
command: |
sudo pip install flake8
flake8 .
flake8 .
evaluator_test:
<<: *defaults
Expand Down
77 changes: 47 additions & 30 deletions flood_forecast/deployment/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,72 +11,84 @@


class InferenceMode(object):
def __init__(self, hours_to_forecast: int, num_prediction_samples: int, model_params, csv_path: str, weight_path,
def __init__(self, forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str, weight_path,
wandb_proj: str = None, torch_script=False):
"""Class to handle inference for models,
:param hours_to_forecast: Number of time-steps to forecasts (doesn't have to be hours)
:type hours_to_forecast: int
:param forecasts_steps: Number of time-steps to forecast (doesn't have to be hours)
:type forecast_steps: int
:param num_prediction_samples: Number of prediction samples
:type num_prediction_samples: int
:param model_params: [description]
:type model_params: [type]
:param csv_path: [description]
:param model_params: A dictionary of model parameters (ideally this should come from saved JSON config file)
:type model_params: Dict
:param csv_path: Path to the CSV test file you want to be used for inference. Evem of you aren't using
:type csv_path: str
:param weight_path: [description]
:type weight_path: [type]
:param wandb_proj: [description], defaults to None
:param weight_path: Path to the model weights
:type weight_path: str
:param wandb_proj: The name of the WB project leave blank if you don't want to log to Wandb, defaults to None
:type wandb_proj: str, optionals
"""
self.hours_to_forecast = hours_to_forecast
self.hours_to_forecast = forecast_steps
self.csv_path = csv_path
self.n_targets = model_params.get("n_targets")
self.targ_cols = model_params["dataset_params"]["target_col"]
self.model = load_model(model_params.copy(), csv_path, weight_path)
self.inference_params = model_params["inference_params"]
if "scaling" in self.inference_params["dataset_params"]:
s = scaling_function({}, self.inference_params["dataset_params"])["scaling"]
self.inference_params["dataset_params"]["scaling"] = s
self.inference_params["hours_to_forecast"] = hours_to_forecast
self.inference_params["hours_to_forecast"] = forecast_steps
self.inference_params["num_prediction_samples"] = num_prediction_samples
if wandb_proj:
date = datetime.now()
wandb.init(name=date.strftime("%H-%M-%D-%Y") + "_prod", project=wandb_proj)
wandb.config.update(model_params)
wandb.config.update(model_params, allow_val_change=True)

def infer_now(self, some_date: datetime, csv_path=None, save_buck=None, save_name=None, use_torch_script=False):
"""Performs inference at a specified datatime
"""Performs inference on a CSV file at a specified datatime
:param some_date: The date you want inference to begin on.
:param csv_path: [description], defaults to None
:type csv_path: [type], optional
:param save_buck: [description], defaults to None
:type save_buck: [type], optional
:param csv_path: A path to a CSV you want to perform inference on, defaults to None
:type csv_path: str, optional
:param save_buck: The GCP bucket where you want to save predictions, defaults to None
:type save_buck: str, optional
:param save_name: The name of the file to save the Pandas data-frame to GCP as, defaults to None
:type save_name: [type], optional
:type save_name: str, optional
:param use_torch_script: Optional parameter which allows you to use a saved torch script version of your model.
:return: Returns a tuple consisting of the Pandas dataframe with predictions + history,
the prediction tensor, a tensor of the historical values, the forecast start index, and the test
:rtype: [type]
the prediction tensor, a tensor of the historical values, the forecast start index, the test loader, and the
a dataframe of the prediction samples (e.g. the confidence interval preds)
:rtype: tuple(pd.DataFrame, torch.Tensor, int, CSVTestLoader, pd.DataFrame)
"""
forecast_history = self.inference_params["dataset_params"]["forecast_history"]
self.inference_params["datetime_start"] = some_date
if csv_path:
self.inference_params["test_csv_path"] = csv_path
self.inference_params["dataset_params"]["file_path"] = csv_path
df, tensor, history, forecast_start, test, samples = infer_on_torch_model(self.model, **self.inference_params)
if test.scale:
if test.scale and self.n_targets:
for i in range(0, self.n_targets):
unscaled = test.inverse_scale(tensor.numpy())
df["pred_" + self.targ_cols[i]] = 0
print("Shape of unscaled is: ")
print(unscaled.shape)
df["pred_" + self.targ_cols[i]][forecast_history:] = unscaled[0, :, i].numpy()
elif test.scale:
unscaled = test.inverse_scale(tensor.numpy().reshape(-1, 1))
df["preds"][forecast_history:] = unscaled.numpy()[:, 0]
if len(samples) > 1:
samples[:forecast_history] = 0
if len(samples) > 0:
for i in range(0, len(samples)):
samples[i][:forecast_history] = 0
if save_buck:
df.to_csv("temp3.csv")
upload_file(save_buck, save_name, "temp3.csv", self.model.gcs_client)
return df, tensor, history, forecast_start, test, samples

def make_plots(self, date: datetime, csv_path: str = None, csv_bucket: str = None,
save_name=None, wandb_plot_id=None):
"""
"""Function to create plots in inference mode.
:param date: [description]
:param date: The datetime to start inference
:type date: datetime
:param csv_path: [description], defaults to None
:type csv_path: str, optional
Expand All @@ -87,7 +99,7 @@ def make_plots(self, date: datetime, csv_path: str = None, csv_bucket: str = Non
:param wandb_plot_id: [description], defaults to None
:type wandb_plot_id: [type], optional
:return: [description]
:rtype: [type]
:rtype: tuple(torch.Tensor, torch.Tensor, CSVTestLoader, matplotlib.pyplot.plot)
"""
if csv_path is None:
csv_path = self.csv_path
Expand All @@ -96,9 +108,10 @@ def make_plots(self, date: datetime, csv_path: str = None, csv_bucket: str = Non
for sample, targ in zip(samples, self.model.params["dataset_params"]["target_col"]):
plt = plot_df_test_with_confidence_interval(df, sample, forecast_start, self.model.params, targ)
if wandb_plot_id:
wandb.log({wandb_plot_id: plt})
deep_explain_model_summary_plot(self.model, test, date)
deep_explain_model_heatmap(self.model, test, date)
wandb.log({wandb_plot_id + targ: plt})
if not self.n_targets:
deep_explain_model_summary_plot(self.model, test, date)
deep_explain_model_heatmap(self.model, test, date)
return tensor, history, test, plt


Expand Down Expand Up @@ -126,7 +139,11 @@ def convert_to_torch_script(model: PyTorchForecast, save_path: str) -> PyTorchFo
return model


def load_model(model_params_dict, file_path, weight_path: str) -> PyTorchForecast:
def convert_to_onnx():
pass


def load_model(model_params_dict, file_path: str, weight_path: str) -> PyTorchForecast:
"""Function to load a PyTorchForecast model from an existing config file.
:param model_params_dict: Dictionary of model parameters
Expand Down
6 changes: 0 additions & 6 deletions tests/csv_loader_test.py

This file was deleted.

1 change: 1 addition & 0 deletions tests/multi_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model_name": "CustomTransformerDecoder", "model_type": "PyTorch", "n_targets": 2, "model_params": {"dropout": 0.1, "seq_length": 11, "n_time_series": 18, "output_dim": 2, "output_seq_length": 1, "n_layers_encoder": 2, "use_mask": true}, "dataset_params": {"class": "default", "num_workers": 5, "forecast_test_len": 20, "pin_memory": true, "training_path": "/content/flow-forecast/miami_f.csv", "validation_path": "/content/flow-forecast/miami_f.csv", "test_path": "/content/flow-forecast/miami_f.csv", "batch_size": 10, "forecast_history": 11, "forecast_length": 1, "scaler": "StandardScaler", "train_start": 0, "train_end": 170, "valid_start": 170, "valid_end": 310, "sort_column": "date", "test_start": 170, "test_end": 310, "target_col": ["rolling_7", "rolling_deaths"], "relevant_cols": ["rolling_7", "rolling_deaths", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential", "avg_temperature", "min_temperature", "max_temperature", "relative_humidity", "specific_humidity", "pressure"], "feature_param": {"datetime_params": {"day_of_week": "cyclical", "month": "cyclical"}}, "interpolate": false}, "training_params": {"criterion": "MSE", "optimizer": "SGD", "optim_params": {"lr": 0.0001}, "epochs": 10, "batch_size": 10}, "early_stopping": {"patience": 3}, "GCS": true, "sweep": true, "wandb": false, "forward_params": {}, "metrics": ["MSE"], "inference_params": {"datetime_start": "2020-12-14", "hours_to_forecast": 18, "num_prediction_samples": 20, "test_csv_path": "/content/flow-forecast/miami_f.csv", "decoder_params": {"decoder_function": "simple_decode", "unsqueeze_dim": 1}, "dataset_params": {"file_path": "/content/flow-forecast/miami_f.csv", "sort_column": "date", "scaling": "StandardScaler", "forecast_history": 11, "forecast_length": 1, "relevant_cols": ["rolling_7", "rolling_deaths", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential", "avg_temperature", "min_temperature", "max_temperature", "relative_humidity", "specific_humidity", "pressure"], "target_col": ["rolling_7", "rolling_deaths"], "interpolate_param": false, "feature_params": {"datetime_params": {"day_of_week": "cyclical", "month": "cyclical"}}}}, "meta_data": false, "run": [{"epoch": 0, "train_loss": "1.1954958769492805", "validation_loss": "85.43445341289043"}, {"epoch": 1, "train_loss": "1.1476804590784013", "validation_loss": "84.1799928843975"}, {"epoch": 2, "train_loss": "1.065674600424245", "validation_loss": "84.03104758262634"}, {"epoch": 3, "train_loss": "1.0211504658218473", "validation_loss": "84.54550993442535"}, {"epoch": 4, "train_loss": "0.9789167386479676", "validation_loss": "85.40744817256927"}, {"epoch": 5, "train_loss": "0.9342440171167254", "validation_loss": "86.52448198199272"}]}
1 change: 1 addition & 0 deletions tests/probabilistic_linear_regression_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"batch_size":4,
"forecast_history":10,
"forecast_length":1,
"forecast_test_len": 30,
"train_start": 1,
"train_end": 300,
"valid_start":301,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ def setUp(self):
"""
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.json")) as y:
self.config_test = json.load(y)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "multi_config.json")) as y:
self.multi_config_test = json.load(y)
self.new_csv_path = "gs://flow_datasets/Massachusetts_Middlesex_County.csv"
self.weight_path = "gs://coronaviruspublicdata/experiments/01_July_202009_44PM_model.pth"
self.multi_path = "gs://flow_datasets/miami_multi.csv"
self.multi_weight_path = "gs://coronaviruspublicdata/experiments/28_January_202102_14AM_model.pth"
self.infer_class = InferenceMode(20, 30, self.config_test, self.new_csv_path, self.weight_path, "covid-core")

def test_load_model(self):
Expand All @@ -28,6 +32,12 @@ def test_infer_mode(self):
def test_plot_model(self):
self.infer_class.make_plots(datetime(2020, 5, 1), self.new_csv_path, "flow_datasets", "tes1/t.csv", "prod_plot")

def test_infer_multi(self):
infer_multi = InferenceMode(20, 30, self.multi_config_test, self.multi_path, self.multi_weight_path,
"covid-core")
infer_multi.make_plots(datetime(2020, 12, 10), csv_bucket="flow_datasets",
save_name="tes1/t2.csv", wandb_plot_id="prod_plot")

def test_speed(self):
pass

Expand Down

0 comments on commit 1b20d85

Please sign in to comment.