Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel model preloading #211

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class InferenceEngine(ABC):

@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass
Expand All @@ -15,6 +16,11 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_s
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass

@abstractmethod
async def preload_model(self, shard: Shard) -> None:
"""Preload the model into memory without full initialization."""
pass


def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
if inference_engine_name == "mlx":
Expand Down
21 changes: 21 additions & 0 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,24 @@ def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
self.shard = shard

async def preload_model(self, shard: Shard) -> None:
# Implement MLX-specific preloading logic
# This might involve loading weights into memory
# without fully initializing the model
if self.model is None:
# Load the model configuration
config = await self.load_config(shard)

# Load the model weights into memory
# but don't initialize the full model yet
self.weights = await self.load_weights(config, shard)

async def load_weights(self, config, shard):
# Implement weight loading logic here
# This should load the weights into memory without full model initialization
pass

def initialize_model(self, weights):
# Implement full model initialization using preloaded weights
pass
5 changes: 5 additions & 0 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,8 @@ async def send_status_to_peer(peer):
@property
def current_topology(self) -> Topology:
return self.topology

def get_assigned_shards(self):
# For a standard node, all shards are assigned to it
# Assuming self.shards exists, otherwise adjust accordingly
return self.shards if hasattr(self, 'shards') else []
8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ def handle_exit():

await node.start(wait_for_peers=args.wait_for_peers)

# Parallelize model preloading
shards_to_load = node.get_assigned_shards()
await asyncio.gather(*(inference_engine.preload_model(shard) for shard in shards_to_load))

# Finish initialization sequentially if needed
for shard in shards_to_load:
await inference_engine.ensure_shard(shard)

if args.run_model:
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
else:
Expand Down