-
Notifications
You must be signed in to change notification settings - Fork 8
Compare microbatch forward outputs and gradients #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str): | |
| continue | ||
|
|
||
| self._logs.append( | ||
| f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}" | ||
| f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}" | ||
| ) | ||
|
|
||
| def run_node(self, n: torch.fx.Node) -> Any: | ||
|
|
@@ -429,6 +429,20 @@ def log_model_weights(self, parallel_mod): | |
|
|
||
| print(f"Weight hashes written to {path}") | ||
|
|
||
| def log_fw_intermediates(self, logs): | ||
| rank = torch.distributed.get_rank() | ||
| path = self.dir / f"rank_{rank}_fw_intermediates.log" | ||
| with open(path, "a") as f: | ||
| f.write("\n".join(logs) + "\n") | ||
|
|
||
| def log_diff(self, t, rank=0, prefix="?"): | ||
| if self.rank == rank: | ||
| path = self.dir / "diff.log" | ||
| if isinstance(t, torch.distributed.tensor.DTensor): | ||
| t = t.to_local() | ||
| with open(path, "a") as f: | ||
| f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n") | ||
|
|
||
| def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): | ||
| path = self.dir / "pp_weights.log" | ||
|
|
||
|
|
@@ -463,3 +477,40 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks): | |
|
|
||
| if self.rank == 0: | ||
| print(f"Weight hashes written to {path}") | ||
|
|
||
| def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is |
||
| path = self.dir / "diff.log" | ||
|
|
||
| torch.distributed.barrier() | ||
| for i in range(num_world_stages): | ||
| if self.rank in ranks and i in stage_mods: | ||
| grad_logs = [] | ||
| real_params = dict(stage_mods[i].named_parameters()) | ||
| for name, _ in orig_mod.named_parameters(): | ||
| if name not in real_params: | ||
| continue | ||
| grad = real_params[name].grad | ||
| if grad is None: | ||
| grad_logs.append(f"[grad {name}] None") | ||
| else: | ||
| grad = grad.to_local() | ||
| grad_logs.append( | ||
| f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}" | ||
| ) | ||
| with open(path, "a") as f: | ||
| f.write("\n".join(grad_logs) + "\n") | ||
| torch.distributed.barrier() | ||
|
|
||
|
|
||
| def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger): | ||
| def run(args): | ||
| with torch.fx.traceback.preserve_node_meta(): | ||
| interp = DebugInterpreter(fx_g) | ||
| out = interp.boxed_run(args) | ||
| mylogs = interp.get_logs() | ||
| if numerics_logger: | ||
| numerics_logger.log_fw_intermediates(mylogs) | ||
| return out | ||
|
|
||
| run._boxed_call = True | ||
| return run | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): | |
| mscale=0.70, | ||
| ) | ||
|
|
||
| bs = 4 * mesh.shape[0] * mesh.shape[1] | ||
| local_batch_size = 2 | ||
| global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] | ||
| device = torch.device(f"cuda:{local_rank}") | ||
|
|
||
| # parallelize the model | ||
|
|
@@ -129,11 +130,16 @@ def input_fn(): | |
| return torch.randint( | ||
| 0, | ||
| config.vocab_size, | ||
| (bs, seq_len), | ||
| (global_batch_size, seq_len), | ||
| device=device, | ||
| ) | ||
|
|
||
| with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: | ||
| numerics_logger = None | ||
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| with AutoParallel( | ||
| model, input_fn, mesh, dynamic=True, numerics_logger=None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be numerics_logger = numerics_logger? |
||
| ) as autop: | ||
| autop.add_parameter_memory_constraint(low=None, high=None) | ||
|
|
||
| # x_sharding = (Shard(0), Replicate()) | ||
|
|
@@ -153,17 +159,22 @@ def input_fn(): | |
| # ) # maybe not correct value | ||
| parallel_mod.init_weights(buffer_device=device, seed=rng_seed) | ||
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| numerics_logger.log_model_weights(parallel_mod) | ||
|
|
||
| x = ( | ||
| torch.randint( | ||
| 0, | ||
| config.vocab_size, | ||
| (bs // mesh.shape[0] // mesh.shape[1], seq_len), | ||
| device=device, | ||
| ), | ||
| torch.manual_seed(rng_seed) | ||
|
|
||
| n_microbatches = 16 | ||
| full_batch = torch.randint( | ||
| 0, | ||
| config.vocab_size, | ||
| (local_batch_size * n_microbatches, seq_len), | ||
| device=device, | ||
| ) | ||
| microbatches = torch.split(full_batch, local_batch_size, dim=0) | ||
| assert len(microbatches) == n_microbatches | ||
| if rng_seed: | ||
| numerics_logger.log_diff( | ||
| full_batch.to(torch.float32), prefix="full batch input" | ||
| ) | ||
|
|
||
| # Symbolically evaluate in case you want to test running a graph bigger than your gpu | ||
| if fake_evaluate: | ||
|
|
@@ -173,15 +184,22 @@ def input_fn(): | |
| allow_non_fake_inputs=True, | ||
| shape_env=shape_env, | ||
| ): | ||
| # # now let's run it | ||
| out = parallel_mod(*x) | ||
| out.backward(torch.randn_like(out)) | ||
| # now let's run it | ||
| for x in microbatches: | ||
| out = parallel_mod(x) | ||
| out.backward(torch.ones_like(out)) | ||
| else: | ||
| out = parallel_mod(*x) | ||
| assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" | ||
| for i, x in enumerate(microbatches): | ||
| assert x.shape[0] == 2 | ||
| out = parallel_mod(x) | ||
| assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" | ||
| out.backward(torch.ones_like(out)) | ||
| if rng_seed is not None: | ||
| numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") | ||
|
|
||
| if rng_seed is not None: | ||
| numerics_logger.log_forward_output(out) | ||
| out.backward(torch.randn_like(out)) | ||
| for k, v in parallel_mod.named_parameters(): | ||
| numerics_logger.log_diff(v.grad, prefix=f"grad {k}") | ||
|
|
||
| print("All good!") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -166,9 +166,9 @@ def run_test( | |
| # This is the spmd mesh to be used for tracing | ||
| mesh = world_mesh[("dp_mod_ep", "ep")] | ||
|
|
||
| global_batch_size = 32 * dp_degree | ||
| # Batch size that will be supplied to the schedule and will be broken down into microbatches | ||
| local_batch_size = global_batch_size // dp_degree | ||
| local_batch_size = 32 | ||
| # global_batch_size = local_batch_size * dp_degree | ||
| n_microbatches = 16 | ||
| # Batch size with which the spmd graphs will actually be executed | ||
| microbatch_size = local_batch_size // n_microbatches | ||
|
|
@@ -472,10 +472,6 @@ def last_stage_inp_with_loss_fn(): | |
|
|
||
| world_size = torch.distributed.get_world_size() | ||
| num_world_stages = world_size * len(stage_mods) | ||
| if rng_seed is not None: | ||
| NumericsLogger(logs_dir).log_pp_model_weights( | ||
| model, stage_mods, num_world_stages, ranks=[0, 4] | ||
| ) | ||
|
|
||
| stages = [] | ||
| # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata | ||
|
|
@@ -500,6 +496,7 @@ def last_stage_inp_with_loss_fn(): | |
| group=world_mesh.get_group("pp"), | ||
| ) | ||
| stages.append(stage) | ||
|
|
||
| # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank | ||
| schedule = build_pipeline_schedule( | ||
| stages=stages, | ||
|
|
@@ -511,9 +508,32 @@ def last_stage_inp_with_loss_fn(): | |
| backward_requires_autograd=False, | ||
| ) | ||
| assert isinstance(schedule, _PipelineScheduleRuntime) | ||
|
|
||
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| numerics_logger.log_pp_model_weights( | ||
| model, stage_mods, num_world_stages, ranks=[0, 4] | ||
| ) | ||
| torch.manual_seed(rng_seed) | ||
|
|
||
| def last_stage_forward_hook( | ||
| stage: GraphPipelineStage, action: str, output: torch.Tensor | ||
| ): | ||
| if not stage.is_last or rng_seed is None: | ||
| return | ||
|
|
||
| rank = torch.distributed.get_rank() | ||
| if rank == 4: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you somehow not hardcode this |
||
| numerics_logger.log_diff( | ||
| output, rank=4, prefix=f"mb{action.microbatch_index} fwd out" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, very confusing. Also do we care about pp_rank or global rank? Finally v style schedules will have last stage on rank 0? |
||
| ) | ||
|
|
||
| # Step 6. Override the pipeline runner's action implementations | ||
| schedule.register_custom_function( | ||
| FORWARD, functools.partial(stage_forward, numerics_logs=None) | ||
| FORWARD, | ||
| functools.partial( | ||
| stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook | ||
| ), | ||
| ) | ||
| schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) | ||
| schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) | ||
|
|
@@ -542,6 +562,10 @@ def last_stage_inp_with_loss_fn(): | |
| ) | ||
| if pp_rank == 0: | ||
| x = runtime_input_fn_first_stage() | ||
| if rng_seed: | ||
| numerics_logger.log_diff( | ||
| x.to(torch.float32), prefix="full batch input" | ||
| ) | ||
| graph_pp_runner.step( | ||
| x, target=target, losses=losses, return_outputs=False | ||
| ) | ||
|
|
@@ -556,6 +580,8 @@ def last_stage_inp_with_loss_fn(): | |
| payload_fn=lambda: f"losses: {losses}", | ||
| ) | ||
|
|
||
| numerics_logger.log_pp_grads(model, stage_mods, num_world_stages, ranks=[0, 4]) | ||
|
|
||
| print("All good!") | ||
|
|
||
| if torch.distributed.is_initialized(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Optional[Callable]