Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent pad ids, special tokens displaying in generate #1211

Merged
merged 6 commits into from
Aug 5, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jul 23, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Pad ID is implicitly assumed to be 0 in utils.generate, and any instances of "0" in the generated tokens is replaced by the tokenizer's pad id. This is very apparent for Llama3 which has a reserved special token for pad id:

Tell me a joke, please. I could use a laugh.
Here's one: Why couldn't the bicycle stand up by itself?

Because it was two-tired!

Hope that brought a smile to your face<|finetune_right_pad_id|> Do you have a favorite joke you'd like to share? I'd love to hear it<|finetune_right_pad_id|> 

(And if you're feeling extra playful, I can try to come up with a joke on the spot. Just let me know!)<|eot_id|>

Here, I simply remove the logic of replacing 0 with pad id. I additionally filter out any special tokens when decoding with the TikToken base tokenizer, with an option to let them be displayed in the decoded string. With these changes, the same generated output for Llama3 is now cleaned up:

Tell me a joke, please. I could use a laugh.
Here's one: Why couldn't the bicycle stand up by itself?

Because it was two-tired!

Hope that brought a smile to your face! Do you have a favorite joke you'd like to share? I'd love to hear it! 

(And if you're feeling extra playful, I can try to come up with a joke on the spot. Just let me know!)

Notably, the 0 ID for Llama3 is the token for "!"

Test plan

tune run generate --config generation for Llama3, which uses tiktoken. Also tried other models that use SentencePiece like Phi3 and confirmed it does not impact any behavior.

Copy link

pytorch-bot bot commented Jul 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1211

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cc6ddbd with merge base 5825500 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 23, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.

Project coverage is 26.73%. Comparing base (6e4809a) to head (5cf5fff).
Report is 4 commits behind head on main.

Files Patch % Lines
torchtune/modules/tokenizers/_tiktoken.py 0.00% 2 Missing ⚠️
torchtune/models/llama3/_tokenizer.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (6e4809a) and HEAD (5cf5fff). Click for more details.

HEAD has 5 uploads less than BASE
Flag BASE (6e4809a) HEAD (5cf5fff)
6 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1211       +/-   ##
===========================================
- Coverage   68.66%   26.73%   -41.94%     
===========================================
  Files         215      219        +4     
  Lines        9734     9860      +126     
===========================================
- Hits         6684     2636     -4048     
- Misses       3050     7224     +4174     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -153,5 +156,11 @@ def decode(
k = None
if k:
token_ids = token_ids[:k]
token_ids = [token_id for token_id in token_ids if token_id != self.bos_id]
if not show_special:
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly far out, but could we move the logic from here to the generation recipe and move the flag for show_special there, too? Unless we anticipate there's other places we could use this logic, but even if so, would it make sense to throw it in a separate utility function instead, with its own tests?

This way the tokenizers remain agnostic to how they are being used, and all calls to tokenizer.decode will behave identically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean, but I do like keeping this with the tokenizer so I can quickly debug even in the middle of the recipe by calling tokenizer.decode(sample, show_special=True). I find myself doing this a lot when debugging in general, and it's especially useful to see the exact text that's going into the model . If it was a separate utility it wouldn't be straightforward to do this in a pdb debugger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, yeah, IGYWM.

@@ -134,6 +134,7 @@ def decode(
self,
token_ids: List[int],
truncate_at_eos: bool = True,
show_special: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: show_special_tokens pls

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Aug 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi
Copy link
Collaborator

One nit; lgtm

@RdoubleA RdoubleA merged commit 8519c35 into pytorch:main Aug 5, 2024
29 checks passed
@RdoubleA RdoubleA deleted the fix_inference_pad branch August 5, 2024 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants