Skip to content

Commit

Permalink
Add a return value to env.require_dataset().
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Mar 3, 2021
1 parent 3623b6e commit bcb8b34
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _reward_view_type(self):
"""
return RewardView

def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None:
def require_datasets(self, datasets: List[Union[str, Dataset]]) -> bool:
"""Require that the given datasets are available to the environment.
Example usage:
Expand All @@ -857,6 +857,8 @@ def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None:
:param datasets: A list of datasets to require. Each dataset is the name
of an available dataset, the URL of a dataset to download, or a
:class:`Dataset` instance.
:return: :code:`True` if one or more datasets were downloaded, or
:code:`False` if all datasets were already available.
"""
dataset_installed = False
for dataset in datasets:
Expand All @@ -872,15 +874,18 @@ def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None:
),
)
self.make_manifest_file()
return dataset_installed

def require_dataset(self, dataset: Union[str, Dataset]) -> None:
def require_dataset(self, dataset: Union[str, Dataset]) -> bool:
"""Require that the given dataset is available to the environment.
Alias for
:meth:`env.require_datasets([dataset]) <compiler_gym.envs.CompilerEnv.require_datasets>`.
:param dataset: The name of the dataset to download, the URL of the dataset, or a
:class:`Dataset` instance.
:return: :code:`True` if the dataset was downloaded, or :code:`False` if
the dataset was already available.
"""
return self.require_datasets([dataset])

Expand Down

0 comments on commit bcb8b34

Please sign in to comment.