Skip to content

Commit

Permalink
Merge branch 'main' into nemotron-export
Browse files Browse the repository at this point in the history
  • Loading branch information
ericharper authored Jul 6, 2024
2 parents 1d1437d + 613e1f1 commit 2279217
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def create_image_processor(mm_cfg):
else:
raise (ValueError("Currently only support CLIPImageProcessor and SiglipImageProcessor from Huggingface"))

crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224))
if hasattr(image_processor, 'crop_size'):
crop_size = mm_cfg.vision_encoder.get("crop_size")
if hasattr(image_processor, 'crop_size') and crop_size is not None:
assert crop_size == (
image_processor.crop_size['height'],
image_processor.crop_size['width'],
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,8 @@ def _integrate_original_checkpoint_data(self, checkpoint: Dict[str, Any]) -> Dic
]['optimizer']['param_groups']
else:
checkpoint['optimizer_states'][0]['param_groups'] = original_checkpoint['optimizer_states'][0][
'optimizer'
]['param_groups']
'param_groups'
]

return checkpoint

Expand Down
15 changes: 11 additions & 4 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def zero(self):
self.data.zero_()

def allreduce_buffer(self):
"""Synchronous buffer data allreduce """
"""Synchronous buffer data allreduce"""
self.data.div_(get_data_parallel_world_size())
torch.distributed.all_reduce(self.data, group=self._data_group)

Expand Down Expand Up @@ -175,7 +175,7 @@ class MainParamsOptimizerWrapper(torch.optim.Optimizer):
Arguments:
optimizer: base optimizer such as Adam or SGD.
fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce.
contiguous_grad_bucket: to enable allocating the master gradients in the
contiguous_grad_bucket: to enable allocating the master gradients in the
contiguous memory space to reduce memory fragmentation.
async_grad_allreduce: enable asynchronous gradient allreduce that is executed
along with the training step backprop.
Expand Down Expand Up @@ -339,6 +339,7 @@ def __init__(

def _make_param_hook(self, param, main_param, i, grad_chunk_info, is_expert_group):
"""Create the grad accumulation and all-reduce hook for backprop."""

# Hook used for back-prop.
def param_hook(*unused):
# Accumulates gradients on main gradients
Expand All @@ -361,7 +362,9 @@ def allreduce_grads(use_fused_div, tensor, data_group, grad_mult):
else:
tensor.div_(grad_mult)
torch.distributed.all_reduce(
tensor, group=data_group, async_op=True,
tensor,
group=data_group,
async_op=True,
)

# Asynchronous gradients allreduce accross data_parallel ranks
Expand Down Expand Up @@ -473,12 +476,16 @@ def load_state_dict(self, state_dict):
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
logging.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...')
if 'state' not in state_dict[optimizer_key]:
state_dict[optimizer_key]['state'] = {}
self.optimizer.load_state_dict(state_dict[optimizer_key])

# Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
if fp32_from_float16_params_key not in state_dict:
state_dict[fp32_from_float16_params_key] = []
for current_group, saved_group in zip(self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
Expand All @@ -489,7 +496,7 @@ def allreduce_main_grads(self):

@contextmanager
def no_sync(self):
""" A context manager to disable gradient synchronizations across
"""A context manager to disable gradient synchronizations across
data-parallel ranks."""
old_require_backward_grad_sync = self._require_backward_grad_sync
self._require_backward_grad_sync = False
Expand Down
4 changes: 2 additions & 2 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
split_w3s = np.split(w3, split_factor, axis=1)

split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)]
key = f'{layer_prefix}.mlp.experts_weight_1'
key = f'{layer_prefix}.mlp.fc.weight'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)

elif "experts.linear_fc2.weight" in key:
cat_dim = -1
val = np.concatenate(vals, axis=cat_dim)
split_vals = np.split(val, split_factor, axis=cat_dim)
key = f'{layer_prefix}.mlp.experts_weight_2'
key = f'{layer_prefix}.mlp.proj.weight'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)
else:
print(f"[WARNING] {key} not handled by converter")
Expand Down

0 comments on commit 2279217

Please sign in to comment.