Skip to content

Commit

Permalink
fix batch size stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 11, 2024
1 parent 9f680f2 commit 8576f1c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def train(config: Config):
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
rank = int(os.environ["RANK"])

# batch_size is the total batch size for all GPUs
assert config.total_batch_size % world_size == 0
batch_size = config.total_batch_size // world_size
assert config.total_batch_size % local_world_size == 0
batch_size = config.total_batch_size // local_world_size

assert batch_size % config.per_device_train_batch_size == 0
gradient_accumulation_steps = batch_size // config.per_device_train_batch_size
Expand All @@ -141,7 +142,6 @@ def train(config: Config):
model = get_model(config)
model = model.to(local_rank)

local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
nnodes = world_size // local_world_size

# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
Expand Down

0 comments on commit 8576f1c

Please sign in to comment.