Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ def register_sophia_optimizers() -> List[str]:
_optim = getattr(Sophia, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
optimizers.append(module_name)
try:
OPTIMIZERS.register_module(module=_optim)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
return optimizers


Expand All @@ -146,7 +148,8 @@ def register_bitsandbytes_optimizers() -> List[str]:
dadaptation_optimizers = []
try:
import bitsandbytes as bnb
except ImportError:
# import bnb may trigger cuda related error without nvidia gpu resources
except (ImportError, RuntimeError):
pass
else:
optim_classes = inspect.getmembers(
Expand All @@ -155,7 +158,10 @@ def register_bitsandbytes_optimizers() -> List[str]:
for name, optim_cls in optim_classes:
if name in OPTIMIZERS:
name = f'bnb_{name}'
OPTIMIZERS.register_module(module=optim_cls, name=name)
try:
OPTIMIZERS.register_module(module=optim_cls, name=name)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
dadaptation_optimizers.append(name)
return dadaptation_optimizers

Expand All @@ -170,7 +176,10 @@ def register_transformers_optimizers():
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
try:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
transformer_optimizers.append('Adafactor')
return transformer_optimizers

Expand Down
Loading