88import importlib .metadata
99import json
1010import os
11- from typing import Optional , Union
11+ from typing import Union
1212
1313import regex as re
1414import torch
3636class BaseLogitsProcessor :
3737
3838 def __init__ (self , guide : Guide , eos_token_id : int ,
39- reasoner : Optional [ ReasoningParser ] ) -> None :
39+ reasoner : ReasoningParser | None ) -> None :
4040 self ._guide : Guide = guide
4141 self ._eos_token_id : int = eos_token_id
42- self ._reasoner : Optional [ ReasoningParser ] = reasoner
43- self ._mask : Optional [ torch .Tensor ] = None
42+ self ._reasoner : ReasoningParser | None = reasoner
43+ self ._mask : torch .Tensor | None = None
4444
4545 def __call__ (self , input_ids : list [int ],
4646 scores : torch .Tensor ) -> torch .Tensor :
@@ -114,7 +114,7 @@ def _get_guide(cls, regex_string: str,
114114 return Guide (index )
115115
116116 def __init__ (self , regex_string : str , tokenizer : PreTrainedTokenizerBase ,
117- reasoner : Optional [ ReasoningParser ] ) -> None :
117+ reasoner : ReasoningParser | None ) -> None :
118118 super ().__init__ (
119119 guide = RegexLogitsProcessor ._get_guide (regex_string , tokenizer ),
120120 eos_token_id = tokenizer .eos_token_id , # type: ignore
@@ -126,7 +126,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
126126 def __init__ (self , schema : Union [str , dict , BaseModel ],
127127 tokenizer : PreTrainedTokenizerBase ,
128128 whitespace_pattern : Union [str , None ],
129- reasoner : Optional [ ReasoningParser ] ) -> None :
129+ reasoner : ReasoningParser | None ) -> None :
130130
131131 if isinstance (schema , type (BaseModel )):
132132 schema_str = json .dumps (schema .model_json_schema ())
0 commit comments