Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Advance parser concurrently with model forward pass #1065

Merged
merged 22 commits into from
Nov 10, 2024

Conversation

hudson-ai
Copy link
Collaborator

@hudson-ai hudson-ai commented Oct 28, 2024

Refactors Engine.__call__ and TokenParser._parse coroutine/generator to yield a future wrapping LLInterpreter.mid_process (running in a ThreadPoolExecutor) such that mid_process can run concurrently with the forward pass. This function comes directly from the rust extension, so it releases the GIL and threading should be sufficient to ensure true concurrency.

@lochuynh1412 I stole a bit of your code from the interactivity overhaul in order to simplify Engine.__call__. I'm definitely introducing a merge conflict for that PR -- happy to help resolve it :)

@hudson-ai hudson-ai changed the title [Optimization] [Optimization] Advance parser concurrently with model forward pass Oct 28, 2024
@hudson-ai
Copy link
Collaborator Author

@Harsha-Nori ran some quick benchmarks to make sure the parallelization was actually working. I ran a JSON generation task (RPG characters) with two different models. Statistics are reported in milliseconds per token.

Meta-Llama-3.1-8B-Instruct-Q8_0

266 tokens (ran at zero temperature, so this was consistent across runs)

branch count mean std min 25% 50% 75% max
main 100 16.4348 0.0302434 16.391 16.4157 16.4261 16.4468 16.5256
parallel_parser 100 15.1874 0.0314103 15.1202 15.1659 15.1843 15.2038 15.2796

95% CI for difference in means: (1.239, 1.256)
So there is about a 1.25ms decrease in the time to generate each token. Sounds marginal, but in this case that's about an 8% difference.

Phi-3-mini-4k-instruct-q4

103 tokens (ran at zero temperature, so this was consistent across runs)

branch count mean std min 25% 50% 75% max
main 100 5.01469 0.276793 4.91669 4.94167 4.95379 4.97787 6.56915
parallel_parser 100 4.95095 0.271921 4.86816 4.88707 4.89749 4.90545 6.42321

95% CI for difference in means: (-0.010, 0.141)
There's little to no speedup in this case, maybe due to the low number of total tokens generated or just due to the much quicker forward pass time overall.

Conclusions

We see some slight speedups, and any threading overhead seems pretty negligible. I feel comfortable pushing this PR forward :)

@hudson-ai
Copy link
Collaborator Author

Tests currently failing because a consequence of this PR is that we do an additional forward pass at the end of generation because we do get_logits before we've heard back from the parser (which is what will tell us that we're "done"). @mmoskal would it be possible to know whether we're done in the last "post_process" step, or is this a hard limitation?

@mmoskal
Copy link
Collaborator

mmoskal commented Oct 30, 2024

@hudson-ai let me know if this helps: guidance-ai/llguidance@e9cfc18

@hudson-ai
Copy link
Collaborator Author

@mmoskal thanks a bunch!

We can now often (not always, mind you) prevent an unnecessary forward pass if the next call to the parser is going to give us the stop signal. The smoke tests that previously failed happen to be such cases, so they will now pass. Just note that changing the grammars inside of them may cause this to change.

@codecov-commenter
Copy link

codecov-commenter commented Oct 31, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 88.63636% with 10 lines in your changes missing coverage. Please review.

Project coverage is 65.36%. Comparing base (0ace873) to head (42a5e5e).

Files with missing lines Patch % Lines
guidance/_parser.py 90.38% 5 Missing ⚠️
guidance/models/_model.py 85.29% 5 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1065      +/-   ##
==========================================
- Coverage   66.48%   65.36%   -1.12%     
==========================================
  Files          65       65              
  Lines        5102     5140      +38     
==========================================
- Hits         3392     3360      -32     
- Misses       1710     1780      +70     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# Upstairs should have already waited on this future
mask, _ = mid_process_future.result()

if mask is None:
if token is not None:
raise TokenParserException(f"Expected None, got token {token}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if mask is None, isn't it in accepting mode? any tokens should be accepted?

Copy link
Collaborator Author

@hudson-ai hudson-ai Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask should never be none unless the parser is actually done (i.e. we should not be accepting ANY tokens, as the loop should be stopping). This condition should be equivalent to ll_response.stop if we were to parse the string in the second slot of the future above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the .cleanup code, which is currently responsible for sending the final None token to get the generator loop to break. Let me know if you have any better ideas on how to structure it!

gen_data = None
token = yield (gen_data, response)
# Upstairs should have already waited on this future
mask, _ = mid_process_future.result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we get LLInterpreterResponse.model_validate_json(ll_response_str) here instead of just the string?
Because both parser and engine check ll_response.stop to break, why we need cleanup function?

Copy link
Collaborator Author

@hudson-ai hudson-ai Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we get LLInterpreterResponse.model_validate_json(ll_response_str) here instead of just the string?

We could add the pydantic validation into the thread to make sure the object-version of the string is always returned from the future -- I suppose I was feeling cautious that adding some CPU work into the thread might block the forward pass and slow things down because ya know, GIL. But I think this was unfounded since we have to wait on that work regardless. I can try this and see if it affects timings at all.

Because both parser and engine check ll_response.stop to break, why we need cleanup function?

This is definitely a bit annoying... The parser loop isn't running while the "upstairs" caller is running -- i.e. it can't even check the value of ll_response.stop until the caller sends it a final None. Technically, we can just abandon the generator before it terminates, but that puts us in the confusing situation where not all the code in the parser actually runs. The cleanup exists out of an abundance of caution, just making sure that the parser generator finishes, doing any final checks/validation on state as we do so. Happy to jump on a call to discuss.

Thank you for looking this over!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add the pydantic validation into the thread to make sure the object-version of the string is always returned from the future -- I suppose I was feeling cautious that adding some CPU work into the thread might block the forward pass and slow things down because ya know, GIL. But I think this was unfounded since we have to wait on that work regardless. I can try this and see if it affects timings at all.

Yeah, we can try. guess it'll make the code a bit cleaner. I don't think it will consume a lot of cpu cycles to validate the JSON string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really just on the order of 1-3 hundredths of a millisecond to parse the string. Should be fine to throw it in the thread (again, the engine call would have had to do it sequentially anyway, so this may very well cost literally nothing and give us some cleaner code)

@Harsha-Nori Harsha-Nori merged commit b6bcee7 into guidance-ai:main Nov 10, 2024
98 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants