Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7458b28
Data LLM Config Refactor - Part 1
nrghosh Oct 30, 2025
beb5d98
Data LLM Config Refactor - Part 2
nrghosh Oct 31, 2025
1327d43
Data LLM Config Refactor - Part 3: Update SGLang processor with stage…
nrghosh Nov 7, 2025
399ef58
Data LLM Config Refactor - Part 4: Add deprecation warnings for legac…
nrghosh Nov 12, 2025
275a93d
Data LLM Config Refactor - Part 5: Update public API docstrings
nrghosh Nov 12, 2025
5299043
wip - feedback: Fix config mutation and chat_template_kwargs bugs
nrghosh Nov 12, 2025
2cf40d4
wip - feedback: Add concurrency to resolve_stage_config and fix CPU s…
nrghosh Nov 12, 2025
163f51f
wip - feedback: Fix runtime_env and batch_size falsy value handling
nrghosh Nov 13, 2025
89ba2fe
wip - Add num_cpus and memory support to StageConfig for per-stage re…
nrghosh Nov 13, 2025
42dcaf0
wip - feedback: fix or operators to checks
nrghosh Nov 13, 2025
5a607f6
wip - feedback: Fix resolve_stage_config to raise TypeError for unsup…
nrghosh Nov 13, 2025
45dcf5e
wip
nrghosh Nov 13, 2025
31ec8f6
wip - readability: refactor merging logic and simplify expressions
nrghosh Nov 13, 2025
cee0247
wip - cleanup: remove redundant checks for merged fields
nrghosh Nov 13, 2025
b938ba8
wip - extract CPU stage building helpers to shared utils module
nrghosh Nov 13, 2025
f4e2a5c
wip - detailed public API docstrings for nested stage configs
nrghosh Nov 13, 2025
636e957
wip - Fix: handle None values in legacy flag coercion
nrghosh Nov 13, 2025
245f521
Merge branch 'master' into nrghosh/data-llm-config-refactor
nrghosh Nov 14, 2025
12bffb3
Update python/ray/llm/_internal/batch/processor/sglang_engine_proc.py
nrghosh Nov 18, 2025
7c8d97f
Update python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
nrghosh Nov 18, 2025
66fee73
Merge branch 'master' into nrghosh/data-llm-config-refactor
nrghosh Nov 18, 2025
474f119
Data LLM Config Refactor - Part 6: Add unit tests for stage config re…
nrghosh Nov 18, 2025
f0f7fdf
wip - fix None concurrency in `normalize_cpu_stage_concurrency`
nrghosh Nov 18, 2025
4e35742
wip - update tests to use nested StageConfig schema
nrghosh Nov 18, 2025
3293d47
wip
nrghosh Nov 18, 2025
c200da2
fix test
nrghosh Nov 19, 2025
cac411a
wip - test
nrghosh Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions python/ray/data/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,30 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
each batch. The default value may not be optimal when the batch size
or the batch processing latency is too small, but it should be good
enough for batch size >= 64.
apply_chat_template: Whether to apply chat template.
chat_template: The chat template to use. This is usually not needed if the
model checkpoint already contains the chat template.
tokenize: Whether to tokenize the input before passing it to the vLLM engine.
If not, vLLM will tokenize the prompt in the engine.
detokenize: Whether to detokenize the output.
has_image: Whether the input messages have images.
chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``
and ``chat_template`` fields are deprecated but still supported.
tokenize_stage: Tokenizer stage config (bool | dict | TokenizerStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
``tokenize`` field is deprecated but still supported.
detokenize_stage: Detokenizer stage config (bool | dict | DetokenizeStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
``detokenize`` field is deprecated but still supported.
prepare_image_stage: Prepare image stage config (bool | dict | PrepareImageStageConfig).
Defaults to False. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, and memory. Legacy ``has_image`` field
is deprecated but still supported.
accelerator_type: The accelerator type used by the LLM stage in a processor.
Default to None, meaning that only the CPU will be used.
concurrency: The number of workers for data parallelism. Default to 1.
If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.
Stage-specific concurrency can be set via nested stage configs.

Examples:

Expand Down Expand Up @@ -205,19 +215,26 @@ class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):
each batch. The default value may not be optimal when the batch size
or the batch processing latency is too small, but it should be good
enough for batch size >= 64.
apply_chat_template: Whether to apply chat template.
chat_template: The chat template to use. This is usually not needed if the
model checkpoint already contains the chat template.
tokenize: Whether to tokenize the input before passing it to the SGLang engine.
If not, SGLang will tokenize the prompt in the engine.
detokenize: Whether to detokenize the output.
chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``
and ``chat_template`` fields are deprecated but still supported.
tokenize_stage: Tokenizer stage config (bool | dict | TokenizerStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
``tokenize`` field is deprecated but still supported.
detokenize_stage: Detokenizer stage config (bool | dict | DetokenizeStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
``detokenize`` field is deprecated but still supported.
accelerator_type: The accelerator type used by the LLM stage in a processor.
Default to None, meaning that only the CPU will be used.
concurrency: The number of workers for data parallelism. Default to 1.
If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.
Stage-specific concurrency can be set via nested stage configs.

Examples:
.. testcode::
Expand Down Expand Up @@ -375,7 +392,12 @@ def build_llm_processor(
"""Build a LLM processor using the given config.

Args:
config: The processor config.
config: The processor config. Supports nested stage configs for per-stage
control over batch_size, concurrency, runtime_env, num_cpus, and memory
(e.g., ``chat_template_stage=ChatTemplateStageConfig(batch_size=128)``
or ``tokenize_stage={"batch_size": 256, "concurrency": 2}``). Legacy
boolean flags (``apply_chat_template``, ``tokenize``, ``detokenize``,
``has_image``) are deprecated but still supported with deprecation warnings.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages. Each row
Expand Down Expand Up @@ -483,7 +505,7 @@ def build_llm_processor(

config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen3-0.6B",
apply_chat_template=True,
chat_template_stage={"enabled": True},
concurrency=1,
batch_size=64,
)
Expand Down
83 changes: 74 additions & 9 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from pydantic import Field, field_validator
from pydantic import Field, field_validator, root_validator

import ray
from ray.data import Dataset
Expand Down Expand Up @@ -155,29 +155,94 @@ class OfflineProcessorConfig(ProcessorConfig):
"enough for batch size >= 32.",
)

# Processor stage configurations.
# Processor stage configurations (legacy booleans, will be deprecated).
apply_chat_template: bool = Field(
default=True, description="Whether to apply chat template."
default=True,
description="[DEPRECATED] Prefer `chat_template_stage`. Whether to apply chat template.",
)
chat_template: Optional[str] = Field(
default=None,
description="The chat template to use. This is usually not needed if the "
"model checkpoint already contains the chat template.",
description="[DEPRECATED] Prefer `chat_template_stage.chat_template`. The chat template to use.",
)
tokenize: bool = Field(
default=True,
description="Whether to tokenize the input before passing it to the "
"backend engine. If not, the backend engine will tokenize the prompt.",
description="[DEPRECATED] Prefer `tokenize_stage`. Whether to tokenize input before engine.",
)
detokenize: bool = Field(
default=True,
description="Whether to detokenize the output.",
description="[DEPRECATED] Prefer `detokenize_stage`. Whether to detokenize the output.",
)
has_image: bool = Field(
default=False,
description="Whether the input messages have images.",
description="[DEPRECATED] Prefer `prepare_image_stage`. Whether the input messages have images.",
)

# New nested stage configuration (bool | dict | typed config).
chat_template_stage: Any = Field(
default=True,
description="Chat templating stage config (bool | dict | ChatTemplateStageConfig).",
)
tokenize_stage: Any = Field(
default=True,
description="Tokenizer stage config (bool | dict | TokenizerStageConfig).",
)
detokenize_stage: Any = Field(
default=True,
description="Detokenizer stage config (bool | dict | DetokenizeStageConfig).",
)
prepare_image_stage: Any = Field(
default=False,
description="Prepare image stage config (bool | dict | PrepareImageStageConfig).",
)

@root_validator(pre=True)
def _coerce_legacy_to_stage_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Only set stage fields if not explicitly provided.
# Emit deprecation warnings when legacy boolean flags are used.

# Chat template stage: special case (handles both apply_chat_template and chat_template fields)
if "chat_template_stage" not in values:
if "apply_chat_template" in values or "chat_template" in values:
logger.warning(
"The `apply_chat_template` and `chat_template` fields are deprecated. "
"Use `chat_template_stage` instead. For example: "
"`chat_template_stage=ChatTemplateStageConfig(enabled=True, chat_template='...')` "
"or `chat_template_stage={'enabled': True, 'chat_template': '...'}`. "
"This will raise an error in a future version."
)
enabled_value = values.get("apply_chat_template")
enabled = enabled_value if enabled_value is not None else True
stage: Dict[str, Any] = {"enabled": enabled}
if values.get("chat_template") is not None:
stage["chat_template"] = values["chat_template"]
values["chat_template_stage"] = stage

# Other stages: simple boolean-to-stage mapping
stage_mappings = [
("tokenize_stage", "tokenize", True, "TokenizerStageConfig"),
("detokenize_stage", "detokenize", True, "DetokenizeStageConfig"),
("prepare_image_stage", "has_image", False, "PrepareImageStageConfig"),
]
for (
stage_field,
legacy_field,
default_enabled,
config_class_name,
) in stage_mappings:
if stage_field not in values and legacy_field in values:
logger.warning(
f"The `{legacy_field}` field is deprecated. "
f"Use `{stage_field}` instead. For example: "
f"`{stage_field}={config_class_name}(enabled=True)` "
f"or `{stage_field}={{'enabled': True}}`. "
"This will raise an error in a future version."
)
legacy_value = values.get(legacy_field)
enabled = default_enabled if legacy_value is None else legacy_value
values[stage_field] = {"enabled": enabled}

return values


@PublicAPI(stability="alpha")
class Processor:
Expand Down
78 changes: 52 additions & 26 deletions python/ray/llm/_internal/batch/processor/sglang_engine_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@
Processor,
ProcessorBuilder,
)
from ray.llm._internal.batch.processor.utils import (
build_cpu_stage_map_kwargs,
get_value_or_fallback,
)
from ray.llm._internal.batch.stages import (
ChatTemplateStage,
DetokenizeStage,
SGLangEngineStage,
TokenizeStage,
)
from ray.llm._internal.batch.stages.configs import (
ChatTemplateStageConfig,
DetokenizeStageConfig,
TokenizerStageConfig,
resolve_stage_config,
)
from ray.llm._internal.batch.stages.sglang_engine_stage import SGLangTaskType
from ray.llm._internal.common.observability.telemetry_utils import DEFAULT_GPU_TYPE

Expand Down Expand Up @@ -85,35 +95,50 @@ def build_sglang_engine_processor(

stages = []

if config.apply_chat_template:
# Prepare processor defaults for merging into stage configs
processor_defaults = {
"batch_size": config.batch_size,
"concurrency": config.concurrency,
"runtime_env": config.runtime_env,
"model_source": config.model_source,
}

# Resolve and build ChatTemplateStage if enabled
chat_template_stage_cfg = resolve_stage_config(
config.chat_template_stage,
ChatTemplateStageConfig,
processor_defaults,
)
if chat_template_stage_cfg.enabled:
stages.append(
ChatTemplateStage(
fn_constructor_kwargs=dict(
model=config.model_source,
chat_template=config.chat_template,
chat_template_kwargs=chat_template_kwargs,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
model=chat_template_stage_cfg.model_source,
chat_template=get_value_or_fallback(
chat_template_stage_cfg.chat_template, config.chat_template
),
chat_template_kwargs=get_value_or_fallback(
chat_template_stage_cfg.chat_template_kwargs,
chat_template_kwargs,
),
),
map_batches_kwargs=build_cpu_stage_map_kwargs(chat_template_stage_cfg),
)
)

if config.tokenize:
# Resolve and build TokenizeStage if enabled
tokenize_stage_cfg = resolve_stage_config(
getattr(config, "tokenize_stage", config.tokenize),
TokenizerStageConfig,
processor_defaults,
)
if tokenize_stage_cfg.enabled:
stages.append(
TokenizeStage(
fn_constructor_kwargs=dict(
model=config.model_source,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
model=tokenize_stage_cfg.model_source,
),
map_batches_kwargs=build_cpu_stage_map_kwargs(tokenize_stage_cfg),
)
)

Expand Down Expand Up @@ -149,18 +174,19 @@ def build_sglang_engine_processor(
)
)

if config.detokenize:
# Resolve and build DetokenizeStage if enabled
detokenize_stage_cfg = resolve_stage_config(
getattr(config, "detokenize_stage", config.detokenize),
DetokenizeStageConfig,
processor_defaults,
)
if detokenize_stage_cfg.enabled:
stages.append(
DetokenizeStage(
fn_constructor_kwargs=dict(
model=config.model_source,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
model=detokenize_stage_cfg.model_source,
),
map_batches_kwargs=build_cpu_stage_map_kwargs(detokenize_stage_cfg),
)
)

Expand Down
54 changes: 54 additions & 0 deletions python/ray/llm/_internal/batch/processor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Shared utility functions for processor builders."""

from typing import Any, Dict, Optional, Tuple, Union

from ray.llm._internal.batch.stages.configs import _StageConfigBase


def get_value_or_fallback(value: Any, fallback: Any) -> Any:
"""Return value if not None, otherwise return fallback."""
return value if value is not None else fallback


def extract_resource_kwargs(
runtime_env: Optional[Dict[str, Any]],
num_cpus: Optional[float],
memory: Optional[float],
) -> Dict[str, Any]:
"""Extract non-None resource kwargs for map_batches."""
kwargs = {}
if runtime_env is not None:
kwargs["runtime_env"] = runtime_env
if num_cpus is not None:
kwargs["num_cpus"] = num_cpus
if memory is not None:
kwargs["memory"] = memory
return kwargs


def normalize_cpu_stage_concurrency(
concurrency: Optional[Union[int, Tuple[int, int]]]
) -> Tuple[int, int]:
"""Normalize concurrency for CPU stages (int -> (1, int) for autoscaling)."""
if concurrency is None:
return (1, 1) # Default to minimal autoscaling pool
if isinstance(concurrency, int):
return (1, concurrency)
return concurrency


def build_cpu_stage_map_kwargs(
stage_cfg: _StageConfigBase,
) -> Dict[str, Any]:
"""Build map_batches_kwargs for CPU stages."""
concurrency = normalize_cpu_stage_concurrency(stage_cfg.concurrency)
return dict(
zero_copy_batch=True,
concurrency=concurrency,
batch_size=stage_cfg.batch_size,
**extract_resource_kwargs(
stage_cfg.runtime_env,
stage_cfg.num_cpus,
stage_cfg.memory,
),
)
Loading