Skip to content

Commit

Permalink
fix: load_model fallback to BytesIO for Py3.6
Browse files Browse the repository at this point in the history
Catch io.UnsupportedOperation raised in Python <3.7 and buffer file
contents into a BytesIO to work around the error.
  • Loading branch information
athewsey authored and eduardocarvp committed Aug 4, 2020
1 parent 1d6b5ef commit 55c09e5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.base import BaseEstimator
from torch.utils.data import DataLoader
from copy import deepcopy
import io
import json
from pathlib import Path
import shutil
Expand Down Expand Up @@ -305,7 +306,13 @@ def load_model(self, filepath):
with z.open("model_params.json") as f:
loaded_params = json.load(f)
with z.open("network.pt") as f:
saved_state_dict = torch.load(f)
try:
saved_state_dict = torch.load(f)
except io.UnsupportedOperation:
# In Python <3.7, the returned file object is not seekable (which at least
# some versions of PyTorch require) - so we'll try buffering it in to a
# BytesIO instead:
saved_state_dict = torch.load(io.BytesIO(f.read()))
except KeyError:
raise KeyError("Your zip file is missing at least one component")

Expand Down

0 comments on commit 55c09e5

Please sign in to comment.