diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 7f2d5c5..dcd2cf6 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -45,5 +45,5 @@ jobs: python examples/example_llama3.py python examples/example_dcp.py python examples/example_local_map.py - python examples/example_ds3_local_map.py python examples/example_pp_graph_passes.py + torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index f9293ce..694b866 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -1529,9 +1529,11 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): ) self.model_args = model_args - def init_weights(self, buffer_device: torch.device | None = None) -> None: - _init_weights_tok_embeddings(self) - _init_weights_layers(self, buffer_device) + def init_weights( + self, buffer_device: torch.device | None = None, seed: int | None = None + ) -> None: + _init_weights_tok_embeddings(self, seed) + _init_weights_layers(self, buffer_device, seed) _init_weights_norm_and_output(self) def forward( @@ -1585,8 +1587,10 @@ def forward(self, h): h = layer(h, self.freqs_cis) return h - def init_weights(self, buffer_device: torch.device | None = None) -> None: - _init_weights_layers(self, buffer_device) + def init_weights( + self, buffer_device: torch.device | None = None, seed: int | None = None + ) -> None: + _init_weights_layers(self, buffer_device, seed) class DeepSeekV3Stage0(DeepSeekV3StageI): @@ -1600,9 +1604,11 @@ def forward(self, tokens): # torch.Size([1024, 1024, 2048]) return super().forward(h) - def init_weights(self, buffer_device: torch.device | None = None) -> None: - _init_weights_tok_embeddings(self) - super().init_weights(buffer_device=buffer_device) + def init_weights( + self, buffer_device: torch.device | None = None, seed: int | None = None + ) -> None: + _init_weights_tok_embeddings(self, seed) + super().init_weights(buffer_device, seed) class DeepSeekV3StageN(DeepSeekV3StageI): @@ -1618,8 +1624,10 @@ def forward(self, h): output = self.output(h) if self.output is not None else h return output - def init_weights(self, buffer_device: torch.device | None = None) -> None: - super().init_weights(buffer_device=buffer_device) + def init_weights( + self, buffer_device: torch.device | None = None, seed: int | None = None + ) -> None: + super().init_weights(buffer_device, seed) _init_weights_norm_and_output(self) @@ -1628,7 +1636,11 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: ###################### -def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]): +def _init_weights_tok_embeddings( + self: Union[DeepSeekV3Model, DeepSeekV3Stage0], seed: int | None = None +): + if seed is not None: + torch.manual_seed(seed) if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) @@ -1636,15 +1648,18 @@ def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]) def _init_weights_layers( self: Union[DeepSeekV3Model, DeepSeekV3StageI], buffer_device: torch.device | None, + seed: int | None = None, ): if buffer_device is None: buffer_device = self.freqs_cis.device # type: ignore[assignment] with torch.device(buffer_device): # type: ignore[arg-type] self.freqs_cis = precompute_freqs_cis(self.model_args) - for layer in self.layers.values(): + for i, layer in enumerate(self.layers.values()): + if seed is not None: + torch.manual_seed(seed) if layer is not None: assert isinstance(layer, TransformerBlock) - layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type] + layer.init_weights(buffer_device) # type: ignore[arg-type] def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]): diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 2f77739..b96f4db 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from pathlib import Path from typing import Any, Iterable import torch @@ -377,3 +378,88 @@ def print_rank_by_rank(msg: Any): print(msg) print(f"{rank=} done") torch.distributed.barrier() + + +def hash_tensor(t: torch.Tensor) -> str: + if isinstance(t, torch.distributed.tensor.DTensor): + t = t.to_local() + return f"DTensor({hash_tensor(t)})" + + if t.is_complex(): + return f"real={hash_tensor(t.real)}, imag={hash_tensor(t.imag)})" + + return f"{torch.hash_tensor(t)}" + + +class NumericsLogger: + def __init__(self, base_dir: str): + self.base = Path(base_dir) + self.base.mkdir(parents=True, exist_ok=True) + self.rank = torch.distributed.get_rank() + self.dir = self._create_run_dir() + + def _create_run_dir(self) -> Path: + """ + Find the next available integer directory name under base_dir. + Example: base_dir/0, base_dir/1, base_dir/2, ... + """ + existing = [ + int(p.name) for p in self.base.iterdir() if p.is_dir() and p.name.isdigit() + ] + next_id = (max(existing) + 1) if existing else 0 + run_dir = self.base / str(next_id) + torch.distributed.barrier() + if self.rank == 0: + run_dir.mkdir() + torch.distributed.barrier() + return run_dir + + def log_model_weights(self, parallel_mod): + if self.rank == 0: + path = self.dir / "weights.log" + + logs = [] + for name, param in parallel_mod.named_parameters(): + logs.append(f"{name=} hash={hash_tensor(param)}") + for name, buf in parallel_mod.named_buffers(): + logs.append(f"{name=} hash={hash_tensor(buf)}") + + with open(path, "a") as f: + f.write("\n".join(logs) + "\n") + + print(f"Weight hashes written to {path}") + + def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): + path = self.dir / "pp_weights.log" + + torch.distributed.barrier() + # First print the params of every stage + for i in range(num_world_stages): + if self.rank in ranks and i in stage_mods: + param_logs = [] + real_params = dict(stage_mods[i].named_parameters()) + for name, _ in orig_mod.named_parameters(): + if name not in real_params: + continue + param = real_params[name] + param_logs.append(f"{name=} hash={hash_tensor(param)}") + with open(path, "a") as f: + f.write("\n".join(param_logs) + "\n") + torch.distributed.barrier() + + # Then print the buffers of every stage + for i in range(num_world_stages): + if self.rank in ranks and i in stage_mods: + buffer_logs = [] + real_buffers = dict(stage_mods[i].named_buffers()) + for name, _ in orig_mod.named_buffers(): + if name not in real_buffers: + continue + buffer = real_buffers[name] + buffer_logs.append(f"{name=} hash={hash_tensor(buffer)}") + with open(path, "a") as f: + f.write("\n".join(buffer_logs) + "\n") + torch.distributed.barrier() + + if self.rank == 0: + print(f"Weight hashes written to {path}") diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index 361e966..d5958fe 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import os +from typing import Optional + import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.tensor.placement_types import Shard @@ -15,117 +18,205 @@ MoEArgs, ) from autoparallel.api import AutoParallel +from autoparallel.utils import NumericsLogger + + +def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): + seq_len = 1024 + if fake_evaluate: + # must symbolically evaluate to run on 32 dp ranks + # world_size = 2048 + + world_size = 256 + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size + ) + local_rank = torch.distributed.get_rank() + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 64, 64), + mesh_dim_names=( + "dp", + "ep", + ), + ) + + config = DeepSeekV3ModelArgs( + vocab_size=102400, + max_seq_len=seq_len, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=1, # 27, + n_dense_layers=0, # 1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=mesh, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=False, + attn_mask_type="causal", + ) + else: + dp_degree = 2 + ep_degree = 2 + world_size = dp_degree * ep_degree + + assert ( + "WORLD_SIZE" in os.environ + ), f"run with torchrun --standalone --nproc-per-node {world_size}" + assert ( + int(os.getenv("WORLD_SIZE")) == world_size + ), f"Need at least {world_size} GPUs for real evaluation" + local_rank = int(os.getenv("LOCAL_RANK")) + torch.distributed.init_process_group(backend="nccl") + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (dp_degree, ep_degree), + mesh_dim_names=( + "dp", + "ep", + ), + ) + + config = DeepSeekV3ModelArgs( + vocab_size=2048, + max_seq_len=seq_len, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=4, + n_dense_layers=0, + n_heads=16, + moe_args=MoEArgs( + num_experts=4, + num_shared_experts=2, + top_k=2, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=mesh, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) + + bs = 4 * mesh.shape[0] * mesh.shape[1] + device = torch.device(f"cuda:{local_rank}") + + # parallelize the model + with torch.device("meta"): + model = DeepSeekV3Model(config).bfloat16() + + def input_fn(): + return torch.randint( + 0, + config.vocab_size, + (bs, seq_len), + device=device, + ) + + with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + # x_sharding = (Shard(0), Replicate()) + x_sharding = (Shard(0), Shard(0)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + + sharding_placement = autop.optimize_placement(verbose=False) + parallel_mod = autop.apply_placement(sharding_placement) + + parallel_mod.to_empty(device=device) + # run weight init on our sharded DTensor params + # TODO: plumb init_std through + # parallel_mod.init_weights( + # init_std=0.02, buffer_device="cuda" + # ) # maybe not correct value + parallel_mod.init_weights(buffer_device=device, seed=rng_seed) + if rng_seed is not None: + NumericsLogger(logs_dir).log_model_weights(parallel_mod) + + x = ( + torch.randint( + 0, + config.vocab_size, + (bs // mesh.shape[0] // mesh.shape[1], seq_len), + device=device, + ), + ) -# must symbolically evaluate to run on 32 dp ranks -# world_size = 2048 -fake_evaluate = True + # Symbolically evaluate in case you want to test running a graph bigger than your gpu + if fake_evaluate: + # all gather on the tokens takes 128 GiB (4GiB * 32 ranks) + shape_env = ShapeEnv() + with FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=shape_env, + ): + # # now let's run it + out = parallel_mod(*x) + out.backward(torch.randn_like(out)) + else: + out = parallel_mod(*x) + out.backward(torch.randn_like(out)) -world_size = 256 + print("All good!") -fake_store = FakeStore() -torch.distributed.init_process_group( - "fake", store=fake_store, rank=0, world_size=world_size -) -# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) -mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (world_size // 64, 64), - mesh_dim_names=( - "dp", - "ep", - ), -) - -device = torch.device("cuda") - - -bs = 4 * mesh.shape[0] * mesh.shape[1] -seq_len = 1024 - -config = DeepSeekV3ModelArgs( - vocab_size=102400, - max_seq_len=seq_len, - dim=2048, - inter_dim=10944, - moe_inter_dim=1408, - n_layers=1, # 27, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=64, - num_shared_experts=2, - top_k=6, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - use_flex_attn=False, - attn_mask_type="causal", -) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.synchronize() + torch.distributed.destroy_process_group() -# parallelize the model -with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() +if __name__ == "__main__": + import argparse -def input_fn(): - return torch.randint( - 0, - config.vocab_size, - (bs, seq_len), - device=device, + parser = argparse.ArgumentParser( + description="Run DeepSeek V3 pipeline parallel example" ) + parser.add_argument( + "--fake-evaluate", + action="store_true", + default=False, + help="Use fake evaluation mode with FakeTensorMode (default: False)", + ) + parser.add_argument( + "--rng-seed", + type=int, + default=None, + help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", + ) + parser.add_argument( + "--logs-dir", + type=str, + default="out/", + help="Directory to store logs (default: ./out/).", + ) + args = parser.parse_args() + if args.rng_seed is not None: + torch.use_deterministic_algorithms(True) + torch.manual_seed(args.rng_seed) -with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - - # x_sharding = (Shard(0), Replicate()) - x_sharding = (Shard(0), Shard(0)) - - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - - sharding_placement = autop.optimize_placement() - parallel_mod = autop.apply_placement(sharding_placement) - -parallel_mod.to_empty(device="cuda") -# run weight init on our sharded DTensor params -# TODO: plumb init_std through -# parallel_mod.init_weights( -# init_std=0.02, buffer_device="cuda" -# ) # maybe not correct value -parallel_mod.init_weights(buffer_device="cuda") -x = ( - torch.randint( - 0, - config.vocab_size, - (bs // mesh.shape[0] // mesh.shape[1], seq_len), - device=torch.device("cuda"), - ), -) - -# Symbolically evaluate in case you want to test running a graph bigger than your gpu -if fake_evaluate: - # all gather on the tokens takes 128 GiB (4GiB * 32 ranks) - shape_env = ShapeEnv() - with FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=shape_env, - ) as mode: - # # now let's run it - out = parallel_mod(*x) - out.backward(torch.randn_like(out)) -else: - out = parallel_mod(*x) - out.backward(torch.randn_like(out)) - - -print("All good!") + run_test( + fake_evaluate=args.fake_evaluate, rng_seed=args.rng_seed, logs_dir=args.logs_dir + ) diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 2f148fb..9baf4ad 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -50,7 +50,7 @@ stage_reshard, stage_unshard, ) -from autoparallel.utils import print_rank_by_rank +from autoparallel.utils import NumericsLogger # Configure logging to show DEBUG messages logging.basicConfig( @@ -100,7 +100,7 @@ def build_pipeline_schedule( return schedule -def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]): +def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): if not fake_evaluate: pp_degree = 2 dp_mod_ep_degree = 2 @@ -372,7 +372,8 @@ def shape_inference_output_fn_last_stage(): torch.save(cache, stage_file) pp_mod.to_empty(device=device) - pp_mod.init_weights(buffer_device=device) + # run weight init on our sharded DTensor params + pp_mod.init_weights(buffer_device=device, seed=rng_seed) # Store each stage's information in stage_mods, stage_graphs, and stage_graph_metas stage_mods[stage_idx] = pp_mod @@ -409,7 +410,12 @@ def shape_inference_output_fn_last_stage(): == len(stage_graph_metas) ) - # run weight init on our sharded DTensor params + world_size = torch.distributed.get_world_size() + num_world_stages = world_size * len(stage_mods) + if rng_seed is not None: + NumericsLogger(logs_dir).log_pp_model_weights( + model, stage_mods, num_world_stages, ranks=[0, 4] + ) stages = [] # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata @@ -446,9 +452,8 @@ def shape_inference_output_fn_last_stage(): ) assert isinstance(schedule, _PipelineScheduleRuntime) # Step 6. Override the pipeline runner's action implementations - numerics_logs = [] schedule.register_custom_function( - FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs) + FORWARD, functools.partial(stage_forward, numerics_logs=None) ) schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) @@ -475,9 +480,6 @@ def shape_inference_output_fn_last_stage(): else: graph_pp_runner.step() - if debug_numerics: - print_rank_by_rank("\n".join(numerics_logs)) - print("All good!") if torch.distributed.is_initialized(): @@ -504,10 +506,18 @@ def shape_inference_output_fn_last_stage(): default=None, help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", ) + parser.add_argument( + "--logs-dir", + type=str, + default="out/", + help="Directory to store logs (default: ./out/).", + ) args = parser.parse_args() if args.rng_seed is not None: torch.use_deterministic_algorithms(True) torch.manual_seed(args.rng_seed) - run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None) + run_test( + fake_evaluate=args.fake_evaluate, rng_seed=args.rng_seed, logs_dir=args.logs_dir + )