Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional input checks #463

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions executorlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from executorlib.standalone.inputcheck import (
check_plot_dependency_graph as _check_plot_dependency_graph,
)
from executorlib.standalone.inputcheck import (
check_pysqa_config_directory as _check_pysqa_config_directory,
)
from executorlib.standalone.inputcheck import (
check_refresh_rate as _check_refresh_rate,
)
Expand Down Expand Up @@ -194,6 +197,7 @@ def __new__(
init_function=init_function,
)
elif not disable_dependencies:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
return ExecutorWithDependencies(
max_workers=max_workers,
backend=backend,
Expand All @@ -210,6 +214,7 @@ def __new__(
plot_dependency_graph=plot_dependency_graph,
)
else:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
_check_plot_dependency_graph(plot_dependency_graph=plot_dependency_graph)
_check_refresh_rate(refresh_rate=refresh_rate)
return create_executor(
Expand Down
22 changes: 6 additions & 16 deletions executorlib/cache/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)
from executorlib.standalone.inputcheck import (
check_executor,
check_flux_executor_pmi_mode,
check_hostname_localhost,
check_max_workers_and_cores,
check_nested_flux_executor,
)
from executorlib.standalone.thread import RaisingThread
Expand Down Expand Up @@ -89,18 +92,6 @@ def create_file_executor(
):
if cache_directory is None:
cache_directory = "executorlib_cache"
if max_workers != 1:
raise ValueError(
"The number of workers cannot be controlled with the pysqa based backend."
)
if max_cores != 1:
raise ValueError(
"The number of cores cannot be controlled with the pysqa based backend."
)
if hostname_localhost is not None:
raise ValueError(
"The option to connect to hosts based on their hostname is not available with the pysqa based backend."
)
if block_allocation:
raise ValueError(
"The option block_allocation is not available with the pysqa based backend."
Expand All @@ -109,10 +100,9 @@ def create_file_executor(
raise ValueError(
"The option to specify an init_function is not available with the pysqa based backend."
)
if flux_executor_pmi_mode is not None:
raise ValueError(
"The option to specify the flux pmi mode is not available with the pysqa based backend."
)
check_flux_executor_pmi_mode(flux_executor_pmi_mode=flux_executor_pmi_mode)
check_max_workers_and_cores(max_cores=max_cores, max_workers=max_workers)
check_hostname_localhost(hostname_localhost=hostname_localhost)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
return FileExecutor(
Expand Down
35 changes: 35 additions & 0 deletions executorlib/standalone/inputcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,41 @@ def check_init_function(block_allocation: bool, init_function: Callable) -> None
raise ValueError("")


def check_max_workers_and_cores(max_workers: int, max_cores: int) -> None:
if max_workers != 1:
raise ValueError(
"The number of workers cannot be controlled with the pysqa based backend."
)
if max_cores != 1:
raise ValueError(
"The number of cores cannot be controlled with the pysqa based backend."
)


def check_hostname_localhost(hostname_localhost: Optional[bool]) -> None:
if hostname_localhost is not None:
raise ValueError(
"The option to connect to hosts based on their hostname is not available with the pysqa based backend."
)


def check_flux_executor_pmi_mode(flux_executor_pmi_mode: Optional[str]) -> None:
if flux_executor_pmi_mode is not None:
raise ValueError(
"The option to specify the flux pmi mode is not available with the pysqa based backend."
)


def check_pysqa_config_directory(pysqa_config_directory: Optional[str]) -> None:
"""
Check if pysqa_config_directory is None and raise a ValueError if it is not.
"""
if pysqa_config_directory is not None:
raise ValueError(
"pysqa_config_directory parameter is only supported for pysqa backend."
)


def validate_number_of_cores(max_cores: int, max_workers: int) -> int:
"""
Validate the number of cores and return the appropriate value.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_shared_input_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
check_refresh_rate,
check_resource_dict,
check_resource_dict_is_empty,
check_flux_executor_pmi_mode,
check_max_workers_and_cores,
check_hostname_localhost,
check_pysqa_config_directory,
)


Expand Down Expand Up @@ -69,3 +73,25 @@ def test_check_nested_flux_executor(self):
def test_check_plot_dependency_graph(self):
with self.assertRaises(ValueError):
check_plot_dependency_graph(plot_dependency_graph=True)

def test_check_flux_executor_pmi_mode(self):
with self.assertRaises(ValueError):
check_flux_executor_pmi_mode(flux_executor_pmi_mode="test")

def test_check_max_workers_and_cores(self):
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=2, max_cores=1)
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=1, max_cores=2)
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=2, max_cores=2)

def test_check_hostname_localhost(self):
with self.assertRaises(ValueError):
check_hostname_localhost(hostname_localhost=True)
with self.assertRaises(ValueError):
check_hostname_localhost(hostname_localhost=False)

def test_check_pysqa_config_directory(self):
with self.assertRaises(ValueError):
check_pysqa_config_directory(pysqa_config_directory="path/to/config")
Loading