Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor tensor subclass API to also use paramterization #146
Refactor tensor subclass API to also use paramterization #146
Changes from all commits
1562867
32843bb
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a big fan of undocumented, generic kwargs. I think it's hard to tell users the intent, because you can't write documentation for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is args for
from_float
function for the tensor subclass, this is just to be consistent with the existing int4 tensor subclass apisalso I don't think we want to use this function as a top level API, can we refactor this a bit later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect we use parametrization only for AOTI? As some kind of "pre-processing" there.
Especially given that my understanding of the long term plan is that AOTI will do this pre-processing themselves and we wll be able to remove it from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need this for torch.export (used by executorch), I'll add a test in next PR, also we want to have a consistent code path for all backends/runtimes I think. is there any problems with enabling this for all use cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD do you think that long term we want export to do the pre-processing?
I think if that's the case, then we might just want to figure out that story now (it might be less work than getting dynamo to handle parametrizations).
The main contentious bit is probably just where this pre-processing should live. One possible answer is that it should happen transparently as part of
torch.export.export()
: automatically search the created state dict for subclasses and flatten them (although this might be a problem if the user expects the state dict of theExportedProgram
to alias the original model's state dict)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is some issue with this in AOT Inductor I think. cc @desertfire
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @cpuhrsch 's question is why rewriting "+=". I am not aware any AOTI restriction that needs this rewrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@desertfire the error I'm getting with "+=" is this: https://gist.github.com/jerryzh168/d4ea2fb8138376cff903c38aaef8f5ef, is this expected?