From 8c446b1743a821638b55602eb700d276b1f69445 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Sat, 30 Dec 2023 14:13:54 -0600 Subject: [PATCH 1/3] Better strategy auto selection with kwargs override --- basics/base_task.py | 8 +++++- utils/training_utils.py | 57 +++++++++++++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 5e6890ee..8a180e0b 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -408,7 +408,13 @@ def start(cls): accelerator=hparams['pl_trainer_accelerator'], devices=hparams['pl_trainer_devices'], num_nodes=hparams['pl_trainer_num_nodes'], - strategy=get_strategy(hparams['pl_trainer_strategy']), + strategy=get_strategy( + hparams['pl_trainer_devices'], + hparams['pl_trainer_num_nodes'], + hparams['pl_trainer_accelerator'], + hparams['pl_trainer_strategy'], + hparams['pl_trainer_precision'], + ), precision=hparams['pl_trainer_precision'], callbacks=[ DsModelCheckpoint( diff --git a/utils/training_utils.py b/utils/training_utils.py index c38e7aae..b9c057a1 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -364,17 +364,48 @@ def __getstate__(self): del state["_all_rank_experiment"] return state - -def get_strategy(strategy): - if strategy['name'] == 'auto': - return 'auto' - +def get_strategy( + devices = "auto", + num_nodes = 1, + accelerator = "auto", + strategy = "auto", + precision = None, +): + from lightning.pytorch.trainer.connectors import accelerator_connector + from lightning.pytorch.accelerators import AcceleratorRegistry from lightning.pytorch.strategies import StrategyRegistry - if strategy['name'] not in StrategyRegistry: - available_names = ", ".join(sorted(StrategyRegistry.keys())) or "none" - raise ValueError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}") - - data = StrategyRegistry[strategy['name']] - params = data['init_params'] - params.update({k: v for k, v in strategy.items() if k != 'name'}) - return data['strategy'](**utils.filter_kwargs(params, data['strategy'])) + class _DsAcceleratorConnector(accelerator_connector._AcceleratorConnector): + def __init__(self) -> None: + accelerator_connector._register_external_accelerators_and_strategies() + self._registered_strategies = StrategyRegistry.available_strategies() + self._accelerator_types = AcceleratorRegistry.available_accelerators() + self._strategy_flag = "auto" + self._accelerator_flag = "auto" + self._precision_plugin_flag = None + self._parallel_devices = [] + self.checkpoint_io = None + self._check_config_and_set_final_flags( + strategy=strategy['name'], + accelerator=accelerator, + precision=precision, + plugins=[], + sync_batchnorm=False, + ) + if self._accelerator_flag == "auto": + self._accelerator_flag = self._choose_auto_accelerator() + elif self._accelerator_flag == "gpu": + self._accelerator_flag = self._choose_gpu_accelerator_backend() + self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) + self._set_parallel_devices_and_init_accelerator() + if self._strategy_flag == "auto": + self._strategy_flag = self._choose_strategy() + self._check_strategy_and_fallback() + self._init_strategy() + accerlarator = _DsAcceleratorConnector() + for k in StrategyRegistry.available_strategies(): + if StrategyRegistry[k]['strategy'] is accerlarator.strategy.__class__: # type: ignore + data = StrategyRegistry[k] + params = data['init_params'] + params.update({k: v for k, v in strategy.items() if k != 'name'}) + return data['strategy'](**utils.filter_kwargs(params, data['strategy'])) + raise ValueError(f"Strategy {strategy['name']} not found") From 5cf42a68e097c0289b37c571acc7cae13edcf739 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Wed, 3 Jan 2024 23:43:12 -0600 Subject: [PATCH 2/3] Better than PL --- utils/training_utils.py | 60 +++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/utils/training_utils.py b/utils/training_utils.py index b9c057a1..e7547e20 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -371,19 +371,19 @@ def get_strategy( strategy = "auto", precision = None, ): - from lightning.pytorch.trainer.connectors import accelerator_connector + from lightning.fabric.utilities.device_parser import _determine_root_gpu_device from lightning.pytorch.accelerators import AcceleratorRegistry - from lightning.pytorch.strategies import StrategyRegistry + from lightning.pytorch.accelerators.cuda import CUDAAccelerator + from lightning.pytorch.accelerators.mps import MPSAccelerator + from lightning.pytorch.strategies import Strategy, SingleDeviceStrategy, StrategyRegistry + from lightning.pytorch.trainer.connectors import accelerator_connector + from lightning.pytorch.utilities.rank_zero import rank_zero_warn class _DsAcceleratorConnector(accelerator_connector._AcceleratorConnector): def __init__(self) -> None: accelerator_connector._register_external_accelerators_and_strategies() self._registered_strategies = StrategyRegistry.available_strategies() self._accelerator_types = AcceleratorRegistry.available_accelerators() - self._strategy_flag = "auto" - self._accelerator_flag = "auto" - self._precision_plugin_flag = None self._parallel_devices = [] - self.checkpoint_io = None self._check_config_and_set_final_flags( strategy=strategy['name'], accelerator=accelerator, @@ -401,11 +401,43 @@ def __init__(self) -> None: self._strategy_flag = self._choose_strategy() self._check_strategy_and_fallback() self._init_strategy() - accerlarator = _DsAcceleratorConnector() - for k in StrategyRegistry.available_strategies(): - if StrategyRegistry[k]['strategy'] is accerlarator.strategy.__class__: # type: ignore - data = StrategyRegistry[k] - params = data['init_params'] - params.update({k: v for k, v in strategy.items() if k != 'name'}) - return data['strategy'](**utils.filter_kwargs(params, data['strategy'])) - raise ValueError(f"Strategy {strategy['name']} not found") + for k in ['colossalai', 'bagua', 'hpu', 'hpu_parallel', 'hpu_single', 'ipu', 'ipu_strategy']: + if k in StrategyRegistry: + StrategyRegistry.remove(k) + def _init_strategy(self) -> None: + assert isinstance(self._strategy_flag, (str, Strategy)) + if isinstance(self._strategy_flag, str): + if self._strategy_flag not in StrategyRegistry: + available_names = ", ".join(sorted(StrategyRegistry.available_strategies())) or "none" + raise KeyError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}") + data = StrategyRegistry[self._strategy_flag] + params = {} + # Replicate additional logic for _choose_strategy when dealing with single device strategies + if issubclass(data['strategy'], SingleDeviceStrategy): + if self._accelerator_flag == "hpu": + params = {"device": torch.device('hpu')} + elif self._accelerator_flag == "tpu": + params = {"device": self._parallel_devices[0]} + elif data['strategy'] is SingleDeviceStrategy: + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + ): + params = {"device": _determine_root_gpu_device(self._parallel_devices)} + else: + params = {"device": "cpu"} + else: + raise NotImplementedError + params.update(data['init_params']) + params.update({k: v for k, v in strategy.items() if k != 'name'}) + self.strategy = data['strategy'](**utils.filter_kwargs(params, data['strategy'])) + elif isinstance(self._strategy_flag, SingleDeviceStrategy): + params = {'device': self._strategy_flag.root_device} + params.update({k: v for k, v in strategy.items() if k != 'name'}) + self.strategy = self._strategy_flag.__class__(**utils.filter_kwargs(params, self._strategy_flag.__class__)) + else: + rank_zero_warn( + f"Inferred strategy {self._strategy_flag.__class__.__name__} cannot take custom configurations." \ + f"To use custom configurations, please specify the strategy name explicitly." + ) + self.strategy = self._strategy_flag + return _DsAcceleratorConnector().strategy From 98ba675759f10b8dce69c40158f9f39a0d9a1a59 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Thu, 4 Jan 2024 00:15:59 -0600 Subject: [PATCH 3/3] Formatting --- utils/training_utils.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/utils/training_utils.py b/utils/training_utils.py index e7547e20..98fb6da4 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -365,11 +365,11 @@ def __getstate__(self): return state def get_strategy( - devices = "auto", - num_nodes = 1, - accelerator = "auto", - strategy = "auto", - precision = None, + devices="auto", + num_nodes=1, + accelerator="auto", + strategy={"name": "auto"}, + precision=None, ): from lightning.fabric.utilities.device_parser import _determine_root_gpu_device from lightning.pytorch.accelerators import AcceleratorRegistry @@ -385,7 +385,7 @@ def __init__(self) -> None: self._accelerator_types = AcceleratorRegistry.available_accelerators() self._parallel_devices = [] self._check_config_and_set_final_flags( - strategy=strategy['name'], + strategy=strategy["name"], accelerator=accelerator, precision=precision, plugins=[], @@ -401,9 +401,10 @@ def __init__(self) -> None: self._strategy_flag = self._choose_strategy() self._check_strategy_and_fallback() self._init_strategy() - for k in ['colossalai', 'bagua', 'hpu', 'hpu_parallel', 'hpu_single', 'ipu', 'ipu_strategy']: + for k in ["colossalai", "bagua", "hpu", "hpu_parallel", "hpu_single", "ipu", "ipu_strategy"]: if k in StrategyRegistry: StrategyRegistry.remove(k) + def _init_strategy(self) -> None: assert isinstance(self._strategy_flag, (str, Strategy)) if isinstance(self._strategy_flag, str): @@ -413,12 +414,12 @@ def _init_strategy(self) -> None: data = StrategyRegistry[self._strategy_flag] params = {} # Replicate additional logic for _choose_strategy when dealing with single device strategies - if issubclass(data['strategy'], SingleDeviceStrategy): + if issubclass(data["strategy"], SingleDeviceStrategy): if self._accelerator_flag == "hpu": - params = {"device": torch.device('hpu')} + params = {"device": torch.device("hpu")} elif self._accelerator_flag == "tpu": params = {"device": self._parallel_devices[0]} - elif data['strategy'] is SingleDeviceStrategy: + elif data["strategy"] is SingleDeviceStrategy: if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") ): @@ -427,17 +428,18 @@ def _init_strategy(self) -> None: params = {"device": "cpu"} else: raise NotImplementedError - params.update(data['init_params']) - params.update({k: v for k, v in strategy.items() if k != 'name'}) - self.strategy = data['strategy'](**utils.filter_kwargs(params, data['strategy'])) + params.update(data["init_params"]) + params.update({k: v for k, v in strategy.items() if k != "name"}) + self.strategy = data["strategy"](**utils.filter_kwargs(params, data["strategy"])) elif isinstance(self._strategy_flag, SingleDeviceStrategy): - params = {'device': self._strategy_flag.root_device} - params.update({k: v for k, v in strategy.items() if k != 'name'}) + params = {"device": self._strategy_flag.root_device} + params.update({k: v for k, v in strategy.items() if k != "name"}) self.strategy = self._strategy_flag.__class__(**utils.filter_kwargs(params, self._strategy_flag.__class__)) else: rank_zero_warn( - f"Inferred strategy {self._strategy_flag.__class__.__name__} cannot take custom configurations." \ + f"Inferred strategy {self._strategy_flag.__class__.__name__} cannot take custom configurations." f"To use custom configurations, please specify the strategy name explicitly." ) self.strategy = self._strategy_flag + return _DsAcceleratorConnector().strategy