Skip to content

Commit

Permalink
Merge pull request #544 from AIStream-Peelout/infer_classification
Browse files Browse the repository at this point in the history
Infer classification
  • Loading branch information
isaacmg authored May 24, 2022
2 parents 15d4b7d + 00b5082 commit 8627440
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 3 deletions.
33 changes: 32 additions & 1 deletion flood_forecast/deployment/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from flood_forecast.evaluator import infer_on_torch_model
from flood_forecast.plot_functions import plot_df_test_with_confidence_interval
from flood_forecast.explain_model_output import deep_explain_model_heatmap, deep_explain_model_summary_plot
from torch.utils.data import DataLoader
from flood_forecast.time_model import scaling_function
# from flood_forecast.preprocessing.buil_dataset import get_data
from flood_forecast.gcp_integration.basic_utils import upload_file
Expand Down Expand Up @@ -30,8 +31,10 @@ def __init__(self, forecast_steps: int, n_samp: int, model_params, csv_path: Uni
:param wandb_proj: The name of the WB project leave blank if you don't want to log to Wandb, defaults to None
:type wandb_proj: str, optionals
"""
self.hours_to_forecast = forecast_steps
if "inference_params" not in model_params:
model_params["inference_params"] = {"dataset_params": {}}
self.csv_path = csv_path
self.hours_to_forecast = forecast_steps
self.n_targets = model_params.get("n_targets")
self.targ_cols = model_params["dataset_params"]["target_col"]
self.model = load_model(model_params.copy(), csv_path, weight_path)
Expand Down Expand Up @@ -88,6 +91,32 @@ def infer_now(self, some_date: datetime, csv_path=None, save_buck=None, save_nam
upload_file(save_buck, save_name, "temp3.csv", self.model.gcs_client)
return df, tensor, history, forecast_start, test, samples

def infer_now_classification(self, data=None, over_lap_seq=True, save_buck=None, save_name=None, batch_size=1):
"""Function to preform classification/anomaly detection on sequences in real-time
:param data
:type data: Union[pd.DataFrame, str], optional
:param over_lap_seq: Whether to increment by one throughout the df or by sequence length
:type over_lap_seq: bool,
:param batch_size: The batch size to use, defaults to 1
"""
if data:
dataset_params = self.model.params["dataset_params"].copy()
dataset_params["class"] = "GeneralClassificationLoader"
dataset_1 = self.model.make_data_load(data, dataset_params, "custom")
inferL = DataLoader(dataset_1, batch_size=batch_size)
else:
loader = self.model.test_data
inferL = DataLoader(loader, batch_size=batch_size)
seq_list = []
if over_lap_seq:
for x, y in inferL:
seq_list.append(self.model.model(x))
else:
for i in range(0, len(loader), dataset_params["sequence_length"]):
loader[i] # TODO finish implementing
return seq_list

def make_plots(self, date: datetime, csv_path: str = None, csv_bucket: str = None,
save_name=None, wandb_plot_id=None):
"""Function to create plots in inference mode.
Expand Down Expand Up @@ -144,6 +173,7 @@ def convert_to_torch_script(model: PyTorchForecast, save_path: str) -> PyTorchFo


def convert_to_onnx():
""""""
pass


Expand All @@ -166,5 +196,6 @@ def load_model(model_params_dict, file_path: str, weight_path: str) -> PyTorchFo
if "weight_path_add" in model_params_dict:
if "excluded_layers" in model_params_dict["weight_path_add"]:
del model_params_dict["weight_path_add"]["excluded_layers"]
# do stuff
m = PyTorchForecast(model_params_dict["model_name"], file_path, file_path, file_path, model_params_dict)
return m
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='flood_forecast',
version='0.9956dev',
version='0.9988dev',
packages=[
'flood_forecast',
'flood_forecast.transformer_xl',
Expand Down
1 change: 1 addition & 0 deletions tests/24_May_202202_25PM_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model_name": "CustomTransformerDecoder", "n_targets": 2, "model_type": "PyTorch", "model_params": {"n_time_series": 4, "seq_length": 26, "output_seq_length": 1, "output_dim": 2, "n_layers_encoder": 6}, "dataset_params": {"class": "GeneralClassificationLoader", "n_classes": 2, "training_path": "tests/test_data/ff_test.csv", "validation_path": "tests/test_data/ff_test.csv", "test_path": "tests/test_data/ff_test.csv", "sequence_length": 26, "batch_size": 100, "forecast_history": 26, "train_end": 80, "valid_start": 4, "valid_end": 90, "target_col": ["anomalous_rain"], "relevant_cols": ["anomalous_rain", "tmpf", "cfs", "dwpf", "height"], "scaler": "StandardScaler", "interpolate": {"method": "back_forward_generic", "params": {"relevant_columns": ["cfs", "tmpf", "p01m", "dwpf"]}}, "forecast_length": 1}, "training_params": {"criterion": "CrossEntropyLoss", "optimizer": "Adam", "optim_params": {}, "lr": 0.03, "epochs": 4, "batch_size": 100, "shuffle": false}, "GCS": false, "wandb": {"name": "flood_forecast_circleci", "tags": ["dummy_run", "circleci", "multi_head", "classification"], "project": "repo-flood_forecast"}, "forward_params": {}, "metrics": ["CrossEntropyLoss"], "run": [{"epoch": 0, "train_loss": "0.0680028973509454", "validation_loss": "113.27512431330979"}, {"epoch": 1, "train_loss": "0.05458420396058096", "validation_loss": "108.99862844124436"}, {"epoch": 2, "train_loss": "0.054659905693390305", "validation_loss": "106.30307429283857"}, {"epoch": 3, "train_loss": "0.054730736438391936", "validation_loss": "104.98548858333379"}]}
2 changes: 1 addition & 1 deletion tests/decoder_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"n_time_series":3,
"seq_length":5,
"output_seq_length": 1,
"n_layers_encoder": 6
"n_layers_encoder": 4
},
"dataset_params":
{ "class": "default",
Expand Down
Loading

0 comments on commit 8627440

Please sign in to comment.