Skip to content

On Device Sampling #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft

Conversation

quic-sanising
Copy link

No description provided.

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-sanising quic-sanising marked this pull request as ready for review April 9, 2025 04:48
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-amitraj quic-amitraj marked this pull request as draft April 11, 2025 08:49
@@ -75,7 +76,7 @@ def __repr__(self) -> str:

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make them optional parameters

@@ -1317,8 +1322,14 @@ def __init__(
if is_tlm:
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
self.model, transformed = SpDTransform.apply(self.model)
self.model.return_pdfs = True
Copy link
Contributor

@quic-hemagnih quic-hemagnih Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the code for handling is_tlm == FALSE condition for population of return_pdfs

dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80
dynamic_axes["top_ps"] = {0: "batch_size"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define constants for 0.80 and 0.99


@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add doc string

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
* Initial commit

* Reformat code

* Fix bug

* Add Gumbel-Max trick based random sampling

* Bring up to date

* Use Gumbel-Max Trick based Random Sampling as default

* Clip k to max value

* Add docstring for sampling parameters

* Fix bug

* Add support for continuous batching

* Fix ONNX error for batch_size 1 treated as a Constant

* Undo docstring deletion

* Remove device and unncessary reshapes

* Revert batch_size to 1

* Remove vocab_size from dynamic axes

* Change condition

* Change size of each sampling parameter to (batch_size, 1)

* Reformat code

* Add optimizations

* Identify optimizations

* Fix bug

* Fix merge issue

* Optimizations:
Perform random sampling only on topk_values_asc
Only need logits for probs when self.return_pdfs is True

* Remove where clause for temperature

* Remove boolean type casting for retain state

* Always return next_tokens

* Fix bug

* Reformat code

* Initialize retain states

* Optimize imports

* Remove torch.index_select()

* Change dtype of penalty buffers to bool

---------

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants