Skip to content

Commit

Permalink
fix doc args and clean up parameter space
Browse files Browse the repository at this point in the history
  • Loading branch information
maclandrol committed Oct 28, 2023
1 parent 34b1ac8 commit ebd6ad7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
3 changes: 1 addition & 2 deletions safe/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,9 @@ def decode(
):
"""Convert input SAFE representation to smiles
Args:
inp: input SAFE representation to decode as a valid molecule or smiles
safe_str: input SAFE representation to decode as a valid molecule or smiles
as_mol: whether to return a molecule object or a smiles string
canonical: whether to return a canonical smiles or a randomized smiles
standardize: whether to standardize the molecule
fix: whether to fix the SAFE representation to take into account non-connected attachment points
remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string.
remove_dummies: whether to remove dummy atoms from the SAFE representation
Expand Down
7 changes: 3 additions & 4 deletions safe/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any
from typing import Iterator
from typing import Union
from typing import Dict

import re
import os
Expand Down Expand Up @@ -230,8 +231,7 @@ def train_from_iterator(self, data: Iterator, **kwargs: Any):

def __len__(self):
r"""
Returns: Gets the count of tokens in vocab along with special tokens.
Gets the count of tokens in vocab along with special tokens.
"""
return len(self.tokenizer.get_vocab().keys())

Expand Down Expand Up @@ -511,6 +511,7 @@ def from_pretrained(
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
return_fast_tokenizer: Optional[bool] = False,
proxies: Optional[Dict[str, str]] = None,
**kwargs,
):
r"""
Expand All @@ -533,7 +534,6 @@ def from_pretrained(
cache_dir: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used.
force_download: Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist.
resume_download: Whether or not to delete incompletely received files. Attempt to resume the download if such a file exists.
proxies: A dictionary of proxy servers to use by protocol or endpoint, e.g.,
`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
token: The token to use as HTTP bearer authorization for remote files.
Expand All @@ -555,7 +555,6 @@ def from_pretrained(
```
"""
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
subfolder = kwargs.pop("subfolder", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
Expand Down
14 changes: 11 additions & 3 deletions safe/trainer/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Optional
from typing import Callable
from typing import Any
from typing import Union
from typing import Dict

from collections.abc import Mapping
from tqdm.auto import tqdm
from functools import partial

import itertools
import upath
import datasets

from safe.tokenizer import SAFETokenizer


Expand All @@ -14,10 +20,12 @@ def take(n, iterable):
return list(itertools.islice(iterable, n))


def get_dataset_column_names(dataset):
def get_dataset_column_names(dataset: Union[datasets.Dataset, datasets.IterableDataset, Mapping]):
"""Get the column names in a dataset
Args:
dataset: dataset to get the column names from
"""
if isinstance(dataset, (datasets.IterableDatasetDict, Mapping)):
column_names = {split: dataset[split].column_names for split in dataset}
Expand All @@ -29,8 +37,8 @@ def get_dataset_column_names(dataset):


def tokenize_fn(
row,
tokenizer,
row: Dict[str, Any],
tokenizer: Callable,
tokenize_column: str = "inputs",
max_length: Optional[int] = None,
padding: bool = False,
Expand Down
2 changes: 0 additions & 2 deletions safe/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ def forward(
mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*):
Labels for computing the supervized loss for regularization.
inputs: List of inputs, put here because the trainer removes information not in signature
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
Expand Down
3 changes: 2 additions & 1 deletion safe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def link_fragments(


@contextmanager
def attr_as(obj, field, value):
def attr_as(obj: Any, field: str, value: Any):
"""Temporary replace the value of an object
Args:
obj: object to temporary patch
field: name of the key to change
Expand Down

0 comments on commit ebd6ad7

Please sign in to comment.