11import argparse
2- from typing import List , Tuple
2+ from typing import List , Tuple , Optional
33import random
44
5- import ray
5+ import torch
6+ try :
7+ import ray
8+ except ImportError :
9+ ray = None
610
711from cacheflow .master .scheduler import Scheduler
812from cacheflow .models import get_memory_analyzer
@@ -31,13 +35,18 @@ def __init__(
3135 all_stage_devices : List [List [DeviceID ]],
3236 gpu_memory : int ,
3337 cpu_memory : int ,
38+ use_ray : bool ,
3439 collect_stats : bool = False ,
3540 do_memory_analysis : bool = False ,
3641 ):
3742 self .num_nodes = num_nodes
3843 self .num_devices_per_node = num_devices_per_node
3944 self .world_size = pipeline_parallel_size * tensor_parallel_size
4045
46+ if not use_ray :
47+ assert self .world_size == 1 , (
48+ "Only support single GPU without Ray." )
49+
4150 self .memory_analyzer = get_memory_analyzer (
4251 model_name = model ,
4352 block_size = block_size ,
@@ -72,6 +81,7 @@ def __init__(
7281 model_path = model_path ,
7382 use_dummy_weights = use_dummy_weights ,
7483 max_num_batched_tokens = max_num_batched_tokens ,
84+ use_ray = use_ray ,
7585 )
7686 self .controllers .append (controller )
7787
@@ -105,11 +115,30 @@ def has_unfinished_requests(self):
105115 self .scheduler .swapped )
106116
107117
108- def initialize_ray_cluster (
109- address : str = 'auto' ,
118+ def initialize_cluster (
119+ use_ray : bool = False ,
120+ address : Optional [str ] = None ,
110121 pipeline_parallel_size : int = 1 ,
111122 tensor_parallel_size : int = 1 ,
112123) -> Tuple [int , int , str , List [List [DeviceID ]]]:
124+ # Initialize cluster locally.
125+ if not use_ray :
126+ assert pipeline_parallel_size * tensor_parallel_size == 1 , (
127+ "Only support single GPU without Ray." )
128+ num_nodes = 1
129+ num_devices_per_node = torch .cuda .device_count ()
130+ port = random .randint (10000 , 20000 )
131+ # We need to setup the distributed init method to make sure
132+ # the distributed megatron code (e.g., get world size) works correctly.
133+ distributed_init_method = f"tcp://localhost:{ port } "
134+ all_stage_devices = [[(0 , None , 0 )]]
135+ return (num_nodes , num_devices_per_node , distributed_init_method ,
136+ all_stage_devices )
137+
138+ assert ray is not None , (
139+ "Ray is not installed. Please install Ray to use distributed "
140+ "serving." )
141+
113142 # Connect to a ray cluster.
114143 ray .init (address = address )
115144
@@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
177206 parser .add_argument ('--model-path' , type = str , default = '~/.cacheflow/model_weights' ,
178207 help = 'model path to download and load the weights' )
179208 # Parallel arguments
209+ parser .add_argument ('--use-ray' , action = 'store_true' , help = 'use Ray for distributed serving, will be automatically set when using more than 1 GPU' )
180210 parser .add_argument ('--pipeline-parallel-size' , '-pp' , type = int , default = 1 , help = 'number of pipeline stages' )
181211 parser .add_argument ('--tensor-parallel-size' , '-tp' , type = int , default = 1 , help = 'number of tensor parallel replicas' )
182212 # KV cache arguments
@@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
190220 parser .add_argument ('--max-num-sequences' , type = int , default = 256 , help = 'maximum number of sequences per iteration' )
191221 parser .add_argument ('--use-dummy-weights' , action = 'store_true' , help = 'use dummy values for model weights' )
192222 return parser
223+
224+ def process_server_arguments (args : argparse .Namespace ):
225+ if args .pipeline_parallel_size * args .tensor_parallel_size > 1 :
226+ args .use_ray = True
227+ return args
0 commit comments