@@ -120,6 +120,8 @@ class Llama4Tokenizer(ModelTokenizer, Transform):
120120            - Model-specific templates that are required whenever the model is prompted, such as the [INST] 
121121              tags in Llama2 and in Mistral 
122122            - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` 
123+         truncation_type (str): type of truncation to apply, either "left" or "right". 
124+             Default is "right". 
123125
124126            The extra text will still get tokenized as normal text, not as special tokens. Default is None. 
125127
@@ -136,6 +138,7 @@ def __init__(
136138        special_tokens : Optional [dict [str , int ]] =  None ,
137139        max_seq_len : Optional [int ] =  None ,
138140        prompt_template : Optional [PromptTemplateInterface ] =  None ,
141+         truncation_type : str  =  "right" ,
139142    ):
140143        self .special_tokens  =  (
141144            special_tokens  if  special_tokens  is  not   None  else  LLAMA4_SPECIAL_TOKENS 
@@ -188,6 +191,8 @@ def __init__(
188191            r"<\|header_start\|>.*?<\|header_end\|>\n\n" 
189192        )
190193
194+         self .truncation_type  =  truncation_type 
195+ 
191196    def  _validate_special_tokens (
192197        self ,
193198    ):
@@ -420,9 +425,17 @@ def tokenize_messages(
420425
421426        if  self .max_seq_len :
422427            tokens  =  truncate (
423-                 tokens , self .max_seq_len , self .eos_id  if  add_end_tokens  else  None 
428+                 tokens = tokens ,
429+                 max_seq_len = self .max_seq_len ,
430+                 eos_id = self .eos_id  if  add_end_tokens  else  None ,
431+                 truncation_type = self .truncation_type ,
432+             )
433+             mask  =  truncate (
434+                 tokens = mask ,
435+                 max_seq_len = self .max_seq_len ,
436+                 eos_id = True  if  add_end_tokens  else  None ,
437+                 truncation_type = self .truncation_type ,
424438            )
425-             mask  =  truncate (mask , self .max_seq_len , True  if  add_end_tokens  else  None )
426439
427440        return  tokens , mask 
428441
0 commit comments