diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index ee0f6af42..b479a4e01 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -840,7 +840,7 @@ def _reward_view_type(self): """ return RewardView - def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None: + def require_datasets(self, datasets: List[Union[str, Dataset]]) -> bool: """Require that the given datasets are available to the environment. Example usage: @@ -857,6 +857,8 @@ def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None: :param datasets: A list of datasets to require. Each dataset is the name of an available dataset, the URL of a dataset to download, or a :class:`Dataset` instance. + :return: :code:`True` if one or more datasets were downloaded, or + :code:`False` if all datasets were already available. """ dataset_installed = False for dataset in datasets: @@ -872,8 +874,9 @@ def require_datasets(self, datasets: List[Union[str, Dataset]]) -> None: ), ) self.make_manifest_file() + return dataset_installed - def require_dataset(self, dataset: Union[str, Dataset]) -> None: + def require_dataset(self, dataset: Union[str, Dataset]) -> bool: """Require that the given dataset is available to the environment. Alias for @@ -881,6 +884,8 @@ def require_dataset(self, dataset: Union[str, Dataset]) -> None: :param dataset: The name of the dataset to download, the URL of the dataset, or a :class:`Dataset` instance. + :return: :code:`True` if the dataset was downloaded, or :code:`False` if + the dataset was already available. """ return self.require_datasets([dataset])