Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifed masking before pooling - Fixes issue in ONNX conversion #92

Merged
merged 1 commit into from
Apr 12, 2024

Conversation

ashokrajab
Copy link
Contributor

@ashokrajab ashokrajab commented Oct 26, 2023

Issue:
In class INSTRUCTOR_Transformer, inside def forward(), the attention mask corresponding to the instruction tokens are
set to 0 in the following manner:

if context_masks is not None:
            import torch
            assert len(context_masks) == len(attention_mask)
            n = len(attention_mask)
            # print('n ',n)
            for local_idx in range(n):
                assert torch.sum(attention_mask[local_idx]).item() >= context_masks[local_idx].item(),\
                    f'{attention_mask[local_idx]}, {context_masks[local_idx]}, ' \
                    f'{torch.sum(attention_mask[local_idx]).item()}, {context_masks[local_idx].item()}'
                attention_mask[local_idx][:context_masks[local_idx]] = 0

I want to draw attention to the line n = len(attention_mask). This int variable will be treated as a constant during onnx conversion, which will lead to incorrect inference when the instruction token length changes.

Solution:
Instead of geting the instruction token length and manually iterating the attention_mask to set the value as 0,
I have introduced def prepare_input_features() function under class Instructor that carries out the same task using tensor manipulations.
By this way performing inference using the onnx model works as expected for any instruction.

Other change set:
There are many other diff in the pull request, those are a result of adhering to formatting/linting standards.

@ashokrajab
Copy link
Contributor Author

@Harry-hash @hongjin-su
Looking forward to your inputs...

@ashokrajab ashokrajab force-pushed the onnx_conversion_fix branch 2 times, most recently from fcf4147 to 43e7c83 Compare November 15, 2023 06:40
@ashokrajab
Copy link
Contributor Author

@hongjin-su @Harry-hash
Just a gentle reminder..

@ashokrajab
Copy link
Contributor Author

Following up on this.

@ashokrajab
Copy link
Contributor Author

@hongjin-su @Harry-hash
just a reminder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants