diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6417c3197..511e6c69b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -149,7 +149,14 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") - parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") + + parser.add_argument( + "--output_constraint_mode", + type=str, + choices=["outlines", "xgrammar", "none"], + default="none", + help="set the output constraint backend, none means no output constraint", + ) parser.add_argument( "--first_token_constraint_mode", action="store_true", diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 3fbac97fd..342f6f7b9 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -47,9 +47,11 @@ def __init__( # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty input_penalty: bool = DEFAULT_INPUT_PENALTY, regular_constraint: Optional[str] = None, # Regular expressions constrain the output. + guided_grammar: Optional[str] = None, # EBNF constrain the output. + guided_json: Optional[Union[str, dict]] = None, # JSON schema constrain the output. # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. - # allowed_token_ids only can be used in "--simple_constraint_mode" started server. + # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, @@ -81,6 +83,8 @@ def __init__( self.add_spaces_between_special_tokens = add_spaces_between_special_tokens self.print_eos_token = print_eos_token self.regular_constraint = regular_constraint + self.guided_grammar = guided_grammar + self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node @@ -257,6 +261,8 @@ def to_dict(self): ret["best_of"] = self.best_of ret["input_penalty"] = self.input_penalty ret["regular_constraint"] = self.regular_constraint + ret["guided_grammar"] = self.guided_grammar + ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 2c041c570..959b2b629 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -12,6 +12,8 @@ ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256)) MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10)) REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) +GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) +JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) class StopSequence(ctypes.Structure): @@ -76,7 +78,7 @@ def to_list(self): class RegularConstraint(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("constraint", ctypes.c_byte * REGULAR_CONSTRAINT_MAX_LENGTH), + ("constraint", ctypes.c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH), ("length", ctypes.c_int), ] @@ -98,6 +100,66 @@ def to_str(self): return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") +class GuidedGrammar(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("constraint", ctypes.c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH), + ("length", ctypes.c_int), + ] + + def initialize(self, constraint: str, tokenizer): + constraint_bytes = constraint.encode("utf-8") + assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long." + + ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) + self.length = len(constraint_bytes) + try: + if self.length > 0: + import xgrammar as xgr + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + xgrammar_compiler.compile_grammar(constraint) + except Exception as e: + raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}") + return + + def to_str(self): + if self.length == 0: + return "" + return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") + + +class GuidedJsonSchema(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("constraint", ctypes.c_ubyte * JSON_SCHEMA_MAX_LENGTH), + ("length", ctypes.c_int), + ] + + def initialize(self, constraint: str, tokenizer): + constraint_bytes = constraint.encode("utf-8") + assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long." + + ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) + self.length = len(constraint_bytes) + try: + if self.length > 0: + import xgrammar as xgr + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + xgrammar_compiler.compile_json_schema(constraint) + except Exception as e: + raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}") + return + + def to_str(self): + if self.length == 0: + return "" + return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") + + class AllowedTokenIds(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -191,9 +253,11 @@ class SamplingParams(ctypes.Structure): # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty ("input_penalty", ctypes.c_bool), ("regular_constraint", RegularConstraint), + ("guided_grammar", GuidedGrammar), + ("guided_json", GuidedJsonSchema), # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. - # allowed_token_ids only can be used in "--simple_constraint_mode" started server. + # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), @@ -251,6 +315,16 @@ def init(self, tokenizer, **kwargs): self.regular_constraint = RegularConstraint() self.regular_constraint.initialize(regular_constraint) + # Initialize guided_grammar + guided_grammar = kwargs.get("guided_grammar", "") + self.guided_grammar = GuidedGrammar() + self.guided_grammar.initialize(guided_grammar, tokenizer) + + # Initialize guided_json + guided_json = kwargs.get("guided_json", "") + self.guided_json = GuidedJsonSchema() + self.guided_json.initialize(guided_json, tokenizer) + # Initialize stop_sequence_groups stop_sequences = kwargs.get("stop_sequences", []) self.stop_sequences = StopSequenceGroups() @@ -316,13 +390,26 @@ def verify(self): ) self._verify_allowed_token_ids() + self._verify_grammar_constraint() return + def _verify_grammar_constraint(self): + if self.guided_grammar.length != 0: + if self.regular_constraint.length != 0: + raise ValueError("guided_grammar and regular_constraint can not be used in same time") + if self.guided_json.length != 0: + raise ValueError("guided_grammar and guided_json can not be used in same time") + return + def _verify_allowed_token_ids(self): if self.allowed_token_ids.size != 0: if self.regular_constraint.length != 0: raise ValueError("allowed_token_ids and regular_constraint can not be used in same time") + if self.guided_grammar.length != 0: + raise ValueError("allowed_token_ids and guided_grammar can not be used in same time") + if self.guided_json.length != 0: + raise ValueError("allowed_token_ids and guided_json can not be used in same time") return def to_dict(self): @@ -342,6 +429,8 @@ def to_dict(self): "best_of": self.best_of, "input_penalty": self.input_penalty, "regular_constraint": self.regular_constraint.to_str(), + "guided_grammar": self.guided_grammar.to_str(), + "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 32898a668..7c2e3c6c2 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -41,7 +41,7 @@ class StartArgs: enable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - simple_constraint_mode: bool = field(default=False) + output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) cache_capacity: int = field(default=200) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 85d4bc78f..d2a1e84f3 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,7 +7,7 @@ import collections from dataclasses import dataclass, field -from typing import List, Dict, Tuple, Optional, Any +from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager @@ -194,10 +194,15 @@ def __init__( # output constraint states self.regular_constraint = self.shm_param.regular_constraint.to_str() + self.guided_grammar = self.shm_param.guided_grammar.to_str() + self.guided_json = self.shm_param.guided_json.to_str() if len(self.regular_constraint) == 0: self.regular_constraint = None + if len(self.guided_grammar) == 0: + self.guided_grammar = None + if len(self.guided_json) == 0: + self.guided_json = None - self.regex_guide = None self.fsm_current_state: int = 0 self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list() if len(self.allowed_token_ids) == 0: @@ -217,7 +222,12 @@ def __init__( return def has_constraint_setting(self) -> bool: - return self.regular_constraint is not None or self.allowed_token_ids is not None + return ( + self.regular_constraint is not None + or self.allowed_token_ids is not None + or self.guided_grammar is not None + or self.guided_json is not None + ) class InferReq: diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 88edb07e5..509f1dd70 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -4,8 +4,9 @@ from .chunked_prefill.impl import ChunkedPrefillBackend from .diverse_backend.impl import DiversehBackend from .continues_batch.impl_for_token_healing import TokenHealingBackend -from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend +from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .dp_backend.impl import DPBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode +from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py similarity index 99% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py rename to lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py index 963cf4ff8..00af16be4 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py @@ -14,7 +14,7 @@ logger = init_logger(__name__) -class SimpleConstraintBackend(ContinuesBatchBackend): +class OutlinesConstraintBackend(ContinuesBatchBackend): def __init__(self) -> None: super().__init__() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py new file mode 100644 index 000000000..3d880614c --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -0,0 +1,137 @@ +import os +import shutil +import torch +from typing import List, Tuple + +from .impl import ContinuesBatchBackend +from .pre_process import prepare_prefill_inputs, prepare_decode_inputs +from .post_process import sample + +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.core.objs import FinishStatus +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams +from lightllm.server.tokenizer import get_tokenizer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class XgrammarBackend(ContinuesBatchBackend): + def __init__(self) -> None: + super().__init__() + + def init_custom(self): + import xgrammar as xgr + + self.tokenizer = get_tokenizer( + self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + ) + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) + self.xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + self.xgrammar_token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + + eos_token_ids = [] + eos_token_ids.append(self.tokenizer.eos_token_id) + eos_token_ids.extend(self.args.eos_id) + return + + @calculate_time(show=False, min_cost_ms=300) + def prefill(self, reqs: List[Tuple]): + import xgrammar as xgr + + req_ids = self._init_reqs(reqs) + kwargs, run_reqs = prepare_prefill_inputs(req_ids, is_multimodal=self.is_multimodal) + + logics = self.model.forward(**kwargs) + + for i, run_obj in enumerate(run_reqs): + run_obj: InferReq = run_obj + sample_params = run_obj.sampling_param + if sample_params.guided_grammar is not None: + xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + elif sample_params.guided_json is not None: + xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + self._mask_req_out_token(i, run_obj, logics[i]) + + # fix the logics with -inf to a large negative value + logics[logics == float("-inf")] = -1000000.0 + + next_token_ids, next_token_probs = sample(logics, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self.post_handel(run_reqs, next_token_ids, next_token_logprobs) + + return + + @calculate_time(show=True, min_cost_ms=200) + def decode(self): + import xgrammar as xgr + + kwargs, run_reqs = prepare_decode_inputs(g_infer_context.infer_req_ids) + run_reqs: List[InferReq] = run_reqs + + logits = self.model.forward(**kwargs) + + all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) + if not all_has_no_constraint: + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, logits[i]) + + logits[logits == float("-inf")] = -1000000.0 + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self.post_handel(run_reqs, next_token_ids, next_token_logprobs) + return + + def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logprobs): + import xgrammar as xgr + + finished_req_ids = [] + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = req_obj.get_cur_total_len() + + req_obj.set_next_gen_token_id(next_token_id, next_token_logprob) + req_obj.cur_output_len += 1 + + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + matcher = req_obj.sampling_param.xgrammar_matcher + assert matcher.accept_token(next_token_id) + + if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted or matcher.is_terminated(): + finished_req_ids.append(req_obj.shm_req.request_id) + + if self.tp_rank < self.dp_size: + # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 + # finish_token_index finish_status candetoken_out_len 是 + # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 + req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len + req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len + + if req_obj.finish_status.is_finished(): + req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 + req_obj.shm_req.finish_status = req_obj.finish_status + + req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len + + g_infer_context.filter(finished_req_ids) + return + + def _mask_req_out_token(self, i, run_obj: InferReq, logits): + import xgrammar as xgr + + sample_params = run_obj.sampling_param + if sample_params.guided_grammar is not None or sample_params.guided_json is not None: + sample_params.xgrammar_matcher.fill_next_token_bitmask(self.xgrammar_token_bitmask) + xgr.apply_token_bitmask_inplace(logits, self.xgrammar_token_bitmask.to(logits.device)) + return diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 88fb45878..2a63e6b21 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -13,7 +13,8 @@ DiversehBackend, RewardModelBackend, TokenHealingBackend, - SimpleConstraintBackend, + OutlinesConstraintBackend, + XgrammarBackend, FirstTokenConstraintBackend, ContinuesBatchBackendForPrefillNode, ContinuesBatchBackendForDecodeNode, @@ -106,11 +107,16 @@ def init_model(self, kvargs): is_token_healing = kvargs.get("is_token_healing", False) is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: - is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode + is_outlines_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" + is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" + assert not ( + is_outlines_constraint_mode and is_xgrammar_constraint_mode + ), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" else: - is_simple_constraint_mode = False + is_outlines_constraint_mode = False + is_xgrammar_constraint_mode = False is_prefill_node = False is_decode_node = False # use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) @@ -128,8 +134,10 @@ def init_model(self, kvargs): self.backend = DiversehBackend() elif is_token_healing: self.backend = TokenHealingBackend() - elif is_simple_constraint_mode: - self.backend = SimpleConstraintBackend() + elif is_outlines_constraint_mode: + self.backend = OutlinesConstraintBackend() + elif is_xgrammar_constraint_mode: + self.backend = XgrammarBackend() elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() elif kvargs.get("dp_size", 1) > 1: diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index fb61f42d5..793b81662 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -15,7 +15,7 @@ def build_req_queue(args, router, dp_size: int): queue_class = ChunkedPrefillQueue if args.token_healing_mode: queue_class = ContinuesBatchQueue - if args.simple_constraint_mode: + if args.output_constraint_mode != "none": queue_class = ContinuesBatchQueue if args.first_token_constraint_mode: queue_class = ContinuesBatchQueue diff --git a/test/format_out/test_constraint_server.py b/test/format_out/test_constraint_server.py new file mode 100644 index 000000000..62b622031 --- /dev/null +++ b/test/format_out/test_constraint_server.py @@ -0,0 +1,67 @@ +import time +import requests +import json +import threading + +""" +python -m lightllm.server.api_server --model_dir /Meta-Llama-3-8B-Instruct \ + --host 0.0.0.0 \ + --port 8017 \ + --tp 1 \ + --max_total_token_num 100000 \ + --simple_constraint_mode \ + --use_dynamic_prompt_cache +""" + + +class RequestThread(threading.Thread): + def __init__(self, url, headers, data): + threading.Thread.__init__(self) + self.url = url + self.headers = headers + self.data = data + + def run(self): + response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) + if response.status_code == 200: + print(response.json()) + else: + print("Error:", response.status_code, response.text) + + +url = "http://localhost:8017/generate" +headers = {"Content-Type": "application/json"} + +for i in range(1): + data = { + "inputs": "(100+1+3)*2=", + # 'temperature': 0.1, + "parameters": {"do_sample": False, "regular_constraint": r"-?\d+"}, + } + thread = RequestThread(url, headers, data) + thread.start() + +time.sleep(2) + +for i in range(20): + data = { + "inputs": "Are dog a man? ", + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": 200, + "regular_constraint": r"(Yes|No) Reason is [a-zA-Z\s]+", + }, + } + thread = RequestThread(url, headers, data) + thread.start() + +time.sleep(10) + +for i in range(20): + data = { + "inputs": "Are dog a man? ", + "parameters": {"do_sample": False, "ignore_eos": True, "max_new_tokens": 200, "allowed_token_ids": [2, 3]}, + } + thread = RequestThread(url, headers, data) + thread.start() diff --git a/test/format_out/test_xgrammar_constraint.py b/test/format_out/test_xgrammar_constraint.py new file mode 100644 index 000000000..67490dd31 --- /dev/null +++ b/test/format_out/test_xgrammar_constraint.py @@ -0,0 +1,138 @@ +import time +import requests +import json +import threading +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("/mnt/nvme0/models/Meta-Llama-3.1-8B-Instruct") + + +class RequestThread(threading.Thread): + def __init__(self, url, headers, data): + threading.Thread.__init__(self) + self.url = url + self.headers = headers + self.data = data + + def run(self): + response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) + if response.status_code == 200: + print(response.json()) + else: + print("Error:", response.status_code, response.text) + + +url = "http://0.0.0.0:8888/generate" +headers = {"Content-Type": "application/json"} +json_grammar_ebnf_str = r""" +root ::= basic_array | basic_object +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +""" + +json_schema_str = r""" +{ + "type": "array", + "items": { + "type": "object", + "properties": { + "金额": { + "type": "number" + }, + "标题": { + "type": "string" + }, + "类型": { + "type": "string" + }, + "大类": { + "type": "string" + }, + "小类": { + "type": "string" + }, + "日期": { + "type": "string" + }, + "时间": { + "type": "string" + } + }, + "required": [ + "金额", + "标题", + "类型", + "大类", + "小类", + "时间" + ] + } +} +""" + +person_schema = r"""{ + "title": "Person", + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer", + } + }, + "required": ["name", "age"] +} +""" + +system_prompt = open("system.md", "r").read() +user_input = open("user.md", "r").read() + +# user_input = """generate a person information for me, for example, {'name': 'John', 'age': 25}.""" + +messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_input}, +] + +inputs = tokenizer.apply_chat_template(messages, tokenize=False) + +for i in range(1): + data = { + "inputs": inputs, + # 'temperature': 0.1, + "parameters": { + "do_sample": False, + # "guided_json": json_schema_str, + "max_new_tokens": 200, + }, + } + thread = RequestThread(url, headers, data) + thread.start() + +# time.sleep(2) + +# for i in range(20): +# data = { +# "inputs": "12-(25+16)*7=", +# "parameters": { +# "do_sample": False, +# "ignore_eos": True, +# "max_new_tokens": 200, +# "guided_grammar": r"""root ::= (expr "=" term)+ +# expr ::= term ([-+*/] term)* +# term ::= num | "(" expr ")" +# num ::= [0-9]+""", +# }, +# } +# thread = RequestThread(url, headers, data) +# thread.start() diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index 8cb89f681..489f8ae34 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -7,11 +7,31 @@ ExponentialDecayLengthPenalty, DecodeNode, SamplingParams, + GuidedGrammar, + GuidedJsonSchema, STOP_SEQUENCE_MAX_LENGTH, REGULAR_CONSTRAINT_MAX_LENGTH, ALLOWED_TOKEN_IDS_MAX_LENGTH, ) +grammar_str = r"""root ::= (expr "=" term)+ +expr ::= term ([-+*/] term)* +term ::= num | "(" expr ")" +num ::= [0-9]+""" + +schema_str = r"""{ + "type": "array", + "items": { + "type": "object", + "properties": { + "Title": {"type": "string"}, + "Date": {"type": "string"}, + "Time": {"type": "string"} + }, + "required": ["Title", "Time", "Date"] + } +}""" + @pytest.mark.parametrize( "sequence, expected", @@ -58,6 +78,24 @@ def test_regular_constraint_initialization(): constraint.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) +def test_guided_grammar_initialization(): + grammar = GuidedGrammar() + grammar.initialize(grammar_str) + assert grammar.to_str() == grammar_str + + with pytest.raises(AssertionError): + grammar.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + + +def test_guided_json_schema_initialization(): + schema = GuidedJsonSchema() + schema.initialize(schema_str) + assert schema.to_str() == schema_str + + with pytest.raises(AssertionError): + schema.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + + def test_allowed_token_ids_initialization(): allowed_ids = AllowedTokenIds() allowed_ids.initialize([1, 2, 3])