Skip to content

Commit 22c755d

Browse files
authored
[data][llm] Add per-stage map kwargs for build_llm_processor preprocess/postprocess (#57826)
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
1 parent 7806bf2 commit 22c755d

File tree

7 files changed

+212
-7
lines changed

7 files changed

+212
-7
lines changed

python/ray/data/llm.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def build_llm_processor(
368368
config: ProcessorConfig,
369369
preprocess: Optional[UserDefinedFunction] = None,
370370
postprocess: Optional[UserDefinedFunction] = None,
371+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
372+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
371373
builder_kwargs: Optional[Dict[str, Any]] = None,
372374
) -> Processor:
373375
"""Build a LLM processor using the given config.
@@ -383,6 +385,12 @@ def build_llm_processor(
383385
postprocess: An optional lambda function that takes a row (dict) as input
384386
and returns a postprocessed row (dict). To keep all the original columns,
385387
you can use the `**row` syntax to return all the original columns.
388+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
389+
preprocess stage. Useful for controlling resources (e.g., num_cpus=0.5)
390+
and concurrency independently of the main LLM stage.
391+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
392+
postprocess stage. Useful for controlling resources (e.g., num_cpus=0.25)
393+
and concurrency independently of the main LLM stage.
386394
builder_kwargs: Optional additional kwargs to pass to the processor builder
387395
function. These will be passed through to the registered builder and
388396
should match the signature of the specific builder being used.
@@ -435,6 +443,36 @@ def build_llm_processor(
435443
for row in ds.take_all():
436444
print(row)
437445
446+
Using map_kwargs to control preprocess/postprocess resources:
447+
448+
.. testcode::
449+
:skipif: True
450+
451+
import ray
452+
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
453+
454+
config = vLLMEngineProcessorConfig(
455+
model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
456+
concurrency=1,
457+
batch_size=64,
458+
)
459+
460+
processor = build_llm_processor(
461+
config,
462+
preprocess=lambda row: dict(
463+
messages=[{"role": "user", "content": row["prompt"]}],
464+
sampling_params=dict(temperature=0.3, max_tokens=20),
465+
),
466+
postprocess=lambda row: dict(resp=row["generated_text"]),
467+
preprocess_map_kwargs={"num_cpus": 0.5},
468+
postprocess_map_kwargs={"num_cpus": 0.25},
469+
)
470+
471+
ds = ray.data.range(300)
472+
ds = processor(ds)
473+
for row in ds.take_all():
474+
print(row)
475+
438476
Using builder_kwargs to pass chat_template_kwargs:
439477
440478
.. testcode::
@@ -474,9 +512,12 @@ def build_llm_processor(
474512
from ray.llm._internal.batch.processor import ProcessorBuilder
475513

476514
ProcessorBuilder.validate_builder_kwargs(builder_kwargs)
515+
477516
build_kwargs = dict(
478517
preprocess=preprocess,
479518
postprocess=postprocess,
519+
preprocess_map_kwargs=preprocess_map_kwargs,
520+
postprocess_map_kwargs=postprocess_map_kwargs,
480521
)
481522

482523
# Pass through any additional builder kwargs

python/ray/llm/_internal/batch/processor/base.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ class Processor:
193193
required fields for the following processing stages.
194194
postprocess: An optional lambda function that takes a row (dict) as input
195195
and returns a postprocessed row (dict).
196+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
197+
preprocess stage (e.g., num_cpus, memory, concurrency).
198+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
199+
postprocess stage (e.g., num_cpus, memory, concurrency).
196200
"""
197201

198202
# The internal used data column name ("__data"). Your input
@@ -206,10 +210,14 @@ def __init__(
206210
stages: List[StatefulStage],
207211
preprocess: Optional[UserDefinedFunction] = None,
208212
postprocess: Optional[UserDefinedFunction] = None,
213+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
214+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
209215
):
210216
self.config = config
211217
self.preprocess = None
212218
self.postprocess = None
219+
self.preprocess_map_kwargs = preprocess_map_kwargs or {}
220+
self.postprocess_map_kwargs = postprocess_map_kwargs or {}
213221
self.stages: OrderedDict[str, StatefulStage] = OrderedDict()
214222

215223
# FIXES: https://github.com/ray-project/ray/issues/53124
@@ -251,7 +259,7 @@ def __call__(self, dataset: Dataset) -> Dataset:
251259
The output dataset.
252260
"""
253261
if self.preprocess is not None:
254-
dataset = dataset.map(self.preprocess)
262+
dataset = dataset.map(self.preprocess, **self.preprocess_map_kwargs)
255263

256264
# Apply stages.
257265
for stage in self.stages.values():
@@ -262,7 +270,7 @@ def __call__(self, dataset: Dataset) -> Dataset:
262270
dataset = dataset.map_batches(stage.fn, **kwargs)
263271

264272
if self.postprocess is not None:
265-
dataset = dataset.map(self.postprocess)
273+
dataset = dataset.map(self.postprocess, **self.postprocess_map_kwargs)
266274
return dataset
267275

268276
def _append_stage(self, stage: StatefulStage) -> None:
@@ -360,7 +368,12 @@ def validate_builder_kwargs(cls, builder_kwargs: Optional[Dict[str, Any]]) -> No
360368
"""
361369
if builder_kwargs is not None:
362370
# Check for conflicts with explicitly passed arguments
363-
reserved_keys = {"preprocess", "postprocess"}
371+
reserved_keys = {
372+
"preprocess",
373+
"postprocess",
374+
"preprocess_map_kwargs",
375+
"postprocess_map_kwargs",
376+
}
364377
conflicting_keys = reserved_keys & builder_kwargs.keys()
365378
if conflicting_keys:
366379
raise ValueError(

python/ray/llm/_internal/batch/processor/http_request_proc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def build_http_request_processor(
5959
config: HttpRequestProcessorConfig,
6060
preprocess: Optional[UserDefinedFunction] = None,
6161
postprocess: Optional[UserDefinedFunction] = None,
62+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
63+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
6264
) -> Processor:
6365
"""Construct a Processor and configure stages.
6466
@@ -69,6 +71,10 @@ def build_http_request_processor(
6971
required fields for the following processing stages.
7072
postprocess: An optional lambda function that takes a row (dict) as input
7173
and returns a postprocessed row (dict).
74+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
75+
preprocess stage (e.g., num_cpus, memory, concurrency).
76+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
77+
postprocess stage (e.g., num_cpus, memory, concurrency).
7278
7379
Returns:
7480
The constructed processor.
@@ -100,6 +106,8 @@ def build_http_request_processor(
100106
stages,
101107
preprocess=preprocess,
102108
postprocess=postprocess,
109+
preprocess_map_kwargs=preprocess_map_kwargs,
110+
postprocess_map_kwargs=postprocess_map_kwargs,
103111
)
104112
return processor
105113

python/ray/llm/_internal/batch/processor/serve_deployment_proc.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def build_serve_deployment_processor(
3636
config: ServeDeploymentProcessorConfig,
3737
preprocess: Optional[UserDefinedFunction] = None,
3838
postprocess: Optional[UserDefinedFunction] = None,
39+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
40+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
3941
) -> Processor:
40-
"""
41-
Construct a processor that runs a serve deployment.
42+
"""Construct a processor that runs a serve deployment.
4243
4344
Args:
4445
config: The configuration for the processor.
@@ -47,6 +48,10 @@ def build_serve_deployment_processor(
4748
required fields for the following processing stages.
4849
postprocess: An optional lambda function that takes a row (dict) as input
4950
and returns a postprocessed row (dict).
51+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
52+
preprocess stage (e.g., num_cpus, memory, concurrency).
53+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
54+
postprocess stage (e.g., num_cpus, memory, concurrency).
5055
5156
Returns:
5257
The constructed processor.
@@ -69,6 +74,8 @@ def build_serve_deployment_processor(
6974
stages,
7075
preprocess=preprocess,
7176
postprocess=postprocess,
77+
preprocess_map_kwargs=preprocess_map_kwargs,
78+
postprocess_map_kwargs=postprocess_map_kwargs,
7279
)
7380
return processor
7481

python/ray/llm/_internal/batch/processor/sglang_engine_proc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ def build_sglang_engine_processor(
5858
chat_template_kwargs: Optional[Dict[str, Any]] = None,
5959
preprocess: Optional[UserDefinedFunction] = None,
6060
postprocess: Optional[UserDefinedFunction] = None,
61+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
62+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
6163
telemetry_agent: Optional[TelemetryAgent] = None,
6264
) -> Processor:
6365
"""Construct a Processor and configure stages.
66+
6467
Args:
6568
config: The configuration for the processor.
6669
chat_template_kwargs: The optional kwargs to pass to apply_chat_template.
@@ -69,6 +72,10 @@ def build_sglang_engine_processor(
6972
required fields for the following processing stages.
7073
postprocess: An optional lambda function that takes a row (dict) as input
7174
and returns a postprocessed row (dict).
75+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
76+
preprocess stage (e.g., num_cpus, memory, concurrency).
77+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
78+
postprocess stage (e.g., num_cpus, memory, concurrency).
7279
telemetry_agent: An optional telemetry agent for collecting usage telemetry.
7380
7481
Returns:
@@ -179,6 +186,8 @@ def build_sglang_engine_processor(
179186
stages,
180187
preprocess=preprocess,
181188
postprocess=postprocess,
189+
preprocess_map_kwargs=preprocess_map_kwargs,
190+
postprocess_map_kwargs=postprocess_map_kwargs,
182191
)
183192
return processor
184193

python/ray/llm/_internal/batch/processor/vllm_engine_proc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,12 @@ def build_vllm_engine_processor(
106106
chat_template_kwargs: Optional[Dict[str, Any]] = None,
107107
preprocess: Optional[UserDefinedFunction] = None,
108108
postprocess: Optional[UserDefinedFunction] = None,
109+
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
110+
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
109111
telemetry_agent: Optional[TelemetryAgent] = None,
110112
) -> Processor:
111113
"""Construct a Processor and configure stages.
114+
112115
Args:
113116
config: The configuration for the processor.
114117
chat_template_kwargs: The optional kwargs to pass to apply_chat_template.
@@ -117,9 +120,12 @@ def build_vllm_engine_processor(
117120
required fields for the following processing stages.
118121
postprocess: An optional lambda function that takes a row (dict) as input
119122
and returns a postprocessed row (dict).
123+
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
124+
preprocess stage (e.g., num_cpus, memory, concurrency).
125+
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
126+
postprocess stage (e.g., num_cpus, memory, concurrency).
120127
telemetry_agent: An optional telemetry agent for collecting usage telemetry.
121128
122-
123129
Returns:
124130
The constructed processor.
125131
"""
@@ -262,6 +268,8 @@ def build_vllm_engine_processor(
262268
stages,
263269
preprocess=preprocess,
264270
postprocess=postprocess,
271+
preprocess_map_kwargs=preprocess_map_kwargs,
272+
postprocess_map_kwargs=postprocess_map_kwargs,
265273
)
266274
return processor
267275

0 commit comments

Comments
 (0)