-
Notifications
You must be signed in to change notification settings - Fork 44
On Device Sampling #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
quic-sanising
wants to merge
31
commits into
quic:main
Choose a base branch
from
quic-sanising:on-device-sampling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
On Device Sampling #350
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
718d763
Initial commit
quic-sanising b8d099e
Reformat code
quic-sanising 544c0dd
Fix bug
quic-sanising 0b4d0a9
Add Gumbel-Max trick based random sampling
quic-sanising 24efc93
Bring up to date
quic-sanising 2af43c6
Use Gumbel-Max Trick based Random Sampling as default
quic-sanising 3eca771
Clip k to max value
quic-sanising b0e9162
Add docstring for sampling parameters
quic-sanising 0486e42
Fix bug
quic-sanising e7dda72
Add support for continuous batching
quic-sanising f94c657
Fix ONNX error for batch_size 1 treated as a Constant
quic-sanising fa026a4
Undo docstring deletion
quic-sanising eff2007
Remove device and unncessary reshapes
quic-sanising ebfbaea
Revert batch_size to 1
quic-sanising 83d33ac
Remove vocab_size from dynamic axes
quic-sanising fc3dc82
Change condition
quic-sanising abbaf53
Change size of each sampling parameter to (batch_size, 1)
quic-sanising f5f5e2d
Reformat code
quic-sanising 05c0bf0
Fix bug
quic-sanising 3b63ecb
Allow chunked prompts during prefill
quic-sanising 0b6873c
Merge remote-tracking branch 'upstream/main' into on-device-sampling
quic-sanising 1691a08
Add missing params
quic-sanising 02389f8
Update retain state names with past keyword
quic-sanising 7dfdda4
Add output_names for sampler
quic-sanising d48d084
Optimizations (#2)
quic-sanising bf367a6
Merge branch 'main' into on-device-sampling
quic-sanising aa7206d
Fix bugs
quic-sanising 9f2c061
Merge branch 'main' into on-device-sampling
quic-sanising c59e1ab
Add include_sampler check
quic-sanising 40f176e
Merge branch 'main' into on-device-sampling
quic-sanising 52a077f
Update doc-strings
quic-sanising File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,11 +267,11 @@ | |
QEffWhisperPositionalEmbedding, | ||
) | ||
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry | ||
from QEfficient.transformers.sampler.sampler import sampler_forward | ||
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward | ||
|
||
SPD_TARGET = "target" | ||
|
||
|
||
class CustomOpsTransform(ModuleMappingTransform): | ||
_module_mapping = { | ||
GemmaRMSNorm: GemmaCustomRMSNormAIC, | ||
|
@@ -456,6 +456,40 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) - | |
return model, transformed | ||
|
||
|
||
class SamplerTransform: | ||
""" | ||
``Mandatory`` Args: | ||
:model (nn.Module): PyTorch model. | ||
|
||
Returns: | ||
:model (nn.Module): PyTorch model. | ||
:transformed (bool): whether transformation was applied successfully. | ||
""" | ||
|
||
# supported architectures | ||
_module_mapping = { | ||
# Llama | ||
QEffLlamaForCausalLM, | ||
} | ||
|
||
@classmethod | ||
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: | ||
transformed = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add doc string |
||
if qaic_config is None or (include_sampler := qaic_config.get("include_sampler")) is None: | ||
return model, transformed | ||
elif not include_sampler: | ||
return model, transformed | ||
elif (model_class := model.__class__) in cls._module_mapping: | ||
model.forward = MethodType(sampler_forward, model) | ||
model.return_pdfs = qaic_config.get("return_pdfs", False) | ||
transformed = True | ||
else: | ||
raise NotImplementedError( | ||
f"model class {model_class} does not yet support returning multiple logits to keep." | ||
) | ||
return model, transformed | ||
|
||
|
||
class VlmKVOffloadTransform(ModuleMappingTransform): | ||
# supported architectures | ||
_module_mapping = { | ||
|
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define constants for 0.80 and 0.99