Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix a small device bug #57

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions slapo/framework_dialect/deepspeed/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,11 @@ def forward(self, *args, **kwargs):
if logger.isEnabledFor(DEBUG):
logger.debug(f"[{self.name}] Flatten: {len(ret)}; metadata: {metadata}")
ret.append(
torch.ByteTensor(
torch.ByteStorage.from_buffer(bytes(encode_metadata(metadata), "utf8"))
).to(device)
torch.tensor(
torch.ByteStorage.from_buffer(bytes(encode_metadata(metadata), "utf8")),
dtype=torch.uint8,
device=device,
)
)
ret = tuple(ret)
return ret
Expand Down
10 changes: 6 additions & 4 deletions slapo/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ def _shard(name, tensor):
self.metadata.tie_weights[param] = new_param
else:
new_param = nn.Parameter(new_tensor)

# Tag param with model parallel attribute, used for grad clipping
new_param.tensor_model_parallel = True
# Save the original size of the parameter for consolidation.
new_param.orig_shape = param.shape

self.mod.register_parameter(tensor_name, new_param)
except AttributeError:
buffer = self.mod.get_buffer(tensor_name)
Expand Down Expand Up @@ -1363,8 +1363,10 @@ def _consolidate_and_broadcast(sch: Schedule):
is_found = True
if is_found:
cnt_shard += 1
new_param = param.detach().split(sharded_size, dim=axis)[tp_rank]
sch.mod.register_parameter(param_name, nn.Parameter(new_param))
sharded_param = param.detach().split(sharded_size, dim=axis)[tp_rank]
new_param = nn.Parameter(sharded_param)
new_param.tensor_model_parallel = True
sch.mod.register_parameter(param_name, new_param)

for subsch in sch.child.values():
ret = _consolidate_and_broadcast(subsch)
Expand Down
2 changes: 1 addition & 1 deletion slapo/sharding/sync_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def all_gather_along_dim(inp, dim, world_size, group):
else:
# Fallback to all_gather. This may lead to suboptimal performance.
parts = [
torch.empty(inp.shape, dtype=inp.dtype).to(inp.device)
torch.empty(inp.shape, dtype=inp.dtype, device=inp.device)
for _ in range(world_size)
]
dist.all_gather(parts, inp, group=group)
Expand Down