From 52318c844f74d5ac181fc1ef37b2a8f6897646ce Mon Sep 17 00:00:00 2001 From: mozin Date: Tue, 16 Mar 2021 00:27:16 +0530 Subject: [PATCH] torch init --- d6tflow/targets/torch.py | 37 +++++++++++++++++++++++++++++++++++++ d6tflow/tasks/torch.py | 10 ++++++++++ requirements.txt | 1 + 3 files changed, 48 insertions(+) create mode 100644 d6tflow/targets/torch.py create mode 100644 d6tflow/tasks/torch.py diff --git a/d6tflow/targets/torch.py b/d6tflow/targets/torch.py new file mode 100644 index 0000000..39e598a --- /dev/null +++ b/d6tflow/targets/torch.py @@ -0,0 +1,37 @@ +from d6tflow.targets import DataTarget +import torch + + +class PyTorchModel(DataTarget): + + + def load(self, cached=False, **kwargs): + """ + Load saved model + + Args: + cached (bool): keep data cached in memory + **kwargs: arguments to pass to pd.read_parquet + + Returns: pandas dataframe + + """ + return super().load(torch.load, cached, **kwargs) + + + + def save(self, model, **kwargs): + """ + Save torch model + + Args: + model (obj): python object + kwargs : additional arguments to pass to torch.save + + Returns: filename + + """ + + (self.path).parent.mkdir(parents=True, exist_ok=True) + torch.save(model, self.path, **kwargs) + return self.path \ No newline at end of file diff --git a/d6tflow/tasks/torch.py b/d6tflow/tasks/torch.py new file mode 100644 index 0000000..d182b00 --- /dev/null +++ b/d6tflow/tasks/torch.py @@ -0,0 +1,10 @@ +from d6tflow.tasks import TaskData +from d6tflow.targets.torch import PyTorchModel + +class PyTorch(TaskData): + """ + Task which saves to .pt models + """ + target_class = PyTorchModel + target_ext = '.pt' + diff --git a/requirements.txt b/requirements.txt index cd88dda..d4ffaba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ toolz dask[dataframe] d6tcollect pyarrow +torch