diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md
index e39bbacf1138..5c0c1762f8aa 100644
--- a/docs/source/features/reasoning_outputs.md
+++ b/docs/source/features/reasoning_outputs.md
@@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
 }
 ```
 
-Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
+Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
+
+## Limitations
+
+- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
+- It is not compatible with [`tool_calling`](#tool_calling).
+- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
 
 ## How to support a new reasoning model
 
@@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
         """
 ```
 
-After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
+Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
+
+```python
+@dataclass
+class DeepSeekReasoner(Reasoner):
+    """
+    Reasoner for DeepSeek R series models.
+    """
+    start_token_id: int
+    end_token_id: int
+
+    start_token: str = ""
+    end_token: str = ""
+
+    @classmethod
+    def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+        return cls(start_token_id=tokenizer.encode(
+            "", add_special_tokens=False)[0],
+                   end_token_id=tokenizer.encode("",
+                                                 add_special_tokens=False)[0])
+
+    def is_reasoning_end(self, input_ids: list[int]) -> bool:
+        return self.end_token_id in input_ids
+```
+
+The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
+
+Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
 
 ```bash
 vllm serve  \
     --enable-reasoning --reasoning-parser example
 ```
-
-## Limitations
-
-- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
-- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
-- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
new file mode 100644
index 000000000000..1f72e1164d42
--- /dev/null
+++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+An example shows how to generate structured outputs from reasoning models
+like DeepSeekR1. The thinking process will not be guided by the JSON
+schema provided by the user. Only the final output will be structured.
+
+To run this example, you need to start the vLLM server with the reasoning 
+parser:
+
+```bash
+vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
+     --enable-reasoning --reasoning-parser deepseek_r1
+```
+
+This example demonstrates how to generate chat completions from reasoning models
+using the OpenAI Python client library.
+"""
+
+from enum import Enum
+
+from openai import OpenAI
+from pydantic import BaseModel
+
+# Modify OpenAI's API key and API base to use vLLM's API server.
+openai_api_key = "EMPTY"
+openai_api_base = "http://localhost:8000/v1"
+
+client = OpenAI(
+    api_key=openai_api_key,
+    base_url=openai_api_base,
+)
+
+models = client.models.list()
+model = models.data[0].id
+
+
+# Guided decoding by JSON using Pydantic schema
+class CarType(str, Enum):
+    sedan = "sedan"
+    suv = "SUV"
+    truck = "Truck"
+    coupe = "Coupe"
+
+
+class CarDescription(BaseModel):
+    brand: str
+    model: str
+    car_type: CarType
+
+
+json_schema = CarDescription.model_json_schema()
+
+prompt = ("Generate a JSON with the brand, model and car_type of"
+          "the most iconic car from the 90's, think in 100 tokens")
+completion = client.chat.completions.create(
+    model=model,
+    messages=[{
+        "role": "user",
+        "content": prompt,
+    }],
+    extra_body={"guided_json": json_schema},
+)
+print("content", completion.choices[0].message.content)
+print("reasoning_content: ", completion.choices[0].message.reasoning_content)
diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py
index be544698fa03..531c3a8c13b2 100644
--- a/tests/model_executor/test_guided_processors.py
+++ b/tests/model_executor/test_guided_processors.py
@@ -16,17 +16,33 @@
 
 MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
 GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
+GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
+REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
 
 
-def test_guided_logits_processors(sample_regex, sample_json_schema):
+# Initialize the tokenizer for the model here to avoid repeated loading
+@pytest.fixture(scope="module")
+def zephyr_7B_tokenzer():
+    return AutoTokenizer.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(scope="module")
+def deepseek_r1_qwen_tokenizer():
+    return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
+
+
+def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
+                                  sample_json_schema):
     """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
-    tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
-    regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
+    regex_LP = RegexLogitsProcessor(sample_regex,
+                                    zephyr_7B_tokenzer,
+                                    reasoner=None)
     json_LP = JSONLogitsProcessor(sample_json_schema,
-                                  tokenizer,
-                                  whitespace_pattern=None)
+                                  zephyr_7B_tokenzer,
+                                  whitespace_pattern=None,
+                                  reasoner=None)
 
-    token_ids = tokenizer.encode(
+    token_ids = zephyr_7B_tokenzer.encode(
         f"Give an example IPv4 address with this regex: {sample_regex}")
     tensor = torch.rand(32000)
     original_tensor = torch.clone(tensor)
@@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
     assert tensor.shape == original_tensor.shape
     assert not torch.allclose(tensor, original_tensor)
 
-    token_ids = tokenizer.encode(
+    token_ids = zephyr_7B_tokenzer.encode(
         f"Give an employee profile that fits this schema: {sample_json_schema}"
     )
     tensor = torch.rand(32000)
@@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
 @pytest.mark.parametrize("is_local", [True, False])
 async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
                                                  sample_regex,
-                                                 sample_json_schema):
+                                                 sample_json_schema,
+                                                 zephyr_7B_tokenzer):
 
     config = ModelConfig(
         MODEL_NAME,
@@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
         seed=0,
         dtype="bfloat16",
     )
-    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
-    token_ids = tokenizer.encode(
+    token_ids = zephyr_7B_tokenzer.encode(
         f"Give an example IPv4 address with this regex: {sample_regex}")
     regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
 
     regex_lp = get_local_guided_decoding_logits_processor(
-            regex_request, tokenizer, config) if is_local else \
+            regex_request, zephyr_7B_tokenzer, config) if is_local else \
             await get_guided_decoding_logits_processor(
-                    regex_request, tokenizer, config)
+                    regex_request, zephyr_7B_tokenzer, config)
     assert regex_lp is not None
     tensor = torch.rand(32000)
     original_tensor = torch.clone(tensor)
@@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
     assert tensor.shape == original_tensor.shape
     assert not torch.allclose(tensor, original_tensor)
 
-    token_ids = tokenizer.encode(
+    token_ids = zephyr_7B_tokenzer.encode(
         f"Give an employee profile that fits this schema: {sample_json_schema}"
     )
     json_request = GuidedDecodingParams(json=sample_json_schema,
                                         backend=backend)
     json_lp = await get_guided_decoding_logits_processor(
-        json_request, tokenizer, config)
+        json_request, zephyr_7B_tokenzer, config)
+    assert json_lp is not None
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    tensor = json_lp(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert not torch.allclose(tensor, original_tensor)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("backend",
+                         GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
+@pytest.mark.parametrize("is_local", [True, False])
+@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
+async def test_guided_logits_processor_with_reasoning(
+        backend: str, is_local: bool, reasoning_backend: str, sample_regex,
+        sample_json_schema, deepseek_r1_qwen_tokenizer):
+
+    config = ModelConfig(
+        REASONING_MODEL_NAME,
+        task="generate",
+        tokenizer=REASONING_MODEL_NAME,
+        tokenizer_mode="auto",
+        trust_remote_code=False,
+        seed=0,
+        dtype="bfloat16",
+    )
+    token_ids = deepseek_r1_qwen_tokenizer.encode(
+        f"Give an example IPv4 address with this regex: {sample_regex}."
+        "here is the thinking process")
+    regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
+
+    regex_lp = get_local_guided_decoding_logits_processor(regex_request,
+                    deepseek_r1_qwen_tokenizer, config,
+                    reasoning_backend) if is_local else \
+            await get_guided_decoding_logits_processor(
+                    regex_request, deepseek_r1_qwen_tokenizer, config,
+                    reasoning_backend)
+    assert regex_lp is not None
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    tensor = regex_lp(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert torch.allclose(tensor, original_tensor)
+
+    token_ids = deepseek_r1_qwen_tokenizer.encode(
+        f"Give an employee profile that fits this schema: {sample_json_schema}."
+        "here is the thinking process")
+    json_request = GuidedDecodingParams(json=sample_json_schema,
+                                        backend=backend)
+    json_lp = get_local_guided_decoding_logits_processor(
+        json_request, deepseek_r1_qwen_tokenizer, config,
+        reasoning_backend) if is_local else \
+        await get_guided_decoding_logits_processor(
+            json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
+    assert json_lp is not None
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    tensor = json_lp(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert torch.allclose(tensor, original_tensor)
+
+    # Thinking is over, so the tensor should change.
+    token_ids = deepseek_r1_qwen_tokenizer.encode(
+        f"Give an employee profile that fits this schema: {sample_json_schema}."
+        "here is the thinking process Then")
+    json_request = GuidedDecodingParams(json=sample_json_schema,
+                                        backend=backend)
+    json_lp = get_local_guided_decoding_logits_processor(
+        json_request, deepseek_r1_qwen_tokenizer, config,
+        reasoning_backend) if is_local else \
+        await get_guided_decoding_logits_processor(
+            json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
     assert json_lp is not None
     tensor = torch.rand(32000)
     original_tensor = torch.clone(tensor)
diff --git a/vllm/config.py b/vllm/config.py
index c7108473442b..54ed38418dd4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2715,6 +2715,8 @@ class DecodingConfig:
     # 'outlines' / 'lm-format-enforcer' / 'xgrammar'
     guided_decoding_backend: str = 'xgrammar'
 
+    reasoning_backend: Optional[str] = None
+
     def compute_hash(self) -> str:
         """
         WARNING: Whenever a new field is added to this config,
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 1a2f794c9151..989eb4dbfd14 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -213,6 +213,8 @@ class EngineArgs:
     calculate_kv_scales: Optional[bool] = None
 
     additional_config: Optional[Dict[str, Any]] = None
+    enable_reasoning: Optional[bool] = None
+    reasoning_parser: Optional[str] = None
 
     def __post_init__(self):
         if not self.tokenizer:
@@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
             "Different platforms may support different configs. Make sure the "
             "configs are valid for the platform you are using. The input format"
             " is like '{\"config_key\":\"config_value\"}'")
+
+        parser.add_argument(
+            "--enable-reasoning",
+            action="store_true",
+            default=False,
+            help="Whether to enable reasoning_content for the model. "
+            "If enabled, the model will be able to generate reasoning content."
+        )
+
+        parser.add_argument(
+            "--reasoning-parser",
+            type=str,
+            choices=["deepseek_r1"],
+            default=None,
+            help=
+            "Select the reasoning parser depending on the model that you're "
+            "using. This is used to parse the reasoning content into OpenAI "
+            "API format. Required for ``--enable-reasoning``.")
+
         return parser
 
     @classmethod
@@ -1332,7 +1353,10 @@ def create_engine_config(self,
                                         if self.enable_prompt_adapter else None
 
         decoding_config = DecodingConfig(
-            guided_decoding_backend=self.guided_decoding_backend)
+            guided_decoding_backend=self.guided_decoding_backend,
+            reasoning_backend=self.reasoning_parser
+            if self.enable_reasoning else None,
+        )
 
         show_hidden_metrics = False
         if self.show_hidden_metrics_for_version is not None:
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 93d9b74d8e1e..90e66b005f39 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -509,6 +509,7 @@ async def add_request_async(
                 tokenizer=await self.get_tokenizer_async(lora_request),
                 default_guided_backend=self.decoding_config.
                 guided_decoding_backend,
+                reasoning_backend=self.decoding_config.reasoning_backend,
                 model_config=self.model_config)
 
         self._add_processed_request(
@@ -530,7 +531,7 @@ async def check_health_async(self) -> None:
 
 async def build_guided_decoding_logits_processor_async(
         sampling_params: SamplingParams, tokenizer: AnyTokenizer,
-        default_guided_backend: str,
+        default_guided_backend: str, reasoning_backend: Optional[str],
         model_config: ModelConfig) -> SamplingParams:
     """Constructs logits processors based on the guided_decoding,
     logits_bias, and allowed_token_ids fields in sampling_params. Deletes
@@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
     sampling_params = copy.copy(sampling_params)
     guided_decoding = sampling_params.guided_decoding
 
-    logger.debug("Building guided decoding logits processor. "
-                 "Params: %s", guided_decoding)
+    logger.info(
+        "Building guided decoding logits processor. "
+        "guided_decoding: %s%s", guided_decoding,
+        f", reasoning_backend: {reasoning_backend}"
+        if reasoning_backend is not None else "")
 
     guided_decoding.backend = guided_decoding.backend or default_guided_backend
 
     processor = await get_guided_decoding_logits_processor(
         guided_params=guided_decoding,
         tokenizer=tokenizer,
+        reasoning_backend=reasoning_backend,
         model_config=model_config)
 
     if processor:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 9c83ea75ead7..f055438d1feb 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2048,10 +2048,15 @@ def _build_logits_processors(
             guided_decoding.backend = guided_decoding.backend or \
                 self.decoding_config.guided_decoding_backend
 
+            logger.debug("Reasoning backend: %s",
+                         self.decoding_config.reasoning_backend)
+
             processor = get_local_guided_decoding_logits_processor(
                 guided_params=guided_decoding,
                 tokenizer=tokenizer,
-                model_config=self.model_config)
+                model_config=self.model_config,
+                reasoning_backend=self.decoding_config.reasoning_backend,
+            )
             if processor:
                 logits_processors.append(processor)
 
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index c12fe242082b..005ba81cd226 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -611,7 +611,8 @@ async def _process_request(
                     default_guided_backend=(self.decoding_config.guided_decoding_backend
                         if self.decoding_config
                         else DecodingConfig.guided_decoding_backend),
-                    model_config=self.model_config
+                    model_config=self.model_config,
+                    reasoning_backend=self.decoding_config.reasoning_backend,
                 )
 
         # 1) Create output queue for this requests.
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index ba953c219708..8d877046f75f 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -13,7 +13,6 @@
 from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
 from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
                                          validate_chat_template)
-from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
 from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
                                                     PromptAdapterPath)
 from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
         default=False,
         help="Enable auto tool choice for supported models. Use "
         "``--tool-call-parser`` to specify which parser to use.")
-    parser.add_argument(
-        "--enable-reasoning",
-        action="store_true",
-        default=False,
-        help="Whether to enable reasoning_content for the model. "
-        "If enabled, the model will be able to generate reasoning content.")
-
-    valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
-    parser.add_argument(
-        "--reasoning-parser",
-        type=str,
-        metavar="{" + ",".join(valid_reasoning_parsers) + "}",
-        default=None,
-        help=
-        "Select the reasoning parser depending on the model that you're using."
-        " This is used to parse the reasoning content into OpenAI API "
-        "format. Required for ``--enable-reasoning``.")
 
     valid_tool_parsers = ToolParserManager.tool_parsers.keys()
     parser.add_argument(
diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py
index 1522e3404182..86f6f0e5f907 100644
--- a/vllm/model_executor/guided_decoding/__init__.py
+++ b/vllm/model_executor/guided_decoding/__init__.py
@@ -5,6 +5,7 @@
 from typing import TYPE_CHECKING
 
 from vllm.logger import init_logger
+from vllm.model_executor.guided_decoding.reasoner import get_reasoner
 from vllm.model_executor.guided_decoding.utils import (
     convert_lark_to_gbnf, grammar_is_likely_lark,
     has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
@@ -103,8 +104,13 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
 
 
 async def get_guided_decoding_logits_processor(
-        guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
-        model_config: ModelConfig) -> LogitsProcessor | None:
+        guided_params: GuidedDecodingParams,
+        tokenizer: PreTrainedTokenizer,
+        model_config: ModelConfig,
+        reasoning_backend: str | None = None) -> LogitsProcessor | None:
+
+    reasoner = get_reasoner(tokenizer, reasoning_backend)
+
     guided_params = maybe_backend_fallback(guided_params)
     # CFG grammar not supported by LMFE, so we use outlines instead
     if guided_params.backend_name == 'outlines':
@@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor(
         from vllm.model_executor.guided_decoding.outlines_decoding import (  # noqa
             get_outlines_guided_decoding_logits_processor)
         return await get_outlines_guided_decoding_logits_processor(
-            guided_params, tokenizer)
-    if guided_params.backend_name == 'lm-format-enforcer':
+            guided_params, tokenizer, reasoner)
+    if guided_params.backend == 'lm-format-enforcer':
         from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (  # noqa
             get_local_lm_format_enforcer_guided_decoding_logits_processor)
         return get_local_lm_format_enforcer_guided_decoding_logits_processor(
@@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor(
         from vllm.model_executor.guided_decoding.xgrammar_decoding import (  # noqa
             get_local_xgrammar_guided_decoding_logits_processor)
         return get_local_xgrammar_guided_decoding_logits_processor(
-            guided_params, tokenizer, model_config)
+            guided_params, tokenizer, model_config, reasoner)
 
     raise ValueError(
         f"Unknown guided decoding backend '{guided_params.backend}'. "
@@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor(
 
 
 def get_local_guided_decoding_logits_processor(
-        guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
-        model_config: ModelConfig) -> LogitsProcessor | None:
+        guided_params: GuidedDecodingParams,
+        tokenizer: PreTrainedTokenizer,
+        model_config: ModelConfig,
+        reasoning_backend: str | None = None) -> LogitsProcessor | None:
     guided_params = maybe_backend_fallback(guided_params)
+
+    # Get the reasoner if needed, it will be None if reasoning_
+    reasoner = get_reasoner(tokenizer, reasoning_backend)
+
     # CFG grammar not supported by LMFE, so we use outlines instead
     if guided_params.backend_name == 'outlines':
         # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
         from vllm.model_executor.guided_decoding.outlines_decoding import (  # noqa
             get_local_outlines_guided_decoding_logits_processor)
         return get_local_outlines_guided_decoding_logits_processor(
-            guided_params, tokenizer)
+            guided_params, tokenizer, reasoner)
     if guided_params.backend_name == 'lm-format-enforcer':
         from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (  # noqa
             get_local_lm_format_enforcer_guided_decoding_logits_processor)
@@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor(
         from vllm.model_executor.guided_decoding.xgrammar_decoding import (  # noqa
             get_local_xgrammar_guided_decoding_logits_processor)
         return get_local_xgrammar_guided_decoding_logits_processor(
-            guided_params, tokenizer, model_config)
+            guided_params, tokenizer, model_config, reasoner)
 
     raise ValueError(
         f"Unknown guided decoding backend '{guided_params.backend}'. "
diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
index ba9c98290368..97f63ae11f45 100644
--- a/vllm/model_executor/guided_decoding/outlines_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -6,12 +6,13 @@
 from enum import Enum
 from json import dumps as json_dumps
 from re import escape as regex_escape
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 
 from transformers import PreTrainedTokenizerBase
 
 from vllm.model_executor.guided_decoding.outlines_logits_processors import (
     CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
+from vllm.model_executor.guided_decoding.reasoner import Reasoner
 from vllm.sampling_params import GuidedDecodingParams
 
 
@@ -58,7 +59,9 @@ class GuidedDecodingMode(Enum):
 
 
 async def get_outlines_guided_decoding_logits_processor(
-    guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
+    guided_params: GuidedDecodingParams,
+    tokenizer: PreTrainedTokenizerBase,
+    reasoner: Optional[Reasoner],
 ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
            None]:
     """
@@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor(
 
     return await loop.run_in_executor(global_thread_pool,
                                       _get_logits_processor, guide, tokenizer,
-                                      mode, guided_params.whitespace_pattern)
+                                      mode, guided_params.whitespace_pattern,
+                                      reasoner)
 
 
 def get_local_outlines_guided_decoding_logits_processor(
-    guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
+    guided_params: GuidedDecodingParams,
+    tokenizer: PreTrainedTokenizerBase,
+    reasoner: Optional[Reasoner],
 ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
            None]:
     """
@@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor(
         return None
 
     return _get_logits_processor(guide, tokenizer, mode,
-                                 guided_params.whitespace_pattern)
+                                 guided_params.whitespace_pattern, reasoner)
 
 
 def _get_guide_and_mode(
@@ -131,14 +137,18 @@ def _get_guide_and_mode(
 
 
 def _get_logits_processor(
-    guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
-    whitespace_pattern: Union[str, None]
+    guide: str,
+    tokenizer: PreTrainedTokenizerBase,
+    mode: GuidedDecodingMode,
+    whitespace_pattern: Union[str, None],
+    reasoner: Optional[Reasoner],
 ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
     if mode == GuidedDecodingMode.JSON:
-        return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
+        return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
+                                   reasoner)
     elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
-        return RegexLogitsProcessor(guide, tokenizer)
+        return RegexLogitsProcessor(guide, tokenizer, reasoner)
     elif mode == GuidedDecodingMode.GRAMMAR:
-        return CFGLogitsProcessor(guide, tokenizer)
+        return CFGLogitsProcessor(guide, tokenizer, reasoner)
     else:
         raise ValueError(f"Unknown guided decoding mode {mode}")
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index a05267d921d1..db5d738f42e4 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -19,7 +19,7 @@
 import json
 from collections import defaultdict
 from functools import lru_cache
-from typing import Callable, DefaultDict, Dict, List, Union
+from typing import Callable, DefaultDict, Dict, List, Optional, Union
 
 import numpy as np
 import torch
@@ -32,13 +32,18 @@
 from pydantic import BaseModel
 from transformers import PreTrainedTokenizerBase
 
+from vllm.logger import init_logger
+from vllm.model_executor.guided_decoding.reasoner import Reasoner
 from vllm.platforms import current_platform
 
+logger = init_logger(__name__)
+
 
 class BaseLogitsProcessor:
 
-    def __init__(self, guide: Guide):
+    def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
         self._guide: Guide = guide
+        self._reasoner = reasoner
         # CFGState is used for the FSM state for CFGGuide
         self._fsm_state: DefaultDict[int, Union[int,
                                                 CFGState]] = defaultdict(int)
@@ -46,6 +51,14 @@ def __init__(self, guide: Guide):
     def __call__(self, input_ids: List[int],
                  scores: torch.Tensor) -> torch.Tensor:
         """Use the FSM to bias the logits before sampling the next token."""
+
+        # Skip the structured logits processing if reasoning is not finished.
+        # reasoner is not None only when `--enable-reasoning` is set.
+        if self._reasoner is not None and \
+        not self._reasoner.is_reasoning_end(
+                input_ids):
+            return scores
+
         seq_id = hash(tuple(input_ids))
 
         if len(input_ids) > 0:
@@ -113,7 +126,12 @@ def _get_guide(cls, regex_string: str,
         tokenizer = _adapt_tokenizer(tokenizer)
         return RegexGuide.from_regex(regex_string, tokenizer)
 
-    def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
+    def __init__(
+        self,
+        regex_string: str,
+        tokenizer: PreTrainedTokenizerBase,
+        reasoner: Optional[Reasoner],
+    ):
         """Compile the FSM that drives the regex-structured generation.
 
         Parameters
@@ -125,14 +143,15 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
 
         """
         super().__init__(
-            RegexLogitsProcessor._get_guide(regex_string, tokenizer))
+            RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
 
 
 class JSONLogitsProcessor(RegexLogitsProcessor):
 
     def __init__(self, schema: Union[str, Dict, BaseModel],
                  tokenizer: PreTrainedTokenizerBase,
-                 whitespace_pattern: Union[str, None]):
+                 whitespace_pattern: Union[str, None],
+                 reasoner: Optional[Reasoner]):
         """Compile the FSM that drives the JSON-guided generation.
 
         Parameters
@@ -160,7 +179,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
                 f"a Pydantic object, a dictionary or a string that contains "
                 f"the JSON Schema specification")
         regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
-        super().__init__(regex_string, tokenizer)
+        super().__init__(regex_string, tokenizer, reasoner)
 
 
 class CFGLogitsProcessor(BaseLogitsProcessor):
@@ -171,7 +190,8 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
         tokenizer = _adapt_tokenizer(tokenizer)
         return CFGGuide(cfg, tokenizer)
 
-    def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
+    def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
+                 reasoner: Optional[Reasoner]):
         """Compile the FSM that drives the context free grammar generation.
 
         Parameters
@@ -182,7 +202,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
             The model's tokenizer
 
         """
-        super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
+        super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
+                         reasoner)
         self._guide = self._guide.copy()
 
 
diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py
new file mode 100644
index 000000000000..5a91f791d45b
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py
@@ -0,0 +1,23 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from transformers import PreTrainedTokenizer
+
+from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import (  # noqa: E501
+    DeepSeekReasoner)
+from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
+
+
+def get_reasoner(tokenizer: PreTrainedTokenizer,
+                 reasoning_backend: str | None) -> Reasoner | None:
+    if reasoning_backend is None:
+        # No reasoning backend specified
+        return None
+    elif reasoning_backend == "deepseek_r1":
+        return DeepSeekReasoner.from_tokenizer(tokenizer)
+    else:
+        raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
+
+
+__all__ = ["Reasoner", "get_reasoner"]
diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
new file mode 100644
index 000000000000..e762fb0659de
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+from dataclasses import dataclass
+
+from transformers import PreTrainedTokenizer
+
+from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
+
+
+@dataclass
+class DeepSeekReasoner(Reasoner):
+    """
+    Reasoner for DeepSeek R series models.
+    """
+    start_token_id: int
+    end_token_id: int
+
+    start_token: str = ""
+    end_token: str = ""
+
+    @classmethod
+    def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+        return cls(start_token_id=tokenizer.encode(
+            "", add_special_tokens=False)[0],
+                   end_token_id=tokenizer.encode("",
+                                                 add_special_tokens=False)[0])
+
+    def is_reasoning_end(self, input_ids: list[int]) -> bool:
+        return self.end_token_id in input_ids
diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py
new file mode 100644
index 000000000000..5db0c9bc7850
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py
@@ -0,0 +1,19 @@
+# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from transformers import PreTrainedTokenizer
+
+
+@dataclass
+class Reasoner(ABC):
+
+    @abstractmethod
+    def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+        pass
+
+    @abstractmethod
+    def is_reasoning_end(self, input_ids: list[int]) -> bool:
+        pass
diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
index eb9d83acb286..ce278c15ab3b 100644
--- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py
+++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
@@ -11,6 +11,8 @@
 import torch
 from transformers import PreTrainedTokenizerFast
 
+from vllm.logger import init_logger
+
 try:
     import xgrammar as xgr
     from xgrammar.base import _core as xgr_core
@@ -19,7 +21,6 @@
     xgr_installed = False
     pass
 
-from vllm.logger import init_logger
 from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
                                                        grammar_is_likely_lark)
 from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
@@ -28,6 +29,7 @@
     from transformers import PreTrainedTokenizer
 
     from vllm.config import ModelConfig
+    from vllm.model_executor.guided_decoding.reasoner import Reasoner
     from vllm.sampling_params import GuidedDecodingParams
 
 logger = init_logger(__name__)
@@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor(
         guided_params: GuidedDecodingParams,
         tokenizer: PreTrainedTokenizer,
         model_config: ModelConfig,
+        reasoner: Reasoner | None,
         max_threads: int = 8):
     config = GrammarConfig.from_guided_params(guided_params=guided_params,
                                               model_config=model_config,
                                               tokenizer=tokenizer,
                                               max_threads=max_threads)
-    return XGrammarLogitsProcessor(config)
+    return XGrammarLogitsProcessor(config, reasoner)
 
 
 @dataclass(frozen=True)
@@ -293,6 +296,7 @@ def choice_as_grammar(choice: List[str] | None) -> str:
 class XGrammarLogitsProcessor:
     """Wrapper class to support pickle protocol"""
     config: GrammarConfig
+    reasoner: Reasoner | None = None
 
     ctx: xgr.CompiledGrammar | None = None
     token_bitmask: torch.Tensor = None  # type: ignore[assignment]
@@ -301,10 +305,11 @@ class XGrammarLogitsProcessor:
     prefilled: bool = field(default=False)
 
     def __getstate__(self) -> dict[str, Any]:
-        return {'config': self.config}
+        return {'config': self.config, 'reasoner': self.reasoner}
 
     def __setstate__(self, state: dict[str, Any]):
         self.config = state['config']
+        self.reasoner = state['reasoner']
 
         self.ctx = None
         self.matchers = []
@@ -331,6 +336,14 @@ def _ensure_ctx(self):
 
     def __call__(self, input_ids: list[int],
                  scores: torch.Tensor) -> torch.Tensor:
+
+        # Skip the structured logits processing if reasoning is not finished.
+        # reasoner is not None only when `--enable-reasoning` is set.
+        if self.reasoner is not None and \
+        not self.reasoner.is_reasoning_end(
+                input_ids):
+            return scores
+
         if self.ctx is None:
             self._ensure_ctx()