Skip to content

[WIP] FSDP support for MoE training #2357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _scaled_grouped_mm(
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
"""
print("$$$ SCALED GROUPED MM")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down
30 changes: 28 additions & 2 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch.utils._pytree import tree_map

from torchao.prototype.moe_training import _scaled_grouped_mm


class ScaledGroupedMMTensor(torch.Tensor):
"""
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
Expand All @@ -16,6 +16,9 @@ class ScaledGroupedMMTensor(torch.Tensor):
def __init__(self, data: torch.Tensor):
self._data = data

def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data})"

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
if func.__name__ == cls.grouped_mm_func_name:
Expand All @@ -32,4 +35,27 @@ def __torch_function__(cls, func, types, args, kwargs={}):
has_offs = kwargs.get(cls.offs_arg_name) is not None
if A_is_2d and B_is_3d and has_offs:
return _scaled_grouped_mm(*args, **kwargs)
return super().__torch_function__(func, types, args, kwargs)

# Disable torch_function by hand because we don't want
# the wrapping behavior of the super() impl, go directly to
# torch_dispatch for the rest of the ops.
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs={}):
unwrap = lambda x: x._data if isinstance(x, cls) else x
wrap = lambda x: cls(x) if isinstance(x, torch.Tensor) else x
unwrapped_args, unwrapped_kwargs = tree_map(unwrap, (args, kwargs))

# special case: for ops with out=.. specified, we want the output tensor to be a subclass.
if 'out' in unwrapped_kwargs:
unwrapped_kwargs['out'] = tree_map(wrap, unwrapped_kwargs['out'])

with torch._C.DisableTorchFunctionSubclass():
output = func(*args, **kwargs)
wrapped_output = tree_map(wrap, output)
print("func", func.__name__)
print(wrapped_output)
return wrapped_output