diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 92988887ed06e..11f1c67211e40 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -402,15 +402,11 @@ def _capture_compile_kwargs(compile_fn: Callable) -> Callable: @wraps(compile_fn) def _capture(*args: Any, **kwargs: Any) -> Any: - if not args: - # torch.compile is being applied as a decorator + if not args or not isinstance(args[0], nn.Module): + # either torch.compile is being applied as a decorator or we're compiling something else return compile_fn(*args, **kwargs) model = args[0] - if not isinstance(model, nn.Module): - # compiling something else - return compile_fn(*args, **kwargs) - compiled_model = compile_fn(model, **kwargs) compiled_model._compile_kwargs = deepcopy(kwargs) return compiled_model