Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 10, 2024
1 parent 99524b4 commit e4b2ea9
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 207 deletions.
58 changes: 10 additions & 48 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Callable, Dict, Iterator, Optional, Tuple, cast
from typing import Callable, Dict, Iterator, Optional, Tuple

import torch
from torch.distributed import GradBucket

import nanotron.distributed as dist
from nanotron import logging
from nanotron.optim.zero import SlicedFlatTensor
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage

Expand Down Expand Up @@ -91,57 +90,20 @@ def __init__(
if not param.requires_grad:
continue

global_buffer_start_idx = length
global_buffer_end_idx = global_buffer_start_idx + param.numel()

start = length
end_weight = start + param.numel()
assert name not in segment_index
param = cast(SlicedFlatTensor, param)
segment_index[name] = (
(global_buffer_start_idx, global_buffer_end_idx),
(param.start_offset, param.end_offset),
param,
)
length = global_buffer_end_idx
segment_index[name] = (start, end_weight, param)
length = end_weight

big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda")

self.parameters = {}
for name, (
(global_start_idx, global_end_idx),
(dp_weight_start_idx, dp_weight_end_idx),
param,
) in segment_index.items():
# if name == "model.final_layer_norm.pp_block.weight":
# assert 1 == 1

fp32_p = big_flat_buffer[global_start_idx:global_end_idx].view_as(param)
# sliced_fp32_p = get_sliced_tensor(
# fp32_p,
# start_offset=dp_weight_start_idx,
# end_offset=dp_weight_end_idx,
# is_sharded=True,
# )
# assert (
# sliced_fp32_p.numel() == param.numel()
# ), f"Expected {name} to have the same number of elements, dp_weight_start_idx: {dp_weight_start_idx}, dp_weight_end_idx: {dp_weight_end_idx}, param.numel(): {param.numel()}, sliced_fp32_p.numel(): {sliced_fp32_p.numel()}"
self.parameters[name] = {
"fp32": fp32_p,
self.parameters = {
name: {
"fp32": big_flat_buffer[start_weight:end_weight].view_as(param),
"half": param,
}

# self.parameters = {
# name: {
# # "fp32": big_flat_buffer[global_start_idx:global_end_idx].view_as(param),
# # NOTE: save the way we shard stuff in dp for zero-1, so we can reshard it
# "fp32": get_sliced_tensor(
# big_flat_buffer[global_start_idx:global_end_idx].view_as(param),
# start_offset=dp_weight_start_idx,
# end_offset=dp_weight_end_idx,
# ),
# "half": param,
# }
# for name, ((global_start_idx, global_end_idx), (dp_weight_start_idx, dp_weight_end_idx), param) in segment_index.items()
# }
for name, (start_weight, end_weight, param) in segment_index.items()
}

with torch.inference_mode():
for _, elt in self.parameters.items():
Expand Down
22 changes: 0 additions & 22 deletions src/nanotron/optim/inherit_from_other_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,6 @@ def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, to
return self.optimizer.load_state_dict(state_dict, map_location=map_location)

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
# NOTE: error: RuntimeError: params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout
# NOTE: add assert device, dtype, layout are the same

params = [p for group in self.optimizer.param_groups for p in group["params"]]
[p.grad for p in params if p.grad is not None]
[state["exp_avg"] for state in self.optimizer.state_dict()["state"].values()]
[state["exp_avg_sq"] for state in self.optimizer.state_dict()["state"].values()]

# Check if all required attributes have the same device, dtype, and layout
# ref_device = params[0].device
# ref_dtype = params[0].dtype
# ref_layout = params[0].layout

# for attr_list, name in zip(
# [params, grads, exp_avgs, exp_avg_sqs],
# ["params", "grads", "exp_avgs", "exp_avg_sqs"]
# ):
# for idx, attr in enumerate(attr_list):
# assert attr.device == ref_device, f"{name}[{idx}] has device {attr.device}, expected {ref_device}"
# assert attr.dtype == ref_dtype, f"{name}[{idx}] has dtype {attr.dtype}, expected {ref_dtype}"
# assert attr.layout == ref_layout, f"{name}[{idx}] has layout {attr.layout}, expected {ref_layout}"

return self.optimizer.step(closure=closure)

def get_base_optimizer(self):
Expand Down
80 changes: 1 addition & 79 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors

# TODO(xrsrke): take rank instead of parallel_context
# def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
# if is_zero is True:
# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
# else:
# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def get_optimizer_filename(
tp_topology: Tuple[int, int],
Expand Down Expand Up @@ -129,7 +122,6 @@ def convert_to_string(input_item):
torch.save(
optimizer.state_dict(),
root_folder
# / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
/ get_optimizer_filename(
tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()),
pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()),
Expand Down Expand Up @@ -385,19 +377,9 @@ def round_robin_map(numbers, min_val, max_val):
range_size = max_val - min_val + 1
return [(num - 1) % range_size + min_val for num in numbers]

# if int(ckp_dp_size) != int(parallel_context.dp_pg.size()):
# pass
# else:

# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
# state_dict = torch.load(
# root_folder
# / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
# map_location=map_location,
# )
# NOTE: since here we only load the optimizer states,
# then we shard it according to the current data parallel dimension

# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
state_dict = torch.load(
root_folder
/ get_optimizer_filename(
Expand All @@ -416,24 +398,6 @@ def round_robin_map(numbers, min_val, max_val):
)

if isinstance(optimizer, ZeroDistributedOptimizer):

# NOTE: optimizer state topology-agnostic loading
# NOTE: only reshard after merging tp shards
# or we get a new dp_size
# if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# # NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
# current_dp_rank = dist.get_rank(parallel_context.dp_pg)
# OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
# for param_index in state_dict["state"]:
# param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
# for state_name in OPTIMIZER_STATE_NAMES:
# sliced_tensor = get_sliced_tensor(
# param=state_dict["state"][param_index][state_name],
# start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
# end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
# )
# state_dict["state"][param_index][state_name] = sliced_tensor

shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}_exp-*-of-{ckpt_expert_parallel_size}.pt"
Expand Down Expand Up @@ -474,8 +438,6 @@ def get_key_by_value(d, target_value):
(int(tp_rank), int(dp_rank))
][p_idx][key]

assert 1 == 1

# NOTE: now merge optimizer states across data parallel dimension
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
Expand All @@ -493,27 +455,6 @@ def get_key_by_value(d, target_value):
# NOTE: reshard gradient_accumulator if different dp size from checkpoint
if int(ckp_dp_size) != parallel_context.dp_pg.size():
assert int(ckp_tp_size) == parallel_context.tp_pg.size(), "Don't support changing TP size for ZeRO-1"
# merged_grad_accumulator = {}
# for name, param in state_dict["gradient_accumulator"].items():
# # NOTE: assume that we shard a parameter evenly across all DPs
# # TODO: ideally refactor a map between sharding and resharding, so
# # we don't have to assume things
# # merged_p = torch.zeros(param.numel()*int(ckp_dp_size), device="cuda")
# # merged_p = [torch.zeros_like(param) for _ in range(int(ckp_dp_size))]
# # dist.all_gather(merged_p, param.to("cuda"), group=parallel_context.dp_pg)
# # merged_grad_accumulator[name] = torch.cat(merged_p, dim=-1).to(map_location)

# merged_p_shape = ckp_optimizer_config["configs"]["orig_param_shapes"][name]
# merged_p_shape = tuple(int(x) for x in merged_p_shape)
# merged_p = torch.zeros(merged_p_shape).view(-1)
# dp_rank = dist.get_rank(parallel_context.dp_pg)
# dp_offset = ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"][name][str(dp_rank)]
# merged_p[int(dp_offset[0]):int(dp_offset[1])] = param.view(-1)
# dist.all_reduce(merged_p, group=parallel_context.dp_pg)
# merged_p = merged_p.view(merged_p_shape)
# merged_grad_accumulator[name] = merged_p.to(map_location)

assert 1 == 1
ckp_sharded_grad_accum = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
Expand Down Expand Up @@ -546,27 +487,8 @@ def get_key_by_value(d, target_value):
int(new_offset[0]) : int(new_offset[1])
]

# NOTE: reshard the gradient_accumulator

try:
assert state_dict["state"][0]["exp_avg"].numel() > 0
except:
assert 1 == 1

optimizer.load_state_dict(state_dict, map_location=map_location)

try:
assert state_dict["state"][0]["exp_avg"].numel() > 0
except:
assert 1 == 1

try:
assert optimizer.state_dict()["state"][0]["exp_avg"].numel() > 0
except:
assert 1 == 1

assert 1 == 1


def load_lr_scheduler(
lr_scheduler,
Expand Down
15 changes: 0 additions & 15 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,6 @@ def __init__(
)
self.model = self.init_model() # Defines self.model

# from torch import nn
# def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]:
# """
# Return all the leaf modules (modules without any child modules) in a PyTorch module.
# """
# leaf_modules = []
# for n, m in module.named_modules():
# if not list(m.children()):
# leaf_modules.append((n, m))
# return leaf_modules

# leaf_modules = get_leaf_modules(self.model)
# for name, param in self.model.named_parameters():
# print(name, param.shape)

self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
)
Expand Down
43 changes: 0 additions & 43 deletions tests/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,46 +543,3 @@ def _test_sliced_flat_tensor(parallel_context: ParallelContext):
assert not isinstance(c, SlicedFlatTensor)

parallel_context.destroy()


@rerun_if_address_is_in_use()
def test_wrap_slice_tensor_around_a_sharded_tensor():
init_distributed(1, 1, 1)(_test_wrap_slice_tensor_around_a_sharded_tensor)()


def _test_wrap_slice_tensor_around_a_sharded_tensor(parallel_context: ParallelContext):
a = torch.randn(2, 3, requires_grad=True)
grad = torch.randn(2, 3)
a.grad = grad

start_offset, end_offset = 1, 5
b = SlicedFlatTensor(a, start_offset=start_offset, end_offset=end_offset)

torch.testing.assert_close(a.grad, grad, atol=0, rtol=0)
torch.testing.assert_close(b.grad, grad.view(-1)[start_offset:end_offset])

# Deallocate the gradient by setting it to None
a.grad = None

assert a.grad is None
assert b.grad is None

# Setting gradient to None on the sliced tensor works
a.grad = grad
assert a.grad is not None
assert b.grad is not None
b.grad = None
assert b.grad is None
assert a.grad is None

with assert_fail_with(NotImplementedError):
b.grad = torch.randn(1, 5)

with assert_fail_with(NotImplementedError):
del b.grad

c = b[:3]
# It's important not to contaminate everyone.
assert not isinstance(c, SlicedFlatTensor)

parallel_context.destroy()

0 comments on commit e4b2ea9

Please sign in to comment.