Skip to content

Commit 9e64f81

Browse files
NAS export refactor + skip conversion on minitron restore (#424)
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 99c44d3 commit 9e64f81

File tree

12 files changed

+211
-162
lines changed

12 files changed

+211
-162
lines changed

examples/megatron-lm/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Coming soon ...
110110

111111
### ⭐ Pruning
112112

113+
Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README.
114+
113115
Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are:
114116

115117
- `TARGET_FFN_HIDDEN_SIZE`
@@ -121,14 +123,20 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab
121123
- `TARGET_NUM_LAYERS`
122124
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)
123125

126+
Example for depth pruning Qwen3-8B from 36 to 24 layers:
127+
124128
```sh
125129
PP=1 \
126130
TARGET_NUM_LAYERS=24 \
127131
HF_MODEL_CKPT=<pretrained_model_name_or_path> \
128-
MLM_MODEL_SAVE=/tmp/Qwen3-8B-DPruned \
132+
MLM_MODEL_SAVE=Qwen3-8B-Pruned \
129133
bash megatron-lm/examples/post_training/modelopt/prune.sh qwen/Qwen3-8B
130134
```
131135

136+
> [!TIP]
137+
> If number of layers in the model is not divisible by pipeline parallel size (PP), you can configure uneven
138+
> PP by setting `MLM_EXTRA_ARGS="--decoder-first-pipeline-num-layers <X> --decoder-last-pipeline-num-layers <Y>"`
139+
132140
## Learn More About Configuration
133141

134142
For simplicity, we use `shell` scripts and variables as arguments. Each script has at least 1 positional

modelopt/torch/nas/autonas.py

Lines changed: 11 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@
3030
from pydantic import create_model
3131
from torch.nn.modules.batchnorm import _BatchNorm
3232

33-
from modelopt.torch.opt.config import (
34-
ModeloptBaseConfig,
35-
ModeloptField,
36-
get_kwargs_for_create_model_with_rules,
37-
)
33+
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
3834
from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule
3935
from modelopt.torch.opt.mode import (
4036
ConvertEntrypoint,
@@ -56,34 +52,35 @@
5652
stats,
5753
torch_detach,
5854
torch_to,
59-
unwrap_model,
6055
)
6156

6257
from .algorithms import ConstraintsFunc, get_constraints_func
6358
from .conversion import NASModeRegistry
6459
from .patch import PatchData, PatchManager, _modelopt_eval_recursion_guard, prep_for_eval
6560
from .registry import DMRegistry
66-
from .search_space import SearchSpace, generate_search_space
67-
from .utils import MODELOPT_BN_CALIB_ITERS, MODELOPT_QUEUE_MAXLEN, get_subnet_config, sample, select
61+
from .search_space import generate_search_space
62+
from .utils import get_subnet_config, sample, select
6863

6964
__all__ = [
7065
"AutoNASConfig",
7166
"AutoNASModeDescriptor",
7267
"AutoNASPatchManager",
7368
"EvolveSearcher",
74-
"ExportConfig",
75-
"ExportModeDescriptor",
7669
"IterativeSearcher",
7770
"RandomSearcher",
7871
"convert_autonas_searchspace",
7972
"convert_searchspace",
80-
"export_searchspace",
8173
"restore_autonas_searchspace",
82-
"restore_export",
8374
"restore_searchspace",
8475
"update_autonas_metadata",
8576
]
8677

78+
# we have two different numbers here since during training it might take longer to stabilize
79+
MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib
80+
MODELOPT_BN_CALIB_ITERS = (
81+
100 # indicates # iters in train mode 'til we trust BN stats without calib
82+
)
83+
8784

8885
def _get_ratio_list():
8986
return (0.5, 0.67, 1.0)
@@ -132,25 +129,6 @@ def _norm_lin_config():
132129
)
133130

134131

135-
class ExportConfig(ModeloptBaseConfig):
136-
"""Configuration for the export mode.
137-
138-
This mode is used to export a model after NAS search.
139-
"""
140-
141-
strict: bool = ModeloptField(
142-
default=True,
143-
title="Strict export",
144-
description="Enforces that the subnet configuration must exactly match during export.",
145-
)
146-
147-
calib: bool = ModeloptField(
148-
default=False,
149-
title="Calibration",
150-
description="Whether to calibrate the subnet before exporting.",
151-
)
152-
153-
154132
class AutoNASPatchManager(PatchManager):
155133
"""A class to handle the monkey patching of the model for automode."""
156134

@@ -676,48 +654,6 @@ def update_autonas_metadata(
676654
metadata["subnet_config"] = get_subnet_config(model)
677655

678656

679-
def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType:
680-
"""Export a subnet configuration of the search space to a regular model."""
681-
# sanity check to avoid DP/DDP here in the entrypoint
682-
model = unwrap_model(model, raise_error=True)
683-
684-
# store config from model if we can find it for a future convert/restore process
685-
subnet_config = get_subnet_config(model)
686-
687-
# Check for patching and calibration
688-
if PatchManager.is_patched(model):
689-
manager = PatchManager.get_manager(model)
690-
if config.calib:
691-
manager.call_post_eval()
692-
manager.unpatch()
693-
694-
# export model in-place
695-
model = SearchSpace(model).export()
696-
697-
# construct metadata
698-
metadata = {
699-
"subnet_config": subnet_config,
700-
}
701-
702-
return model, metadata
703-
704-
705-
def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module:
706-
"""Restore & export the subnet configuration of the search space to a regular model."""
707-
# select subnet config provided in metadata
708-
select(model, metadata["subnet_config"], strict=config["strict"])
709-
710-
# run export
711-
model, metadata_new = export_searchspace(model, config)
712-
713-
# double check metadata
714-
unmatched_keys = compare_dict(metadata, metadata_new)
715-
if unmatched_keys:
716-
raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!")
717-
718-
return model
719-
720-
721657
@NASModeRegistry.register_mode
722658
class AutoNASModeDescriptor(ModeDescriptor):
723659
"""Class to describe the ``"autonas"`` mode.
@@ -738,12 +674,12 @@ def config_class(self) -> type[ModeloptBaseConfig]:
738674
@property
739675
def next_modes(self) -> set[str] | None:
740676
"""Modes that must immediately follow this mode."""
741-
return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}
677+
return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}
742678

743679
@property
744680
def export_mode(self) -> str | None:
745681
"""The mode that corresponds to the export mode of this mode."""
746-
return "export"
682+
return "export_nas"
747683

748684
@property
749685
def search_algorithm(self) -> type[BaseSearcher]:
@@ -769,40 +705,3 @@ def update_for_save(self) -> UpdateEntrypoint:
769705
def update_for_new_mode(self) -> UpdateEntrypoint:
770706
"""The mode's entrypoint for updating the models state before new mode."""
771707
return update_autonas_metadata
772-
773-
774-
@NASModeRegistry.register_mode
775-
class ExportModeDescriptor(ModeDescriptor):
776-
"""Class to describe the ``"export"`` mode.
777-
778-
The properties of this mode can be inspected via the source code.
779-
"""
780-
781-
@property
782-
def name(self) -> str:
783-
"""Returns the value (str representation) of the mode."""
784-
return "export"
785-
786-
@property
787-
def config_class(self) -> type[ModeloptBaseConfig]:
788-
"""Specifies the config class for the mode."""
789-
return ExportConfig
790-
791-
@property
792-
def is_export_mode(self) -> bool:
793-
"""Whether the mode is an export mode.
794-
795-
Returns:
796-
True if the mode is an export mode, False otherwise. Defaults to False.
797-
"""
798-
return True
799-
800-
@property
801-
def convert(self) -> ConvertEntrypoint:
802-
"""The mode's entrypoint for converting a model."""
803-
return export_searchspace
804-
805-
@property
806-
def restore(self) -> RestoreEntrypoint:
807-
"""The mode's entrypoint for restoring a model."""
808-
return restore_export

modelopt/torch/nas/conversion.py

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,28 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Main APIs+entrypoints for model pruning."""
16+
"""Main APIs+entrypoints for NAS conversion and export."""
1717

1818
from torch import nn
1919

20-
from modelopt.torch.opt.conversion import apply_mode
21-
from modelopt.torch.opt.mode import ModeLike, _ModeRegistryCls
22-
from modelopt.torch.utils import ModelLike, unwrap_model
23-
24-
__all__ = ["convert", "export"]
20+
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
21+
from modelopt.torch.opt.conversion import ApplyModeError, apply_mode
22+
from modelopt.torch.opt.mode import (
23+
ConvertEntrypoint,
24+
ConvertReturnType,
25+
MetadataDict,
26+
ModeDescriptor,
27+
ModeLike,
28+
RestoreEntrypoint,
29+
_ModeRegistryCls,
30+
)
31+
from modelopt.torch.utils import ModelLike, compare_dict, unwrap_model
32+
33+
from .patch import PatchManager
34+
from .search_space import SearchSpace
35+
from .utils import get_subnet_config, select
36+
37+
__all__ = ["ExportConfig", "ExportNASModeDescriptor", "convert", "export"]
2538

2639
NASModeRegistry = _ModeRegistryCls("nas")
2740

@@ -89,6 +102,108 @@ def convert(
89102
return apply_mode(model, mode, registry=registry)
90103

91104

105+
class ExportConfig(ModeloptBaseConfig):
106+
"""Configuration for the export mode.
107+
108+
This mode is used to export a model after NAS search.
109+
"""
110+
111+
strict: bool = ModeloptField(
112+
default=True,
113+
title="Strict export",
114+
description="Enforces that the subnet configuration must exactly match during export.",
115+
)
116+
117+
calib: bool = ModeloptField(
118+
default=False,
119+
title="Calibration",
120+
description="Whether to calibrate the subnet before exporting.",
121+
)
122+
123+
124+
def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType:
125+
"""Export a subnet configuration of the search space to a regular model."""
126+
# sanity check to avoid DP/DDP here in the entrypoint
127+
model = unwrap_model(model, raise_error=True)
128+
129+
# store config from model if we can find it for a future convert/restore process
130+
subnet_config = get_subnet_config(model)
131+
132+
# Check for patching and calibration
133+
if PatchManager.is_patched(model):
134+
manager = PatchManager.get_manager(model)
135+
if config.calib:
136+
manager.call_post_eval()
137+
manager.unpatch()
138+
139+
# export model in-place
140+
model = SearchSpace(model).export()
141+
142+
# construct metadata
143+
metadata = {
144+
"subnet_config": subnet_config,
145+
}
146+
147+
return model, metadata
148+
149+
150+
def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module:
151+
"""Restore & export the subnet configuration of the search space to a regular model."""
152+
# Megatron save_sharded_modelopt_state does not save subnet_config
153+
if "subnet_config" not in metadata:
154+
return model
155+
156+
# select subnet config provided in metadata
157+
select(model, metadata["subnet_config"], strict=config["strict"])
158+
159+
# run export
160+
model, metadata_new = export_searchspace(model, config)
161+
162+
# double check metadata
163+
unmatched_keys = compare_dict(metadata, metadata_new)
164+
if unmatched_keys:
165+
raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!")
166+
167+
return model
168+
169+
170+
@NASModeRegistry.register_mode
171+
class ExportNASModeDescriptor(ModeDescriptor):
172+
"""Class to describe the ``"export_nas"`` mode.
173+
174+
The properties of this mode can be inspected via the source code.
175+
"""
176+
177+
@property
178+
def name(self) -> str:
179+
"""Returns the value (str representation) of the mode."""
180+
return "export_nas"
181+
182+
@property
183+
def config_class(self) -> type[ModeloptBaseConfig]:
184+
"""Specifies the config class for the mode."""
185+
return ExportConfig
186+
187+
@property
188+
def is_export_mode(self) -> bool:
189+
"""Whether the mode is an export mode.
190+
191+
Returns:
192+
True if the mode is an export mode, False otherwise. Defaults to False.
193+
"""
194+
return True
195+
196+
@property
197+
def convert(self) -> ConvertEntrypoint:
198+
"""The mode's entrypoint for converting a model."""
199+
return export_searchspace
200+
201+
@property
202+
def restore(self) -> RestoreEntrypoint:
203+
"""The mode's entrypoint for restoring a model."""
204+
return restore_export
205+
206+
92207
def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Module:
93208
"""Export a pruned subnet to a regular model.
94209
@@ -118,4 +233,4 @@ def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Mod
118233

119234
# apply export mode and return model
120235
config = {"strict": strict, "calib": calib}
121-
return apply_mode(model, [("export", config)], registry=NASModeRegistry)
236+
return apply_mode(model, [("export_nas", config)], registry=NASModeRegistry)

modelopt/torch/nas/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
__all__ = ["DMRegistry"]
2121

2222

23-
DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the registry
23+
DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the NAS registry

modelopt/torch/nas/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,6 @@
6161
"replace_forward",
6262
]
6363

64-
# we have two different numbers here since during training it might take longer to stabilize
65-
MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib
66-
MODELOPT_BN_CALIB_ITERS = (
67-
100 # indicates # iters in train mode 'til we trust BN stats without calib
68-
)
69-
7064

7165
@contextmanager
7266
def batch_norm_ignored_flops():

modelopt/torch/prune/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
# nas is a required - so let's check if it's available
2323
import modelopt.torch.nas
24+
from modelopt.torch.utils import import_plugin
2425

2526
from . import fastnas, gradnas, plugins
2627
from .pruning import *
28+
29+
with import_plugin("mcore_minitron", verbose=False):
30+
from .plugins import mcore_minitron

0 commit comments

Comments
 (0)