Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tensor subclass API to also use paramterization #146

Merged
merged 2 commits into from
May 3, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 17, 2024

Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Two issues right now:

  1. AOTI test: change_linear_weights_to_int8_dqtensors + cuda device doesn't work
  2. another AOTI test also failed with cuda device (only in CI), will create another PR to repro later

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 17, 2024
@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 4 times, most recently from a5b6dda to a9e5563 Compare April 19, 2024 01:30
@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 2 times, most recently from d0b9c23 to 25abb31 Compare April 20, 2024 00:16
@jerryzh168 jerryzh168 changed the base branch from aoti_tests to main April 20, 2024 04:04
@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 5 times, most recently from 578b4f0 to a906c53 Compare April 30, 2024 00:55
@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 4 times, most recently from 2efcc92 to c18e2f6 Compare April 30, 2024 02:21
msaroufim
msaroufim previously approved these changes Apr 30, 2024
@@ -493,7 +493,7 @@ def quant_int8_dynamic_per_token_linear(
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
)
if bias is not None:
mm_out += bias
mm_out = mm_out + bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is some issue with this in AOT Inductor I think. cc @desertfire

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @cpuhrsch 's question is why rewriting "+=". I am not aware any AOTI restriction that needs this rewrite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desertfire the error I'm getting with "+=" is this: https://gist.github.com/jerryzh168/d4ea2fb8138376cff903c38aaef8f5ef, is this expected?

@msaroufim msaroufim self-requested a review April 30, 2024 18:30
@msaroufim msaroufim dismissed their stale review April 30, 2024 18:30

just menat to review the yaml change

return lin

return insert_subclass


def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of undocumented, generic kwargs. I think it's hard to tell users the intent, because you can't write documentation for it.

Copy link
Contributor Author

@jerryzh168 jerryzh168 Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is args for from_float function for the tensor subclass, this is just to be consistent with the existing int4 tensor subclass apis

also I don't think we want to use this function as a top level API, can we refactor this a bit later?

).reshape(w.shape[0], -1)


def pack_tinygemm_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
assert scales.dtype == torch.bfloat16, f" got dtype: {scales.dtype}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this also show what dtype was expected? It seems like an opportunity for a dtype guard decorator or somesuch

def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
    if dtype is not None and tensor_arg.dtype != dtype:
        raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
    if size is not None and tensor_arg.size() != size:
        raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")

guard_dtype_size(scales, "scales", torch.bfloat16, zeros.size())
guard_dtype_size(zeros, "zeros", torch.bfloat16)

See ValueError reference manual for why I chose ValueError here.

self.kwargs = kwargs

def forward(self, int_data, q_scales):
return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs)
Copy link
Contributor

@cpuhrsch cpuhrsch Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use cls.__tensor_flatten__(*args) for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean tensor unflatten? we can't use cls in forward because of pytorch/pytorch#124735 right now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wrap do

def create_parameterization_module(cls):
    class SubclassParameterization:
        [...]

    def forward(self, args):
        cls.[...](args)

    return SubclassParameterization

then cls is given as an argument to create_parameterization_module and you return an instance of SubclassParameterization where cls is that argument. Essentially a module factory function.

These methods also shouldn't be static.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this using cls in forward? I tried this before, and with @torch._dynamo.allow_in_graph for the constructor function and it fails because we can't use class variable in dynamo right now I think.

are you suggesting something like this: 25abb31#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R134 (but using cls instead of hardcoding the class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly and using __tensor_unflatten__ instead of from_qtensor_components

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for reference using cls in forward is not supported until pytorch/pytorch#123350 is landed, according to Brain

return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs)

def right_inverse(self, tensor_subclass_instance):
return tensor_subclass_instance.int_data, tensor_subclass_instance.q_scales
Copy link
Contributor

@cpuhrsch cpuhrsch Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use return self.__tensor_flatten__ for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works, thanks. I'll create a parent class to host init and right_inverse

torchao/quantization/quant_api.py Outdated Show resolved Hide resolved
if enable_parametrization:
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(cls, *args))
Copy link
Contributor

@jcaip jcaip Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question - why do we want to enable this parameterization support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for supporting exporting the tensor subclass model, needed by aot_compile and also torch.export.export

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor subclasses don't work with AOTI

**kwargs,
)

class ConstructTensorSubclassInt8Dyn(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this made generic for all tensor subclasses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't do it now because of pytorch/pytorch#124735, should be able to do it after this is fixed

@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 3 times, most recently from d7ba6af to a50fea5 Compare May 1, 2024 17:43
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
assert len(args) >= 3, f"Not enough args. Expected more than or equal to 3, but got {len(args)}"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw @cpuhrsch we need to use checks + skip test here I think, otherwise this test would fail:
FAIL: test_aoti (main.TestAOTI)

@@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect we use parametrization only for AOTI? As some kind of "pre-processing" there.
Especially given that my understanding of the long term plan is that AOTI will do this pre-processing themselves and we wll be able to remove it from there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need this for torch.export (used by executorch), I'll add a test in next PR, also we want to have a consistent code path for all backends/runtimes I think. is there any problems with enabling this for all use cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD do you think that long term we want export to do the pre-processing?

I think if that's the case, then we might just want to figure out that story now (it might be less work than getting dynamo to handle parametrizations).

The main contentious bit is probably just where this pre-processing should live. One possible answer is that it should happen transparently as part of torch.export.export(): automatically search the created state dict for subclasses and flatten them (although this might be a problem if the user expects the state dict of the ExportedProgram to alias the original model's state dict)

@jerryzh168 jerryzh168 force-pushed the aoti_tests branch 12 times, most recently from 45543c0 to aff0c5b Compare May 2, 2024 02:18
Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 merged commit 2371ff8 into pytorch:main May 3, 2024
15 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants