-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Owl v2 speedup #759
Owl v2 speedup #759
Conversation
assert abs(223 - response.predictions[0].x) < 1.5 | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest moving somewhere else - I do maintain stack of useful scripts in development
directory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you gave me this comment and approved - should I interpret that as, move and rereview, or merge and keep this in mind for the future? I guess, what should my actionable be here?
thank you for the feedback! :)
3d93256
Description
This change introduces some simple approaches to speeding up inference of an OWLv2 model over the original Huggingface implementation. The original implementation on an L4 GPU took 460ms on each iteration of the new test_owlv2.py file, and the new one takes ~36ms. I wrote a tiny little latency test script at the bottom of test_owlv2.py, not sure if there's a better / more 'standard' way of implementing that functionality, or if I should just take it out.
First, we replace the Huggingface preprocessing pipeline with one that makes full utilization of a GPU if available. This reduced preprocessing time from ~200ms to <10ms. I borrowed this implementation from my own open source repo here.
Second, we run the image backbone in mixed precision using PyTorch's autocast. We use float16 instead of bfloat16 as it is compatible with older GPUs such as the T4. We disable mixed precision on CPU as it can sometimes lead to unexpected behavior such as silently hanging if running in float16. We don't run the rest of the model in mixed precision because the built-in NMS kernel does not automatically support mixed precision and working around that would require a refactor with little to no additional speedup.
Third, we compile the model using torch's built-in torch.compile method. I limited the scope of the compilation to the vision backbone's forward pass following my open source repo as it reduces potential issues with the compiler interacting with Python objects. Note that the Huggingface OWLv2 implementation uses a manual self attention kernel, which is very slow compared to existing optimized kernels such as flash attention. Originally I went in and manually replaced the attention mechanism with flash_attn (again following my open source code release) but that package has challenging hardware dependency issues. I then found that torch 2.4's torch.compile method achieved a similar runtime as the flash attention implementation. That and associated VRAM memory reduction led me to conclude that the compile was likely optimizing out the manual attention implementation and replacing it with something more effective. As torch.compile is very general and doesn't have weird hardware dependency issues, I opted to just use that instead of manually plugging in flash attention.
When combined, the second and third improvements on an L4 GPU bought the model time from ~200ms to ~20ms. Overall, I reduced the latency from ~460ms to ~36ms.
On a T4 GPU, the improvements reduce the latency from ~680ms to ~170ms. Additionally, the memory usage is higher on a T4 GPU than an L4 GPU. I suspect this is because torch.compile is not introducing flash attention to the pipeline in the same way as it does on the newer L4 GPU. The most popular flash attention implementation, flash_attn, does not in fact support T4 GPUs, which may be the source of the problem. This could be addressed by building hardware-conditional optimizations manually, changing which version of flash_attn is installed conditional on the available hardware, but I'm leaving that to a future version as it doesn't seem that there yet exist best practices for hardware-conditional code within this codebase.
Finally, running the larger version of the model takes ~140ms on an L4 and ~960ms on a T4 GPU. That means the bigger model with these optimizations is actually faster then the existing version on an L4, but meaningfully slower on a T4. This is likely due to the conjectured lack of flash attention, which could slow down the larger model more than the base model as the larger model uses a larger image size and therefore processes more tokens via attention.
I also changed some of the type signatures from Image.Image to np.ndarray as it looks like they were just receiving numpy arrays anyway. Let me know if that is not correct!
This is my first pull request with Roboflow! Please let me know in what ways I am deviating from expected behavior :)
Type of change
Please delete options that are not relevant.
How has this change been tested, please provide a testcase or example of how you tested the change?
I updated the tolerance of the existing OWLv2 integration test and ran it. @probicheaux and I also tested it briefly in his OWLv2 branch. These changes shouldn't meaningfully change the logic of the model so I would expect no major behavior changes, unless I made a mistake in the prepocessing function.
Let me know if I should build out more thorough testing tools, and what inference best practices might be around them! At some point I would love to introduce more thorough testing anyway as I have a lot of ideas for further speedups that DO meaningfully change the behavior of the model.
Any specific deployment considerations
It's possible that different versions of torch could introduce issues, although I did some testing to that effect. Additionally, different hardware targets may have different behaviors, but torch SHOULD abstract most breaking issues away.
Docs
no updates to docs