Skip to content

Commit

Permalink
Import torchtext #1426 4be2792
Browse files Browse the repository at this point in the history
Summary: Import from github

Reviewed By: Nayef211

Differential Revision: D31962042

fbshipit-source-id: 0308ae0cfe402e8c3eb133cb5a205b65f98ad1df
  • Loading branch information
parmeet authored and facebook-github-bot committed Oct 28, 2021
1 parent 30ca7cf commit 6b437b1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ We recommend Anaconda as a Python package management system. Please refer to `py
1.10.0, 0.11.0, ">=3.6, <=3.9"
1.9.1, 0.10.1, ">=3.6, <=3.9"
1.9, 0.10, ">=3.6, <=3.9"
1.8.2, 0.9.2, ">=3.6, <=3.9"
1.8.1, 0.9.1, ">=3.6, <=3.9"
1.8, 0.9, ">=3.6, <=3.9"
1.7.1, 0.8.1, ">=3.6, <=3.9"
1.7, 0.8, ">=3.6, <=3.8"
Expand Down
1 change: 1 addition & 0 deletions torchtext/_download_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm
# This is to allow monkey-patching in fbcode
from torch.hub import load_state_dict_from_url # noqa
from torchtext._internal.module_utils import is_module_available

if is_module_available("torchdata"):
from torchdata.datapipes.iter import HttpReader # noqa F401
Expand Down
8 changes: 7 additions & 1 deletion torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ class RobertaModelBundle:
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:
def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwargs=None) -> RobertaModel:

if load_weights:
assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object"

if head is not None:
input_head = head
Expand All @@ -67,6 +70,9 @@ def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> Roberta

model = _get_model(self._params, input_head)

if not load_weights:
return model

dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
if input_head is not None:
Expand Down

0 comments on commit 6b437b1

Please sign in to comment.