Skip to content

Commit

Permalink
FIX: Adding an overwriteable slurmified function dectorator
Browse files Browse the repository at this point in the history
  • Loading branch information
d-krupke committed Feb 27, 2024
1 parent b97b518 commit 32a9648
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
25 changes: 24 additions & 1 deletion src/slurminade/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,31 @@ def dec(func) -> SlurmFunction:

return dec

def _slurmify(
allow_overwrite: bool, **args
) -> typing.Union[typing.Callable[[typing.Callable], SlurmFunction], SlurmFunction]:
"""
Decorator: Make a function distributable to slurm.
Usage:
.. code-block:: python
@slurmify_()
def func(a, b):
pass
:param f: Function
:param args: Special slurm options for this function.
:return: A decorated function, callable with slurm.
"""

def dec(func) -> SlurmFunction:
func_id = FunctionMap.register(func, allow_overwrite=allow_overwrite)
return SlurmFunction(args, func, func_id)

return dec

@slurmify()
@_slurmify(allow_overwrite=True)
def shell(cmd: typing.Union[str, typing.List[str]]):
"""
Execute a command.
Expand Down
8 changes: 6 additions & 2 deletions src/slurminade/function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,23 @@ def check_compatibility(func: typing.Callable):
raise ValueError(msg)

@staticmethod
def register(func: typing.Callable) -> str:
def register(func: typing.Callable, allow_overwrite: bool=False) -> str:
"""
Register a function, allowing it to be called just by its id.
:param func: The function to be stored. Needs to be a proper function.
:return: The function's id.
"""
FunctionMap.check_compatibility(func)
func_id = FunctionMap.get_id(func)
if func_id in FunctionMap._data:
if func_id in FunctionMap._data and not allow_overwrite:
msg = "Multiple function definitions!"
raise RuntimeError(msg)
FunctionMap._data[func_id] = func
return func_id

@staticmethod
def exists(func_id: str) -> bool:
return func_id in FunctionMap._data

@staticmethod
def call(
Expand Down

0 comments on commit 32a9648

Please sign in to comment.