-
Notifications
You must be signed in to change notification settings - Fork 19
support delayed scaling of weight in float8 all-gather #312
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,11 +289,10 @@ def inner_func(): | |
), "Mismatched lengths of amax tensors." | ||
|
||
if dist.is_initialized(): | ||
# Combine all the amax tensors into one tensor and reduce it | ||
# Note: do not reduce the weight values, because FSDP already ensures | ||
# the weight values on all ranks are the same after all-gather. | ||
all_amax_tensors = torch.cat( | ||
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list | ||
fp8_amax_x_tensor_list | ||
+ fp8_amax_w_tensor_list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we only do this if we are using fp8 all gather ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that could make sense, I'd love to see the data to see if this is going to matter for performance. Focusing on numerics for now, was hoping for performance be tackled in future PRs. |
||
+ fp8_amax_dL_dY_tensor_list | ||
) | ||
all_reduced_amax_tensor = all_reduce( | ||
all_amax_tensors, "MAX", list(range(dist.get_world_size())) | ||
|
@@ -302,12 +301,14 @@ def inner_func(): | |
all_reduced_amax_tensor = all_reduced_amax_tensor.wait() | ||
|
||
( | ||
reduced_fp8_amax_tensor, | ||
reduced_fp8_amax_x_tensor, | ||
reduced_fp8_amax_w_tensor, | ||
reduced_fp8_amax_dL_dY_tensor, | ||
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) | ||
|
||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) | ||
child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx]) | ||
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) | ||
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) | ||
|
||
# We create two stacked tensor groups, one for the amax history and one for the current scales | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
ScaledMMConfig, | ||
) | ||
|
||
from float8_experimental.float8_utils import EPS | ||
from float8_experimental.float8_utils import e4m3_dtype, EPS | ||
from torch._prims_common import suggest_memory_format | ||
|
||
|
||
|
@@ -189,3 +189,182 @@ def fsdp_post_all_gather( | |
out._scale = scale | ||
return | ||
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) | ||
|
||
|
||
class WeightWithDelayedFloat8CastTensor(torch.Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [no change needed] I wish there was a way to share some more code with the dynamic version There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, me too. Looking at the code below, really the only code which would be shared is |
||
@staticmethod | ||
def __new__( | ||
cls, | ||
tensor: torch.Tensor, | ||
amax_buffer: torch.Tensor, | ||
amax_history_buffer: torch.Tensor, | ||
scale_buffer: torch.Tensor, | ||
mm_config: ScaledMMConfig, | ||
is_amax_initialized: bool, | ||
): | ||
return torch.Tensor._make_wrapper_subclass( | ||
cls, | ||
tensor.size(), | ||
strides=tensor.stride(), | ||
storage_offset=tensor.storage_offset(), | ||
memory_format=suggest_memory_format(tensor), | ||
dtype=tensor.dtype, | ||
layout=tensor.layout, | ||
device=tensor.device, | ||
pin_memory=tensor.is_pinned(), | ||
requires_grad=tensor.requires_grad, | ||
) | ||
|
||
def __init__( | ||
self, | ||
tensor: torch.Tensor, | ||
amax_buffer: torch.Tensor, | ||
amax_history_buffer: torch.Tensor, | ||
scale_buffer: torch.Tensor, | ||
mm_config: ScaledMMConfig, | ||
is_amax_initialized: bool, | ||
): | ||
self._tensor = tensor | ||
self._amax_buffer = amax_buffer | ||
self._amax_history_buffer = amax_history_buffer | ||
self._scale_buffer = scale_buffer | ||
self._mm_config = mm_config | ||
|
||
# Note: is_amax_initialized is not a buffer to avoid data dependent | ||
# control flow visible to dynamo | ||
# TODO(future PR): add serialization for this flag | ||
self.is_amax_initialized = is_amax_initialized | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
if func == torch.ops.aten.detach.default: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mostly just a nit, but any reason to special-case detach here? Alternatively, you could set it up so that every view ops automatiomatically propagates subclass-ness in the same way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is something I wrote, I think it was just something I saw in some other subclasses. Having every view up propagate subclass-ness in the same way sounds good to me. |
||
return WeightWithDelayedFloat8CastTensor( | ||
args[0]._tensor, | ||
args[0]._amax_buffer, | ||
args[0]._amax_history_buffer, | ||
args[0]._scale_buffer, | ||
args[0]._mm_config, | ||
args[0].is_amax_initialized, | ||
) | ||
mm_config: Optional[ScaledMMConfig] = None | ||
amax_buffer: Optional[torch.Tensor] = None | ||
amax_history_buffer: Optional[torch.Tensor] = None | ||
scale_buffer: Optional[torch.Tensor] = None | ||
is_amax_initialized: Optional[bool] = None | ||
|
||
def unwrap(t): | ||
nonlocal mm_config | ||
if mm_config is None: | ||
mm_config = t._mm_config | ||
else: | ||
mm_config = merge_mm_configs(mm_config, t._mm_config) | ||
nonlocal amax_buffer | ||
if amax_buffer is None: | ||
amax_buffer = t._amax_buffer | ||
nonlocal amax_history_buffer | ||
if amax_history_buffer is None: | ||
amax_history_buffer = t._amax_history_buffer | ||
nonlocal scale_buffer | ||
if scale_buffer is None: | ||
scale_buffer = t._scale_buffer | ||
nonlocal is_amax_initialized | ||
if is_amax_initialized is None: | ||
is_amax_initialized = t.is_amax_initialized | ||
return t._tensor | ||
|
||
args, kwargs = pytree.tree_map_only( | ||
WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) | ||
) | ||
out = func(*args, **kwargs) | ||
if func not in _ops_to_preserve_subclass: | ||
return out | ||
return pytree.tree_map_only( | ||
torch.Tensor, | ||
lambda x: WeightWithDelayedFloat8CastTensor( | ||
x, | ||
amax_buffer, | ||
amax_history_buffer, | ||
scale_buffer, | ||
mm_config, | ||
is_amax_initialized, | ||
), | ||
out, | ||
) | ||
|
||
def __tensor_flatten__(self): | ||
return ( | ||
[ | ||
"_tensor", | ||
"_amax_buffer", | ||
"_amax_history_buffer", | ||
"_scale_buffer", | ||
], | ||
{ | ||
"mm_config": self._mm_config, | ||
"is_amax_initialized": is_amax_initialized, | ||
}, | ||
) | ||
|
||
@staticmethod | ||
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): | ||
return WeightWithDelayedFloat8CastTensor( | ||
inner_tensors["_tensor"], | ||
inner_tensors["_amax_buffer"], | ||
inner_tensors["_amax_history_buffer"], | ||
inner_tensors["_scale_buffer"], | ||
metadata["mm_config"], | ||
metadata["is_amax_initialized"], | ||
) | ||
|
||
def __repr__(self): | ||
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" | ||
|
||
def fsdp_pre_all_gather(self, mesh): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ill let @weifengpy confirm this portion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirming that fsdp part looks good |
||
# initialize if needed | ||
# TODO(before land): ensure settings are consistent between Float8Linear and here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still need to resolve this? |
||
if not self.is_amax_initialized: | ||
from float8_experimental.float8_linear import ( | ||
_maybe_initialize_amaxes_scales_for_float8_cast, | ||
) | ||
|
||
_maybe_initialize_amaxes_scales_for_float8_cast( | ||
self._tensor, | ||
self._amax_buffer, | ||
self._amax_history_buffer, | ||
self._scale_buffer, | ||
"max", # TODO(before land): read this from parent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
e4m3_dtype, | ||
self.is_amax_initialized, | ||
reduce_amax=True, | ||
) | ||
self.is_amax_initialized = True | ||
|
||
# this will: | ||
# 1. cast the tensor to float8 using `_scale_buffer` | ||
# 2. populate `_amax_buffer` inplace | ||
# TODO(future PR): clean up all the casting functions and clearly | ||
# separate dynamic vs delayed, tech debt has accumulated | ||
float8_tensor = Float8Tensor.to_float8( | ||
self._tensor, | ||
self._scale_buffer, | ||
e4m3_dtype, | ||
self._amax_buffer, | ||
self._mm_config, | ||
) | ||
return (float8_tensor._data,), (float8_tensor._scale,) | ||
|
||
def fsdp_post_all_gather( | ||
self, | ||
all_gather_outputs: Tuple[torch.Tensor, ...], | ||
metadata: Any, | ||
param_dtype: torch.dtype, | ||
*, | ||
out: Optional[torch.Tensor] = None, | ||
): | ||
(data,) = all_gather_outputs | ||
(scale,) = metadata | ||
if out is not None: | ||
assert isinstance(out, Float8Tensor), f"{type(out)}" | ||
out._scale = scale | ||
return | ||
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) |
Uh oh!
There was an error while loading. Please reload this page.