-
Notifications
You must be signed in to change notification settings - Fork 43
On Device Sampling #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
On Device Sampling #350
Conversation
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
9fab549
to
3b63ecb
Compare
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@@ -75,7 +76,7 @@ def __repr__(self) -> str: | |||
|
|||
@classmethod | |||
@with_replaced_quantizers | |||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): | |||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make them optional parameters
@@ -1317,8 +1322,14 @@ def __init__( | |||
if is_tlm: | |||
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch | |||
self.model, transformed = SpDTransform.apply(self.model) | |||
self.model.return_pdfs = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the code for handling is_tlm == FALSE condition for population of return_pdfs
dynamic_axes["top_ks"] = {0: "batch_size"} | ||
|
||
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80 | ||
dynamic_axes["top_ps"] = {0: "batch_size"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define constants for 0.80 and 0.99
|
||
@classmethod | ||
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: | ||
transformed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add doc string
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
* Initial commit * Reformat code * Fix bug * Add Gumbel-Max trick based random sampling * Bring up to date * Use Gumbel-Max Trick based Random Sampling as default * Clip k to max value * Add docstring for sampling parameters * Fix bug * Add support for continuous batching * Fix ONNX error for batch_size 1 treated as a Constant * Undo docstring deletion * Remove device and unncessary reshapes * Revert batch_size to 1 * Remove vocab_size from dynamic axes * Change condition * Change size of each sampling parameter to (batch_size, 1) * Reformat code * Add optimizations * Identify optimizations * Fix bug * Fix merge issue * Optimizations: Perform random sampling only on topk_values_asc Only need logits for probs when self.return_pdfs is True * Remove where clause for temperature * Remove boolean type casting for retain state * Always return next_tokens * Fix bug * Reformat code * Initialize retain states * Optimize imports * Remove torch.index_select() * Change dtype of penalty buffers to bool --------- Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
50953e2
to
d48d084
Compare
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
No description provided.