Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (proxy): scale computation delegated to bias proxy #938

Merged
merged 5 commits into from
Apr 19, 2024

Conversation

Giuseppe5
Copy link
Collaborator

@Giuseppe5 Giuseppe5 commented Apr 19, 2024

As part of decoupling QuantLayers and QuantTensor, the computation of output scale for bias quantization is now delegated to the underlying proxy, who takes as input the possibly quantized input and weights, and internally defines the output scale.

Compared to our current setup, the main change would happen in the case where:

  • bias quantizer is shared across multiple layers
  • These layers are of different types (e.g., Linear and Conv)

Our current setup would have no issue with this case, while after this PR, the user would be required to use internally scaled bias and not externally defined.
I believe this to be a edge case that can be safely ignored.

input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
def quant_output_scale_impl(
self, input: QuantTensor, weight: QuantTensor, module: torch.nn.Module) -> Tensor:
channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come its -1 if linear and 1 otherwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
For linear layer, the channel dimension is always the last one, otherwise for Conv, ConvTranpose etc. is always at dim 1, at least for the input tensor.

@Giuseppe5 Giuseppe5 merged commit 670420f into Xilinx:dev Apr 19, 2024
304 of 347 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants