diff --git a/safe/converter.py b/safe/converter.py index 786870d..fc0d252 100644 --- a/safe/converter.py +++ b/safe/converter.py @@ -376,10 +376,9 @@ def decode( ): """Convert input SAFE representation to smiles Args: - inp: input SAFE representation to decode as a valid molecule or smiles + safe_str: input SAFE representation to decode as a valid molecule or smiles as_mol: whether to return a molecule object or a smiles string canonical: whether to return a canonical smiles or a randomized smiles - standardize: whether to standardize the molecule fix: whether to fix the SAFE representation to take into account non-connected attachment points remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. remove_dummies: whether to remove dummy atoms from the SAFE representation diff --git a/safe/tokenizer.py b/safe/tokenizer.py index 8177719..365c4b0 100644 --- a/safe/tokenizer.py +++ b/safe/tokenizer.py @@ -3,6 +3,7 @@ from typing import Any from typing import Iterator from typing import Union +from typing import Dict import re import os @@ -230,8 +231,7 @@ def train_from_iterator(self, data: Iterator, **kwargs: Any): def __len__(self): r""" - Returns: Gets the count of tokens in vocab along with special tokens. - + Gets the count of tokens in vocab along with special tokens. """ return len(self.tokenizer.get_vocab().keys()) @@ -511,6 +511,7 @@ def from_pretrained( local_files_only: bool = False, token: Optional[Union[str, bool]] = None, return_fast_tokenizer: Optional[bool] = False, + proxies: Optional[Dict[str, str]] = None, **kwargs, ): r""" @@ -533,7 +534,6 @@ def from_pretrained( cache_dir: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. force_download: Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist. - resume_download: Whether or not to delete incompletely received files. Attempt to resume the download if such a file exists. proxies: A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. token: The token to use as HTTP bearer authorization for remote files. @@ -555,7 +555,6 @@ def from_pretrained( ``` """ resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) use_auth_token = kwargs.pop("use_auth_token", None) subfolder = kwargs.pop("subfolder", None) from_pipeline = kwargs.pop("_from_pipeline", None) diff --git a/safe/trainer/data_utils.py b/safe/trainer/data_utils.py index 3d408bd..eb1e706 100644 --- a/safe/trainer/data_utils.py +++ b/safe/trainer/data_utils.py @@ -1,11 +1,17 @@ from typing import Optional from typing import Callable +from typing import Any +from typing import Union +from typing import Dict + from collections.abc import Mapping from tqdm.auto import tqdm from functools import partial + import itertools import upath import datasets + from safe.tokenizer import SAFETokenizer @@ -14,10 +20,12 @@ def take(n, iterable): return list(itertools.islice(iterable, n)) -def get_dataset_column_names(dataset): +def get_dataset_column_names(dataset: Union[datasets.Dataset, datasets.IterableDataset, Mapping]): """Get the column names in a dataset + Args: dataset: dataset to get the column names from + """ if isinstance(dataset, (datasets.IterableDatasetDict, Mapping)): column_names = {split: dataset[split].column_names for split in dataset} @@ -29,8 +37,8 @@ def get_dataset_column_names(dataset): def tokenize_fn( - row, - tokenizer, + row: Dict[str, Any], + tokenizer: Callable, tokenize_column: str = "inputs", max_length: Optional[int] = None, padding: bool = False, diff --git a/safe/trainer/model.py b/safe/trainer/model.py index 90a9560..09fcc0b 100644 --- a/safe/trainer/model.py +++ b/safe/trainer/model.py @@ -152,8 +152,6 @@ def forward( mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*): Labels for computing the supervized loss for regularization. inputs: List of inputs, put here because the trainer removes information not in signature - - Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( diff --git a/safe/utils.py b/safe/utils.py index a67280c..8445a02 100644 --- a/safe/utils.py +++ b/safe/utils.py @@ -288,8 +288,9 @@ def link_fragments( @contextmanager -def attr_as(obj, field, value): +def attr_as(obj: Any, field: str, value: Any): """Temporary replace the value of an object + Args: obj: object to temporary patch field: name of the key to change