|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +""" | 
|  | 3 | +This example shows how to use Ray Data for data parallel batch inference. | 
|  | 4 | +
 | 
|  | 5 | +Ray Data is a data processing framework that can handle large datasets | 
|  | 6 | +and integrates tightly with vLLM for data-parallel inference. | 
|  | 7 | +
 | 
|  | 8 | +As of Ray 2.44, Ray Data has a native integration with | 
|  | 9 | +vLLM (under ray.data.llm). | 
|  | 10 | +
 | 
|  | 11 | +Ray Data provides functionality for: | 
|  | 12 | +* Reading and writing to cloud storage (S3, GCS, etc.) | 
|  | 13 | +* Automatic sharding and load-balancing across a cluster | 
|  | 14 | +* Optimized configuration of vLLM using continuous batching | 
|  | 15 | +* Compatible with tensor/pipeline parallel inference as well. | 
|  | 16 | +
 | 
|  | 17 | +Learn more about Ray Data's LLM integration: | 
|  | 18 | +https://docs.ray.io/en/latest/data/working-with-llms.html | 
|  | 19 | +""" | 
|  | 20 | +import ray | 
|  | 21 | +from packaging.version import Version | 
|  | 22 | +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig | 
|  | 23 | + | 
|  | 24 | +assert Version(ray.__version__) >= Version( | 
|  | 25 | +    "2.44.1"), "Ray version must be at least 2.44.1" | 
|  | 26 | + | 
|  | 27 | +# Uncomment to reduce clutter in stdout | 
|  | 28 | +# ray.init(log_to_driver=False) | 
|  | 29 | +# ray.data.DataContext.get_current().enable_progress_bars = False | 
|  | 30 | + | 
|  | 31 | +# Read one text file from S3. Ray Data supports reading multiple files | 
|  | 32 | +# from cloud storage (such as JSONL, Parquet, CSV, binary format). | 
|  | 33 | +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") | 
|  | 34 | +print(ds.schema()) | 
|  | 35 | + | 
|  | 36 | +size = ds.count() | 
|  | 37 | +print(f"Size of dataset: {size} prompts") | 
|  | 38 | + | 
|  | 39 | +# Configure vLLM engine. | 
|  | 40 | +config = vLLMEngineProcessorConfig( | 
|  | 41 | +    model_source="unsloth/Llama-3.1-8B-Instruct", | 
|  | 42 | +    engine_kwargs={ | 
|  | 43 | +        "enable_chunked_prefill": True, | 
|  | 44 | +        "max_num_batched_tokens": 4096, | 
|  | 45 | +        "max_model_len": 16384, | 
|  | 46 | +    }, | 
|  | 47 | +    concurrency=1,  # set the number of parallel vLLM replicas | 
|  | 48 | +    batch_size=64, | 
|  | 49 | +) | 
|  | 50 | + | 
|  | 51 | +# Create a Processor object, which will be used to | 
|  | 52 | +# do batch inference on the dataset | 
|  | 53 | +vllm_processor = build_llm_processor( | 
|  | 54 | +    config, | 
|  | 55 | +    preprocess=lambda row: dict( | 
|  | 56 | +        messages=[{ | 
|  | 57 | +            "role": "system", | 
|  | 58 | +            "content": "You are a bot that responds with haikus." | 
|  | 59 | +        }, { | 
|  | 60 | +            "role": "user", | 
|  | 61 | +            "content": row["text"] | 
|  | 62 | +        }], | 
|  | 63 | +        sampling_params=dict( | 
|  | 64 | +            temperature=0.3, | 
|  | 65 | +            max_tokens=250, | 
|  | 66 | +        )), | 
|  | 67 | +    postprocess=lambda row: dict( | 
|  | 68 | +        answer=row["generated_text"], | 
|  | 69 | +        **row  # This will return all the original columns in the dataset. | 
|  | 70 | +    ), | 
|  | 71 | +) | 
|  | 72 | + | 
|  | 73 | +ds = vllm_processor(ds) | 
|  | 74 | + | 
|  | 75 | +# Peek first 10 results. | 
|  | 76 | +# NOTE: This is for local testing and debugging. For production use case, | 
|  | 77 | +# one should write full result out as shown below. | 
|  | 78 | +outputs = ds.take(limit=10) | 
|  | 79 | + | 
|  | 80 | +for output in outputs: | 
|  | 81 | +    prompt = output["prompt"] | 
|  | 82 | +    generated_text = output["generated_text"] | 
|  | 83 | +    print(f"Prompt: {prompt!r}") | 
|  | 84 | +    print(f"Generated text: {generated_text!r}") | 
|  | 85 | + | 
|  | 86 | +# Write inference output data out as Parquet files to S3. | 
|  | 87 | +# Multiple files would be written to the output destination, | 
|  | 88 | +# and each task would write one or more files separately. | 
|  | 89 | +# | 
|  | 90 | +# ds.write_parquet("s3://<your-output-bucket>") | 
0 commit comments