-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Optimization] Advance parser concurrently with model forward pass #1065
Conversation
Co-authored-by: Loc Huynh <lohuynh@microsoft.com>
…ration but the final one
@Harsha-Nori ran some quick benchmarks to make sure the parallelization was actually working. I ran a JSON generation task (RPG characters) with two different models. Statistics are reported in milliseconds per token. Meta-Llama-3.1-8B-Instruct-Q8_0266 tokens (ran at zero temperature, so this was consistent across runs)
95% CI for difference in means: (1.239, 1.256) Phi-3-mini-4k-instruct-q4103 tokens (ran at zero temperature, so this was consistent across runs)
95% CI for difference in means: (-0.010, 0.141) ConclusionsWe see some slight speedups, and any threading overhead seems pretty negligible. I feel comfortable pushing this PR forward :) |
Tests currently failing because a consequence of this PR is that we do an additional forward pass at the end of generation because we do |
@hudson-ai let me know if this helps: guidance-ai/llguidance@e9cfc18 |
@mmoskal thanks a bunch! We can now often (not always, mind you) prevent an unnecessary forward pass if the next call to the parser is going to give us the |
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #1065 +/- ##
==========================================
- Coverage 66.48% 65.36% -1.12%
==========================================
Files 65 65
Lines 5102 5140 +38
==========================================
- Hits 3392 3360 -32
- Misses 1710 1780 +70 ☔ View full report in Codecov by Sentry. |
# Upstairs should have already waited on this future | ||
mask, _ = mid_process_future.result() | ||
|
||
if mask is None: | ||
if token is not None: | ||
raise TokenParserException(f"Expected None, got token {token}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if mask is None, isn't it in accepting mode? any tokens should be accepted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask should never be none unless the parser is actually done (i.e. we should not be accepting ANY tokens, as the loop should be stopping). This condition should be equivalent to ll_response.stop if we were to parse the string in the second slot of the future above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note the .cleanup code, which is currently responsible for sending the final None token to get the generator loop to break. Let me know if you have any better ideas on how to structure it!
guidance/_parser.py
Outdated
gen_data = None | ||
token = yield (gen_data, response) | ||
# Upstairs should have already waited on this future | ||
mask, _ = mid_process_future.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we get LLInterpreterResponse.model_validate_json(ll_response_str) here instead of just the string?
Because both parser and engine check ll_response.stop to break, why we need cleanup function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we get LLInterpreterResponse.model_validate_json(ll_response_str) here instead of just the string?
We could add the pydantic validation into the thread to make sure the object-version of the string is always returned from the future -- I suppose I was feeling cautious that adding some CPU work into the thread might block the forward pass and slow things down because ya know, GIL. But I think this was unfounded since we have to wait on that work regardless. I can try this and see if it affects timings at all.
Because both parser and engine check ll_response.stop to break, why we need cleanup function?
This is definitely a bit annoying... The parser loop isn't running while the "upstairs" caller is running -- i.e. it can't even check the value of ll_response.stop until the caller sends it a final None
. Technically, we can just abandon the generator before it terminates, but that puts us in the confusing situation where not all the code in the parser actually runs. The cleanup exists out of an abundance of caution, just making sure that the parser generator finishes, doing any final checks/validation on state as we do so. Happy to jump on a call to discuss.
Thank you for looking this over!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add the pydantic validation into the thread to make sure the object-version of the string is always returned from the future -- I suppose I was feeling cautious that adding some CPU work into the thread might block the forward pass and slow things down because ya know, GIL. But I think this was unfounded since we have to wait on that work regardless. I can try this and see if it affects timings at all.
Yeah, we can try. guess it'll make the code a bit cleaner. I don't think it will consume a lot of cpu cycles to validate the JSON string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's really just on the order of 1-3 hundredths of a millisecond to parse the string. Should be fine to throw it in the thread (again, the engine call would have had to do it sequentially anyway, so this may very well cost literally nothing and give us some cleaner code)
Refactors
Engine.__call__
andTokenParser._parse
coroutine/generator to yield a future wrappingLLInterpreter.mid_process
(running in aThreadPoolExecutor
) such thatmid_process
can run concurrently with the forward pass. This function comes directly from the rust extension, so it releases the GIL and threading should be sufficient to ensure true concurrency.@lochuynh1412 I stole a bit of your code from the interactivity overhaul in order to simplify
Engine.__call__
. I'm definitely introducing a merge conflict for that PR -- happy to help resolve it :)