-
-
Notifications
You must be signed in to change notification settings - Fork 656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] support for custom logic in submitit plugin's checkpoint function #2042
Comments
Hi @tesfaldet, I'm not an expert in using submitit, but let me nevertheless share an idea: Here is what I am thinking:
Here is the main python script: $ cat my_app.py
import hydra
from omegaconf import DictConfig
@hydra.main(config_path="conf", config_name="config")
def app(cfg: DictConfig) -> None:
print(cfg)
if __name__ == "__main__":
app() Here is the yaml configuration file: $ cat conf/config.yaml
defaults:
- override hydra/launcher: submitit_local
# - override hydra/launcher: submitit_slurm
- _self_
hydra:
launcher:
_target_: hydra_plugins.custom_submitit_logic.MyLocalLauncher
# _target_: hydra_plugins.custom_submitit_logic.MySlurmLauncher And here is the custom plugin subclassing $ cat hydra_plugins/custom_submitit_logic.py
from typing import Any
from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher
class MyBaseSubmititLauncher(BaseSubmititLauncher):
def __init__(self, *args, **kwargs) -> None:
print("INITIALIZING CUSTOM SUBMITIT LAUNCHER")
...
super().__init__(*args, **kwargs)
def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
"""This method is a modified version of the BaseSubmititLauncher.checkpoint method"""
#########################
### CUSTOM LOGIC HERE ###
#########################
super().checkpoint(*modified_args, **modified_kwargs)
class MyLocalLauncher(MyBaseSubmititLauncher):
_EXECUTOR = "local"
class MySlurmLauncher(MyBaseSubmititLauncher):
_EXECUTOR = "slurm" When running
Would subclassing the submitit plugin in this way be sufficiently flexible for your use-case? |
I have tried passing a callable class instance (instead of passing a function) to hydra.main: from typing import Any
import hydra
from omegaconf import DictConfig
class MyTaskFunction:
def __init__(self, context: Any) -> None:
self._context = context
def __call__(self, cfg: DictConfig) -> None:
print(cfg)
if __name__ == "__main__":
my_task_function = MyTaskFunction(context=123)
app = hydra.main(config_path="conf", config_name="config")(my_task_function)
app() It is not working currently, but should be easy to support. |
Thanks for the quick response @Jasha10. Your suggestion is a great one and I think it could work. Specifically, my suggestion of using a from omegaconf import DictConfig, OmegaConf
from typing import Any, Optional, Callable
from functools import wraps
from hydra.types import TaskFunction
from hydra.core.utils import JobReturn
TaskFunction = Callable[[Any], Any]
_UNSPECIFIED_: Any = object()
def main(config_path: Optional[str] = _UNSPECIFIED_,
config_name: Optional[str] = None) -> Callable[[TaskFunction], Any]:
def main_decorator(task_function: TaskFunction) -> Callable[[], None]:
@wraps(task_function)
def decorator_main(cfg_passthrough: Optional[DictConfig] = None, checkpoint: Optional[bool] = False) -> Any:
if cfg_passthrough is not None:
return task_function(cfg_passthrough)
else:
conf = OmegaConf.create({'config_name': config_name, 'config_path': config_path})
launcher = BaseSubmititLauncher()
launcher.setup(task_function=task_function, config=conf)
ret = launcher.launch()
if checkpoint:
launcher.checkpoint()
return ret
return decorator_main
return main_decorator
def run_job(task_function: TaskFunction,
config: DictConfig):
return task_function(config)
class BaseSubmititLauncher():
def __init__(self) -> None:
self.config: Optional[DictConfig] = None
self.task_function: Optional[TaskFunction] = None
def setup(
self,
*,
task_function: TaskFunction,
config: DictConfig
) -> None:
self.config = config
self.task_function = task_function
def __call__(self) -> JobReturn:
assert self.config is not None
assert self.task_function is not None
return run_job(
task_function=self.task_function,
config=self.config
)
def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
"""Resubmit the current callable at its current state with the same initial arguments."""
print(f'in checkpoint with model {self.task_function.model}')
def launch(self) -> JobReturn:
return self()
class Train(TaskFunction):
def __init__(self):
self.model = None
self.config = None
# main train loop code
def __call__(self, cfg: DictConfig) -> float:
self.model = 'PyTorchModel'
self.config = cfg
return 3.3
def __str__(self) -> str:
return f'Train(self.model = {self.model}, self.config = {self.config})'
if __name__ == '__main__':
train = Train()
print(f'in __main__ after instantiating Train object: {train}')
wrapped = main(config_path='./conf', config_name='config')(train)
print(f'in __main__ after wrapping Train object with main decorator: {train}')
wrapped()
print(f'in __main__ after calling train() with no config passthrough: {train}')
wrapped(checkpoint=True)
print(f'in __main__ after calling train() with no config passthrough and calling checkpoint: {train}')
wrapped(OmegaConf.create({'foo': 'bar'}))
print(f'in __main__ after calling train() with config passthrough: {train}') outputs:
What do you think? It follows a similar execution path as Hydra to show that this might be possible without any change to Hydra's code :) Basically, since the launcher already has access to |
I just took a closer look at your example and it's basically the same as mine 😅 I'm surprised it didn't work for you! I haven't tested mine with Hydra's actual I'm not at my computer for the weekend so I can't test it out myself with Hydra's |
Correct me if I'm wrong, but doesn't the submitit plug-in launch a subprocess for each list of sweep params, with each subprocess executing your app's task function with those params? Wouldn't that mean that a global variable outside of the app's task function won't be accessible to each subprocess unless there's some IPC implemented? |
When I run the example from my previous comment, I get the following error:
I am almost done with a PR to patch this so that a callable can be passed to |
Aah yes, good point. I haven't tried it myself, but I suspect you're right. |
@Jasha10 I believe the error you're experiencing occurs because you're passing the callable object and not its import hydra
from omegaconf import DictConfig
class MyTaskFunction:
def __init__(self) -> None:
self._context = None
def __call__(self, cfg: DictConfig) -> None:
self._context = 123
print(cfg, self._context)
if __name__ == "__main__":
my_task_function = MyTaskFunction()
print(my_task_function._context)
app = hydra.main(config_path=None, config_name=None)(my_task_function.__call__)
app()
print(my_task_function._context) outputs:
|
Nvm, the PR is super necessary haha. Although the above works, the problem with passing in |
I just realized that you could access the object through |
I finished trying it out and it worked :) It was a hassle to get working properly for a couple of reasons:
# otherwise pickle.load from submitit launcher won't find it since cloudpickle registers by reference by default
import hydra_plugins.launchers
cloudpickle.register_pickle_by_value(hydra_plugins.launchers) in the top your main app file. Matter of fact, you need to do this There are a couple of weird gotchas but I'm too tired to list them out right now. I've managed to get it working with my own custom submitit launcher with its own checkpoint function and had it send me a wandb alert during pre-emption, re-using the |
I see, very interesting! I'm glad it's working for you. I'm a bit surprised that creating a file |
It's working up to a point! Here's a more detailed explanation of the issue(s) I've been experiencing. Let's say the below is an example folder hierarchy:
Cloudpickle (and pickle) serializes files by reference by default. Cloudpickle can automatically serialize by value but only during an interactive session. However, you could use their experimental You'd think doing Anyways, |
The below error is what I got if I were to put my own custom plugin within a
|
I finally fixed the above issue! Specifically, the def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
"""This method is a modified version of the BaseSubmititLauncher.checkpoint method"""
run = self.task_function.__self__.run
run.alert(
title='Job Pre-empted/Timed-out',
text=f'Job {run.name} in group {run.group} has either been pre-empted or timed-out.',
level=wandb.AlertLevel.INFO
)
import cloudpickle
cloudpickle.register_pickle_by_value(self.task_function.__self__.src)
return super().checkpoint(*args, **kwargs) As you can tell, this meant I had to pass a reference to import hydra
import cloudpickle
import src
# otherwise pickle.load from submitit launcher won't find src since cloudpickle registers by reference by default
cloudpickle.register_pickle_by_value(src)
class Train(TaskFunction):
def __init__(self):
self.model_and_trainer = None
self.run = None
self.cfg = None
self.src = src
# main train loop code
def __call__(self, cfg: DictConfig) -> float:
...
if __name__ == '__main__':
train = Train()
app = hydra.main(config_path='conf', config_name='config')(train.__call__)
app() It's...quite an annoying solution. Anyways, I hope with your PR that you consider the pecularities of having a different working directory than the code's directory in combination with using submitit. You might bump into the same issues. The PR has the potential to clean this up. |
Wow, nice work! Seems like very tricky logic. |
🚀 Feature Request
Currently, when preemption/timeout occurs, the submitit plugin resubmits the job with the same initial arguments. There is no clear way of adding custom logic within the
checkpoint
function of the plugin.hydra/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py
Lines 79 to 84 in 046837b
Motivation
My train loop saves (pytorch) checkpoints every N iterations. When preemption/timeout occurs and the submitit plugin resubmits the job, after the job starts it will continue from the most recently saved checkpoint, which could be several hundred (or thousands) of iterations old. The job will have to re-do training iterations that were already done before, which is a waste of time and resources.
It would be nice to have it such that once the submitit plugin's
checkpoint
function is called (in reaction to preemption/timeout), it could re-use the checkpoint save code that my train loop uses (or whatever logic I'd want, really), that way upon job restart it will start from a checkpoint at most 1 iteration old.Pitch
Describe the solution you'd like
From what I see so far, the
checkpoint
function is exposed to the current callable's context (viaself
):hydra/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py
Lines 45 to 77 in 046837b
The callable returns a
JobReturn
object, which is where thetask_function(...)
(i.e., the train loop which contains all the relevant train loop data, such as the model, its state dict, the optimizer's state dict, etc.) is executed. Specifically, it's executed withinrun_job(...)
:hydra/hydra/core/utils.py
Line 160 in 046837b
An issue I see here is that the
task_function
's context is inaccessible, so it's not possible to be able to pass through objects such as the model's current state, etc. Meaning that thecheckpoint
function also won't have access to that data. You'd need to figure out how to pass thetask_function
's context up to__call__(...)
so that whencheckpoint(...)
is called it will have access to job-specific data and the user can inject custom logic (maybe through a yaml config with_target_
set to asave()
function) that will be able to act upon this job-specific data, such as checkpointing/dumping it.Here are some steps forward I think would make sense (correct me if I'm wrong):
TaskFunction
class with a__call__
function instead of atask_function
from what I see here:hydra/hydra/main.py
Lines 15 to 18 in b16baa7
That way, you can treat each
task_function
as an instance and save your relevant train data usingself.<whatever> = <whatever>
in__call__
. From there, you can access eachtask_function
's context after it's been called. More importantly,run_job(...)
will have access to its context.run_job(...)
optionally accept aLauncher
context so that it could modify it. This would allow it to pass throughtask_function
's context back up toBaseSubmititLauncher
, which would make it available to itscheckpoint
function. Thetask_function
's context can be saved inself.task_function_context
.checkpoint
function check for anon_job_preempt_or_timeout
callback using_get_callbacks_for_run_job(...)
and execute the callback while passing inself.task_function_context
.(Also, as a bonus feature request that could be tackled simultaneously with the above steps, it'd be nice to be able to pass a
task_function
's context to each of the callbacks inrun_job(...)
. I could submit a separate feature request for this. The quick motivation is that you can make each task's wandb run object that was initialized within the task usingself.run = wandb.init(...)
accessible to callbacks so that you could make use of their Alerts interface to send Slack alerts. Currently there are ways to do this already but it involves reinitializing a wandb run within each callback separately, with the same run ID you used before, just to send an alert, which introduces unnecessary overhead due to theinit
process)Describe alternatives you've considered
I've considered saving checkpoints every iteration but that's just time and space consuming...
Are you willing to open a pull request? (See CONTRIBUTING)
I could be willing to fork and try out the steps above then submit a PR, but I'd like to get a feel for what y'all think about this request before I try implementing it myself.
Additional context
None.
The text was updated successfully, but these errors were encountered: