diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 8dcd6dab4e1d..77a838d7095f 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,8 +4,10 @@ # DeepSpeed Team import torch +import contextlib import functools from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,6 +18,11 @@ # Torch does not have compiler support torch_is_compiling = lambda: False +if required_torch_version(min_version="2.6.0a"): + from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable +else: + from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable + def is_compile_supported(): return required_torch_version(min_version=2.1) @@ -71,3 +78,15 @@ def wrapper(*args, **kwargs): def is_compiling(): return torch_is_compiling() + + +@contextlib.contextmanager +def compiled_autograd(enabled, kwargs): + try: + if enabled: + with compiled_autograd_enable(torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs)): + yield + else: + yield + finally: + pass diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 12867437d9dd..01f3aa055ace 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -105,7 +105,7 @@ from .pipe.module import PipelineModule from .utils import get_ma_status -from .compiler import is_compile_supported +from .compiler import is_compile_supported, compiled_autograd from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -420,6 +420,9 @@ def __init__(self, self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather) self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states) + self._is_compiled_autograd_enabled = False + self._compile_kwargs = {} + def _optimized_linear_offload_setup(self): self.optimized_linear_base_weight_sharding = False self.optimized_linear_lora_enabled = False @@ -2359,8 +2362,9 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): self._start_timers(self.engine_timers.backward_timers) loss = self._backward_prologue(loss, scale_wrt_gas) - self._do_optimizer_backward(loss, retain_graph) - self._backward_epilogue() + with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + self._do_optimizer_backward(loss, retain_graph) + self._backward_epilogue() self._stop_timers(self.engine_timers.backward_timers) return loss @@ -4078,7 +4082,11 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() - def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None: + def compile(self, + backend=get_accelerator().get_compile_backend(), + compile_kwargs={}, + schedule=None, + compiled_autograd_enabled=False) -> None: """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ @@ -4144,6 +4152,13 @@ def passes_name_to_fn(passes): raise self._is_compiled = True + self._compile_kwargs = compile_kwargs + if compiled_autograd_enabled: + if not self._deepcompile_active: + self._is_compiled_autograd_enabled = compiled_autograd_enabled + else: + logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.") + self._is_compiled_autograd_enabled = False def _set_deepcompile_active(self, active: bool) -> None: """Toggle DeepCompile runtime state and manage forward hooks accordingly."""