-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate Utilities #193
base: main
Are you sure you want to change the base?
Accelerate Utilities #193
Conversation
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! with a few nits, good work on this!
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Fixed a bug, added some tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need the per token fix to land as a prereq for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the replacement for get_execution_device
?
@dsikka The function This assumption causes an error in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good overall.
Do you mind adding a simple lifecycle dosctring which shows the steps of offloaded modules/parameters to make it slightly easier to follow how the parameters are updated?
I also think we should kick-off W4A16/W8A8 oneshot workflows, similar to what we did here: https://app.asana.com/0/1207078450218847/1208568399648361/f to make sure it runs to completion. I think past issues we've seen have been with g_idx and activation quantization parameters.
I think I understand from your PR as to why this can be removed. |
@dsikka w.r.t.
For these reasons it's a candidate (and we'll need it for the immediate future), but future work can determine whether we want to keep/ update it |
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
): | ||
""" | ||
Update the data of an existing parameter and its offload dict. Supports both | ||
parameters of offloaded modules and non-offloaded modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this supports non-offloaded modules? for what case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting non-offloaded modules allows this function to be used throughout the codebase without having to duplicate code
Ugleh...
if not has_offloaded_params(module):
param = getattr(module, name)
data = data.to(param.dtype)
if param.device != "meta":
param.data.copy_(data)
else:
update_offload_parameter(module, name, data)
Preetay!
update_offload_parameter(module, name, data)
module: torch.nn.Module, | ||
name: str, | ||
parameter: torch.nn.Parameter, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When registering the parameters during initialization, don't we know the device already, depending on if the module has been offloaded or not?
We can't pass that device to update_offload_parameter
to be used when updating the weights_map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just now rewritten these parts to a bit clearer.
don't we know the device already, depending on if the module has been offloaded or not?
During initialization, the _initialize_scale_zero_point
function determines the initial onload device
# begin on the same device as other parameters or cpu if offloaded.
# in the offloaded case, there's no point moving tensors to the execution device
# if they're going to be immediately offloaded by `register_offload_parameter`
params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device
It's the job of register_offload_parameter
(and by extension update_offload_parameter
, offload_to_weights_map
) to determine the offload device.
if isinstance(weights_map, dict):
if key in weights_map:
offload_device = weights_map[key].device
else:
tens = next(iter(weights_map.values()), None)
offload_device = tens.device if tens is not None else default_device
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Purpose
Prerequisites
Changes
Changes not covered by prerequisites:
getattr_chain
utility function (also used by llm-compressor)depreciated
utility decorator for future depreciationsregister_offload_parameter
anddelete_offload_parameter
for easier initialization and removal of parameters related to quantizationget_execution_device
Depreciation Strategy
These functions should be depreciated, each for their own reason. These strategies will be implemented in follow-up PRs
Upstream Strategy
Upstreaming functions to
accelerate
is a low priority, but comes with the benefit of more reviews and more official support