diff --git a/executorlib/__init__.py b/executorlib/__init__.py index 159031a3..bab1ffc5 100644 --- a/executorlib/__init__.py +++ b/executorlib/__init__.py @@ -5,6 +5,9 @@ from executorlib.standalone.inputcheck import ( check_plot_dependency_graph as _check_plot_dependency_graph, ) +from executorlib.standalone.inputcheck import ( + check_pysqa_config_directory as _check_pysqa_config_directory, +) from executorlib.standalone.inputcheck import ( check_refresh_rate as _check_refresh_rate, ) @@ -194,6 +197,7 @@ def __new__( init_function=init_function, ) elif not disable_dependencies: + _check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory) return ExecutorWithDependencies( max_workers=max_workers, backend=backend, @@ -210,6 +214,7 @@ def __new__( plot_dependency_graph=plot_dependency_graph, ) else: + _check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory) _check_plot_dependency_graph(plot_dependency_graph=plot_dependency_graph) _check_refresh_rate(refresh_rate=refresh_rate) return create_executor( diff --git a/executorlib/cache/executor.py b/executorlib/cache/executor.py index 271a1f2e..cf750780 100644 --- a/executorlib/cache/executor.py +++ b/executorlib/cache/executor.py @@ -9,6 +9,9 @@ ) from executorlib.standalone.inputcheck import ( check_executor, + check_flux_executor_pmi_mode, + check_hostname_localhost, + check_max_workers_and_cores, check_nested_flux_executor, ) from executorlib.standalone.thread import RaisingThread @@ -89,18 +92,6 @@ def create_file_executor( ): if cache_directory is None: cache_directory = "executorlib_cache" - if max_workers != 1: - raise ValueError( - "The number of workers cannot be controlled with the pysqa based backend." - ) - if max_cores != 1: - raise ValueError( - "The number of cores cannot be controlled with the pysqa based backend." - ) - if hostname_localhost is not None: - raise ValueError( - "The option to connect to hosts based on their hostname is not available with the pysqa based backend." - ) if block_allocation: raise ValueError( "The option block_allocation is not available with the pysqa based backend." @@ -109,10 +100,9 @@ def create_file_executor( raise ValueError( "The option to specify an init_function is not available with the pysqa based backend." ) - if flux_executor_pmi_mode is not None: - raise ValueError( - "The option to specify the flux pmi mode is not available with the pysqa based backend." - ) + check_flux_executor_pmi_mode(flux_executor_pmi_mode=flux_executor_pmi_mode) + check_max_workers_and_cores(max_cores=max_cores, max_workers=max_workers) + check_hostname_localhost(hostname_localhost=hostname_localhost) check_executor(executor=flux_executor) check_nested_flux_executor(nested_flux_executor=flux_executor_nesting) return FileExecutor( diff --git a/executorlib/standalone/inputcheck.py b/executorlib/standalone/inputcheck.py index d1d9800f..9a466265 100644 --- a/executorlib/standalone/inputcheck.py +++ b/executorlib/standalone/inputcheck.py @@ -131,6 +131,41 @@ def check_init_function(block_allocation: bool, init_function: Callable) -> None raise ValueError("") +def check_max_workers_and_cores(max_workers: int, max_cores: int) -> None: + if max_workers != 1: + raise ValueError( + "The number of workers cannot be controlled with the pysqa based backend." + ) + if max_cores != 1: + raise ValueError( + "The number of cores cannot be controlled with the pysqa based backend." + ) + + +def check_hostname_localhost(hostname_localhost: Optional[bool]) -> None: + if hostname_localhost is not None: + raise ValueError( + "The option to connect to hosts based on their hostname is not available with the pysqa based backend." + ) + + +def check_flux_executor_pmi_mode(flux_executor_pmi_mode: Optional[str]) -> None: + if flux_executor_pmi_mode is not None: + raise ValueError( + "The option to specify the flux pmi mode is not available with the pysqa based backend." + ) + + +def check_pysqa_config_directory(pysqa_config_directory: Optional[str]) -> None: + """ + Check if pysqa_config_directory is None and raise a ValueError if it is not. + """ + if pysqa_config_directory is not None: + raise ValueError( + "pysqa_config_directory parameter is only supported for pysqa backend." + ) + + def validate_number_of_cores(max_cores: int, max_workers: int) -> int: """ Validate the number of cores and return the appropriate value. diff --git a/tests/test_shared_input_check.py b/tests/test_shared_input_check.py index 4444dc48..8899a75e 100644 --- a/tests/test_shared_input_check.py +++ b/tests/test_shared_input_check.py @@ -13,6 +13,10 @@ check_refresh_rate, check_resource_dict, check_resource_dict_is_empty, + check_flux_executor_pmi_mode, + check_max_workers_and_cores, + check_hostname_localhost, + check_pysqa_config_directory, ) @@ -69,3 +73,25 @@ def test_check_nested_flux_executor(self): def test_check_plot_dependency_graph(self): with self.assertRaises(ValueError): check_plot_dependency_graph(plot_dependency_graph=True) + + def test_check_flux_executor_pmi_mode(self): + with self.assertRaises(ValueError): + check_flux_executor_pmi_mode(flux_executor_pmi_mode="test") + + def test_check_max_workers_and_cores(self): + with self.assertRaises(ValueError): + check_max_workers_and_cores(max_workers=2, max_cores=1) + with self.assertRaises(ValueError): + check_max_workers_and_cores(max_workers=1, max_cores=2) + with self.assertRaises(ValueError): + check_max_workers_and_cores(max_workers=2, max_cores=2) + + def test_check_hostname_localhost(self): + with self.assertRaises(ValueError): + check_hostname_localhost(hostname_localhost=True) + with self.assertRaises(ValueError): + check_hostname_localhost(hostname_localhost=False) + + def test_check_pysqa_config_directory(self): + with self.assertRaises(ValueError): + check_pysqa_config_directory(pysqa_config_directory="path/to/config")