Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Jul 8, 2024
2 parents 9dfd0e2 + 811f26b commit 81012d1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
10 changes: 7 additions & 3 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def estimate_memory(job_config: JobConfig):
# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and job_config.estimate.mode == "fake"
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
Expand Down Expand Up @@ -107,7 +111,7 @@ def loss_fn(pred, labels):
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
Expand Down Expand Up @@ -198,7 +202,7 @@ def loss_fn(pred, labels):
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.estimate.mode == "real":
if job_config.memory_estimation.disable_fake_mode and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()

Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then
fi

# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
Expand Down
17 changes: 16 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,27 @@ def build_test_list():
),
OverrideDefinitions(
[
["--estimate.memory=True", "--estimate.mode=real"],
[
"--memory_estimation.enabled",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 1",
"--training.data_parallel_degree 8",
"--experimental.data_parallel_type ddp",
"--experimental.enable_compiled_autograd",
]
],
"CompiledDDP",
"compiled_ddp",
ngpu=8,
),
]
return integration_tests_flavors

Expand Down
14 changes: 7 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,18 +503,18 @@ def __init__(self):
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)

# estimation mode settings
# memory estimation settings
self.parser.add_argument(
"--estimate.memory",
"--memory_estimation.enabled",
help="Whether to estimate memory usage for FSDP",
default=False,
action="store_true",
)

self.parser.add_argument(
"--estimate.mode",
type=str,
default="fake",
help="Mode of estimation to use ['fake', 'real']",
"--memory_estimation.disable_fake_mode",
help="Whether to estimate memory under FakeTensorMode",
default=False,
action="store_true",
)

def parse_args(self, args_list: list = sys.argv[1:]):
Expand Down

0 comments on commit 81012d1

Please sign in to comment.