diff --git a/.gitignore b/.gitignore index b3e39b04..2b36a72f 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,6 @@ piped_mlx.egg-info/ .coverage .idea/ -.vscode/ .DS_Store .ruff_cache/ .logs/ \ No newline at end of file diff --git a/.vscode/.gitignore b/.vscode/.gitignore new file mode 100644 index 00000000..f8a3f97d --- /dev/null +++ b/.vscode/.gitignore @@ -0,0 +1,5 @@ +* + +# ignore everything but itself & tasks +!.gitignore +!tasks.json \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 00000000..aed560d2 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,101 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Start API (8080 + 58080)", + "type": "shell", + "command": "uv", + "args": [ + "run", + "dnet-api", + "--http-port", + "8080", + "--grpc-port", + "58080" + ], + "group": "build", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "new", + "showReuseMessage": true, + "clear": false + } + }, + { + "label": "Start Shard (8081 + 58081)", + "type": "shell", + "command": "uv", + "args": [ + "run", + "dnet-shard", + "--http-port", + "8081", + "--grpc-port", + "58081" + ], + "group": "build", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "new", + "showReuseMessage": true, + "clear": false + } + }, + { + "label": "Health (Shard 8081)", + "type": "shell", + "command": "curl -s http://localhost:8081/health -H \"Content-Type: application/json\" | bun -p \"Bun.inspect(await Bun.stdin.json(), { colors: true })\"" + }, + { + "label": "Start Shard (8082 + 58082)", + "type": "shell", + "command": "uv", + "args": [ + "run", + "dnet-shard", + "--http-port", + "8082", + "--grpc-port", + "58082" + ], + "group": "build", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "new", + "showReuseMessage": true, + "clear": false + } + }, + { + "label": "Prepare Topology (Qwen/Qwen3-4B-MLX-4bit)", + "type": "shell", + "command": "curl -X POST http://localhost:8080/v1/prepare_topology -H \"Content-Type: application/json\" -d '{ \"model\": \"Qwen/Qwen3-4B-MLX-4bit\" }' | bun -p \"Bun.inspect(await Bun.stdin.json(), { colors: true })\"" + }, + { + "label": "Prepare & Load (Qwen/Qwen3-4B-MLX-4bit)", + "type": "shell", + "command": "uv run ./scripts/prepare_model.py Qwen/Qwen3-4B-MLX-4bit" + }, + { + "label": "Get API Devices", + "type": "shell", + "command": "curl -s http://localhost:8080/v1/devices -H \"Content-Type: application/json\" | bun -p \"Bun.inspect(await Bun.stdin.json(), { colors: true })\"" + }, + { + "label": "Get API Topology", + "type": "shell", + "command": "curl -s http://localhost:8080/v1/topology -H \"Content-Type: application/json\" | bun -p \"Bun.inspect(await Bun.stdin.json(), { colors: true })\"" + }, + { + "label": "Chat Completions (Qwen/Qwen3-4B-MLX-4bit)", + "type": "shell", + "command": "curl -X POST http://localhost:8080/v1/chat/completions -H \"Content-Type: application/json\" -d '{\"model\":\"Qwen/Qwen3-4B-MLX-4bit\", \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}], \"max_tokens\": 100}' | bun -p \"Bun.inspect(await Bun.stdin.json(), { colors: true })\"" + } + ] +} diff --git a/Makefile b/Makefile index adf27e3a..b7c42d5f 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,14 @@ format: protos: uv run ./scripts/generate_protos.py -.PHONY: update # | Update git submodules +.PHONY: update # | Update git submodules update: git submodule update --init --recursive +.PHONY: test # | Run tests +test: + uv run pytest -v + .PHONY: help # | List targets help: @grep '^.PHONY: .* #' Makefile | sed 's/\.PHONY: \(.*\) # \(.*\)/\1 \2/' | expand -t20 \ No newline at end of file diff --git a/README.md b/README.md index f3769406..35752059 100644 --- a/README.md +++ b/README.md @@ -158,12 +158,26 @@ curl http://localhost:8080/v1/devices \ ## Testing -You can lint the code using Ruff: +You can run Pytest tests via: ```sh +uv run pytest -v +``` + +You can check linting and formatting via Ruff: + +```sh +# lint uvx ruff check + +# format +uvx ruff format --diff ``` +> [!TIP] +> +> If you are using VsCode, we have prepared [tasks](./.vscode/tasks.json) that you can run easily from the Command Palette > Tasks: Run Task . + ## License You can find the license [here](./LICENSE). diff --git a/lib/dperf b/lib/dperf index baa3a1e5..8e2b1387 160000 --- a/lib/dperf +++ b/lib/dperf @@ -1 +1 @@ -Subproject commit baa3a1e5d8f4ff1dbf37d77c3f26d66c59391cc6 +Subproject commit 8e2b1387d9127ca266e81863ea1ad6da2cf08fa3 diff --git a/pyproject.toml b/pyproject.toml index 427753e3..2ba8aa03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,10 +50,10 @@ requires = ["uv_build>=0.8.17,<0.9.0"] build-backend = "uv_build" [tool.pytest.ini_options] -python_files = ["dnet/**/*_test.py", "tests/*.py"] +python_files = ["src/dnet/**/*_test.py", "tests/*.py"] python_functions = ["test_"] log_cli = true [tool.ruff] exclude = [".git", ".venv", "__pycache__", "build", "dist", "lib"] -line-length = 88 # black +line-length = 88 # black diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index e85da84b..af00b2e8 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -65,15 +65,17 @@ class ChatBaseParams(BaseModel): class ChatParams(ChatBaseParams): """Extended parameters for chat requests.""" - stream: Optional[bool] = False - max_tokens: Optional[int] = Field(default=100, ge=0) - logit_bias: Optional[Dict[int, float]] = Field(default_factory=dict) - logprobs: Optional[int] = Field(default=-1) - stop: Optional[Union[str, List[str]]] = [] - profile: Optional[bool] = False + stream: bool = Field(default=False) + max_tokens: int = Field(default=100, ge=0) + logit_bias: Dict[int, float] = Field(default_factory=dict) + logprobs: int = Field(default=-1) + stop: Union[str, List[str]] = Field(default_factory=list) + profile: bool = Field(default=False) def __init__(self, **data: Any): super().__init__(**data) + + # FIXME: why do this? if isinstance(self.stop, str): self.stop = [self.stop] diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index f7e668ce..9b171058 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -14,7 +14,6 @@ from fastapi import FastAPI, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse from grpc import aio as aio_grpc -from hypercorn import Config import hypercorn.asyncio as aio_hypercorn from mlx_lm.tokenizer_utils import load_tokenizer @@ -42,7 +41,11 @@ from ...utils.logger import logger from ...utils.model import ModelMetadata, get_model_metadata -from .utils import create_generate_step_for_ring_with_grpc +from .utils import ( + create_generate_step_for_ring_with_grpc, + compute_layer_assignments, + optimize_device_ordering, +) from .models import ( ChatBaseParams, ChatChoice, @@ -188,7 +191,7 @@ async def _start_grpc_server(self) -> None: return server = aio_grpc.server() - servicer = ShardApiServicer(self) + servicer = ShardApiServicer(self) # type: ignore # FIXME: !!! add_ShardApiServiceServicer_to_server(servicer, server) listen_addr = f"[::]:{self.grpc_port}" server.add_insecure_port(listen_addr) @@ -202,6 +205,8 @@ async def _start_http_server(self, shutdown_trigger: Any) -> None: Args: shutdown_trigger: Shutdown trigger function """ + from hypercorn import Config + await self._setup_routes() config = Config.from_mapping( @@ -213,8 +218,8 @@ async def _start_http_server(self, shutdown_trigger: Any) -> None: ) self.http_server = asyncio.create_task( - aio_hypercorn.serve( # type: ignore - self.app, + aio_hypercorn.serve( + self.app, # type: ignore config, shutdown_trigger=shutdown_trigger, ) @@ -319,8 +324,12 @@ async def unload_model() -> UnloadModelResponse: ) from e @self.app.post("/v1/chat/completions") - async def chat_completions(req: ChatRequestModel) -> ChatResponseModel: - """Handle chat completion requests.""" + async def chat_completions( + req: ChatRequestModel, + ) -> ChatResponseModel: + """Handle chat completion requests. + + If streaming is requested, returns a StreamingResponse.""" if self.model is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -331,16 +340,18 @@ async def chat_completions(req: ChatRequestModel) -> ChatResponseModel: status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Not connected to first shard", ) - if bool(getattr(req, "stream", False)): + if req.stream: + # FIXME: return type mismatch here return StreamingResponse( self._stream_chat(req), media_type="text/event-stream" ) - return await self._handle_chat_completion(req) + else: + return await self._handle_chat_completion(req) @self.app.post("/v1/completions") async def completions(req: CompletionRequestModel): # type: ignore - """Handle completion requests (not implemented).""" - if bool(getattr(req, "stream", False)): + """Handle completion requests.""" + if req.stream: return StreamingResponse( self._stream_completion(req), media_type="text/event-stream" ) @@ -386,7 +397,7 @@ async def _handle_prepare_topology( ) # Optimize device ordering to place Thunderbolt-connected devices adjacently - optimized_device_name_order = self._optimize_device_ordering( + optimized_device_name_order = optimize_device_ordering( shard_profiles, thunderbolt_conns ) @@ -399,8 +410,8 @@ async def _handle_prepare_topology( ] # shards ordered w.r.t to the solver # Compute layer assignments, next service mapping, and prefetch windows - layer_assignments = self._compute_layer_assignments( - optimized_device_name_order, solution, shards + layer_assignments = compute_layer_assignments( + optimized_device_name_order, solution.w, solution.k, shards ) # Store topology (can be GET'ed later) @@ -432,6 +443,7 @@ async def _handle_prepare_topology_manual( raise ValueError("Device names must be unique in manual topology") # Normalize assignments and validate services + # FIXME: may not need normalized array here, just use assignments services = set(device_names) normalized: List[LayerAssignment] = [] for assignment in req.assignments: @@ -439,17 +451,10 @@ async def _handle_prepare_topology_manual( raise ValueError( f"Assignment references unknown service: {assignment.service}" ) - layers_2d = assignment.layers - try: - # Accept flat list; wrap to single round - if layers_2d and all(isinstance(x, int) for x in layers_2d): # type: ignore - layers_2d = [layers_2d] # type: ignore - except Exception: - pass normalized.append( LayerAssignment( service=assignment.service, - layers=layers_2d, + layers=assignment.layers, next_service=assignment.next_service, window_size=assignment.window_size, ) @@ -479,6 +484,7 @@ async def _handle_prepare_topology_manual( ) # If next_service missing and >1 device, compute simple ring by min layer + # FIXME: may not need this edge case at all, probably redundant if any(a.next_service is None for a in normalized) and len(normalized) > 1: order = sorted( normalized, @@ -495,7 +501,7 @@ async def _handle_prepare_topology_manual( service=a.service, layers=a.layers, next_service=a.next_service or ring_map.get(a.service), - prefetch_window=a.prefetch_window, + window_size=a.window_size, ) for a in normalized ] @@ -528,16 +534,17 @@ async def _handle_load_model( """ # Decide model and assignments if self.topology: - model_to_load = self.topology.model - assignments_to_use = self.topology.assignments - if getattr(req, "model", None) and req.model != model_to_load: + topology = self.topology + + if req.model and req.model != self.topology.model: logger.info( "load_model request model %s overridden by topology model %s", req.model, - model_to_load, + self.topology.model, ) else: - if not getattr(req, "model", None): + # ensure model is given + if not req.model: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=( @@ -546,25 +553,19 @@ async def _handle_load_model( "or include 'model' to bootstrap with discovery." ), ) + # Bootstrap: run discovery-based prepare - await self._handle_prepare_topology(PrepareTopologyRequest(model=req.model)) - model_to_load = self.topology.model # type: ignore - assignments_to_use = self.topology.assignments # type: ignore + topology = await self._handle_prepare_topology( + PrepareTopologyRequest(model=req.model) + ) + model_to_load = topology.model + assignments_to_use = topology.assignments + shards = {dev.instance: dev for dev in topology.devices} logger.info("Loading model: %s", model_to_load) - # Resolve shard endpoints - api_properties = self.discovery.get_own_properties() - if self.topology and getattr(self.topology, "devices", []): - # Build mapping from topology (service/instance -> properties) - shards: Dict[str, DnetDeviceProperties] = { - getattr(dev, "instance"): dev for dev in self.topology.devices - } - else: - # Fallback to discovery when no topology is configured - shards = self._get_shards_from_discovery() - # Notify each shard to load their layers via HTTP + api_properties = self.discovery.get_own_properties() shard_statuses: List[ShardLoadStatus] = [] async with httpx.AsyncClient() as http_client: for assignment in assignments_to_use: @@ -605,13 +606,6 @@ async def _handle_load_model( ) try: - # Total layers from topology (present after bootstrap) - total_layers = ( - self.topology.num_layers - if self.topology - else (max(layers) + 1 if layers else 0) - ) - # Build API callback address (gRPC) api_callback_address = f"{api_properties.local_ip}:{self.grpc_port}" @@ -623,7 +617,7 @@ async def _handle_load_model( warmup=True, next_node=next_shard, window_size=assignment.window_size, - total_layers=total_layers, + total_layers=topology.num_layers, api_callback_address=api_callback_address, ).model_dump() @@ -692,7 +686,7 @@ async def _handle_load_model( model=model_to_load, success=False, shard_statuses=shard_statuses, - message=("Error loading API-side model: %s", e), + message=f"Error loading API-side model: {e}", ) else: failed_shards = [ @@ -731,7 +725,7 @@ async def _handle_unload_model(self) -> UnloadModelResponse: # Call unload_model via HTTP url = f"http://{shard.local_ip}:{shard.server_port}/unload_model" response = await http_client.post(url, timeout=30.0) - result = response.json() + result = response.json() # FIXME: add shard response type shard_statuses.append( ShardUnloadStatus( @@ -812,13 +806,14 @@ async def _connect_first_shard(self) -> bool: Falls back to the first device in topology when ownership cannot be determined. Returns True on success, False otherwise. """ - if not self.topology or not getattr(self.topology, "devices", []): + if not self.topology or not self.topology.devices: + logger.error("No topology configured; cannot connect to first shard") return False # Pick the device whose assignment contains layer 0; fallback to index 0 start_service: str | None = None try: - for assignment in getattr(self.topology, "assignments", []) or []: + for assignment in self.topology.assignments: # Flatten round layers flat = [ layer @@ -831,11 +826,12 @@ async def _connect_first_shard(self) -> bool: except Exception: start_service = None + # find the start device w.r.t service name start_device = None if start_service is not None: try: for dev in self.topology.devices: - if getattr(dev, "instance", None) == start_service: + if dev.instance == start_service: start_device = dev break except Exception: @@ -903,7 +899,7 @@ async def _profile_model( batch_sizes=batch_sizes, sequence_length=sequence_length, ) - logger.info("Model profiling completed.") + logger.info(f"Model profiling completed for {repo_id}.") return load_model_profile_from_dict(asdict(model_profile_split)) async def _collect_shard_profiles( @@ -1002,81 +998,6 @@ async def _collect_shard_profiles( logger.info("Collected profiles from %d shards", len(shard_profiles)) return shard_profiles, all_thunderbolts - # FIXME: move this to elsewhere - def _optimize_device_ordering( - self, - shard_profiles: Dict[str, DeviceProfile], - thunderbolt_conns: Dict[str, Dict[str, ThunderboltConnection]], - ) -> List[str]: - """Optimize device ordering to place Thunderbolt-connected devices adjacently. - - Args: - shard_profiles: Collected shard profiles - thunderbolt_conns: Thunderbolt connections mapping (device -> {neighbor -> connection_info}) - - Returns: - Optimized list of device names with head devices first and Thunderbolt neighbors adjacent - """ - device_names = list(shard_profiles.keys()) - - # Find all head devices (multiple shards can run on same machine as API) - head_devices = [] - for device_name, profile_data in shard_profiles.items(): - if profile_data.is_head: - head_devices.append(device_name) - - if not head_devices: - logger.warning("No head device found in profiles, using first device") - head_devices = [device_names[0]] if device_names else [] - - logger.info("Found %d head device(s): %s", len(head_devices), head_devices) - - # FIXME: shards on the same machine should be adjacent too! - - # Build adjacency graph of Thunderbolt connections - # Graph: device_name -> set of connected device names - tb_graph: Dict[str, set[str]] = {name: set() for name in device_names} - for device_name, neighbors in thunderbolt_conns.items(): - if device_name in tb_graph: - for neighbor_name in neighbors.keys(): - if neighbor_name in tb_graph: - tb_graph[device_name].add(neighbor_name) - tb_graph[neighbor_name].add(device_name) - - # Greedy ordering: Start with all head devices, then pick neighbors with most TB connections - ordered = head_devices.copy() - remaining = set(device_names) - set(head_devices) - - while remaining: - best_candidate = None - best_score = -1 - - # For each remaining device, calculate connection score to already-ordered devices - for candidate in remaining: - # Count Thunderbolt connections to devices already in the order - score = sum( - 1 for ordered_dev in ordered if ordered_dev in tb_graph[candidate] - ) - - # Prioritize devices with TB connections, otherwise any device is fine - if score > best_score: - best_score = score - best_candidate = candidate - - # Add best candidate (or any remaining if no TB connections exist) - if best_candidate: - ordered.append(best_candidate) - remaining.remove(best_candidate) - else: - # Fallback: just pick any remaining device - next_device = remaining.pop() - ordered.append(next_device) - - logger.info("Optimized device ordering: %s", ordered) - logger.info("Thunderbolt graph: %s", tb_graph) - - return ordered - # FIXME: move this to elsewhere async def _run_solver( self, @@ -1116,113 +1037,6 @@ async def _run_solver( return solution - # FIXME: move this to elsewhere - def _compute_layer_assignments( - self, - device_names: List[str], - solution: HALDAResult, - shards: Dict[str, DnetDeviceProperties], - ) -> List[LayerAssignment]: - """Compute round-aware layer assignments, next node mapping, and prefetch windows from solver output. - - Args: - device_names: Device names in solver order - solution: Solver result - shards: Discovered shards - - Returns: - Tuple of (layer assignments per device per round, next service per device in ring, prefetch window per device) - """ - if len(solution.w) != len(shards) or len(device_names) != len(shards): - raise ValueError( - f"Device count mismatch: solution={len(solution.w)}, " - f"shards={len(shards)}" - ) - - num_layers = sum(solution.w) * solution.k - logger.info( - "Distributing %d layers to %d devices in %d rounds", - num_layers, - len(shards), - solution.k, - ) - - layer_assignments: Dict[str, List[List[int]]] = { - name: [[] for _ in range(solution.k)] for name in device_names - } - current_layer = 0 - for round_idx in range(solution.k): - for device_idx, device_name in enumerate(device_names): - for _ in range(solution.w[device_idx]): - layer_assignments[device_name][round_idx].append(current_layer) - current_layer += 1 - assert current_layer == num_layers, ( - f"Assigned {current_layer} layers, expected {num_layers}" - ) - - # Compute next service for each device in ring topology - # In ring: dev1 -> dev2 -> ... -> devN -> dev1 (wraps around) - # Each shard will detect when processing the final layer and send to API - next_service_map: Dict[str, Optional[str]] = {} - - if len(device_names) == 1: - # Single device: forwards to itself in a loop - next_service_map[device_names[0]] = device_names[0] - logger.info( - "Ring (single device): %s -> SELF (loops back)", device_names[0] - ) - else: - # Multiple devices: each forwards to the next in the ring - for i, service_name in enumerate(device_names): - if i < len(device_names) - 1: - # Forward to next device - next_service_map[service_name] = device_names[i + 1] - else: - # Last device wraps to first device - next_service_map[service_name] = device_names[0] - - # Log ring topology - for service_name in device_names: - logger.info( - "Ring: %s -> %s", service_name, next_service_map[service_name] - ) - - # Compute window size for each device: total_layers_per_device / k - window_sizes: Dict[str, int] = {} - for service_name, rounds_layers in layer_assignments.items(): - # Flatten to count total layers - total_layers = sum(len(round_layers) for round_layers in rounds_layers) - if total_layers > 0: - window_size = max(1, total_layers // solution.k) - window_sizes[service_name] = window_size - logger.info( - "Window size for %s: %d (total_layers=%d, k=%d)", - service_name, - window_size, - total_layers, - solution.k, - ) - else: - # FIXME: how to handle? - logger.error( - "No layers assigned to %s, setting window size to 1", - service_name, - ) - window_sizes[service_name] = 1 - - logger.info("Layer assignments (by rounds): %s", layer_assignments) - # return layer_assignments, next_service_map, window_size - - return [ - LayerAssignment( - service=name, - layers=layer_assignments[name], - next_service=next_service_map[name], - window_size=window_sizes[name], - ) - for name in device_names - ] - async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseModel: """Handle chat completion request. @@ -1301,7 +1115,7 @@ async def _handle_completion( Returns: Chat response """ - profile_enabled = bool(getattr(req, "profile", False)) + profile_enabled = bool(req.profile) t_start = time.perf_counter() t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" @@ -1541,7 +1355,7 @@ async def gen(): repetition_penalty=req.repetition_penalty, repetition_context_size=req.repetition_context_size, logit_bias=req.logit_bias, - ), + ), # type: ignore ), arange(req.max_tokens or 0), ): @@ -1603,7 +1417,7 @@ async def gen(): repetition_context_size=req.repetition_context_size, logit_bias=req.logit_bias, ), - ), + ), # type: ignore arange(req.max_tokens or 0), ): detok.add_token(token) diff --git a/src/dnet/ring/api/utils.py b/src/dnet/ring/api/utils.py index 2d485403..1742ace0 100644 --- a/src/dnet/ring/api/utils.py +++ b/src/dnet/ring/api/utils.py @@ -1,21 +1,21 @@ """API utilities for ring topology generation.""" import asyncio -import time -from typing import AsyncGenerator, Dict, Tuple - +from typing import AsyncGenerator, Dict, Tuple, Optional +from dnet_p2p import DnetDeviceProperties +from dnet_p2p.thunderbolt import ThunderboltConnection import mlx.core as mx import numpy as np +from distilp import DeviceProfile -from ...protos import dnet_ring_pb2 -from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub -from ...utils.logger import logger -from .models import ChatBaseParams +from dnet.protos import dnet_ring_pb2 +from dnet.protos.dnet_ring_pb2_grpc import DnetRingServiceStub +from dnet.utils.logger import logger +from dnet.utils.time import utc_epoch_now +from dnet.ring.common import LayerAssignment -def utc_epoch_now() -> int: - """Return current UTC timestamp in milliseconds.""" - return int(time.time() * 1000) +from .models import ChatBaseParams def create_generate_step_for_ring_with_grpc( @@ -126,3 +126,180 @@ async def _step(y): y, logprobs = next_y, next_logprobs return generate_step + + +def compute_layer_assignments( + device_names: list[str], + solution_w: list[int], + solution_k: int, + shards: Dict[str, DnetDeviceProperties], +) -> list[LayerAssignment]: + """Compute round-aware layer assignments, next node mapping, and prefetch windows from solver output. + + Args: + device_names: Device names in solver order + solution_w: Solver result `w` for list of assigned layers. + solution_k: Solver result `k` for number of rounds. + shards: Discovered shards + + Returns: + Tuple of (layer assignments per device per round, next service per device in ring, prefetch window per device) + """ + if len(solution_w) != len(shards) or len(device_names) != len(shards): + raise ValueError( + f"Device count mismatch: solution={len(solution_w)}, shards={len(shards)}" + ) + + num_layers = sum(solution_w) * solution_k + logger.info( + "Distributing %d layers to %d devices in %d rounds", + num_layers, + len(shards), + solution_k, + ) + + layer_assignments: Dict[str, list[list[int]]] = { + name: [[] for _ in range(solution_k)] for name in device_names + } + current_layer = 0 + for round_idx in range(solution_k): + for device_idx, device_name in enumerate(device_names): + for _ in range(solution_w[device_idx]): + layer_assignments[device_name][round_idx].append(current_layer) + current_layer += 1 + assert current_layer == num_layers, ( + f"Assigned {current_layer} layers, expected {num_layers}" + ) + + # Compute next service for each device in ring topology + # In ring: dev1 -> dev2 -> ... -> devN -> dev1 (wraps around) + # Each shard will detect when processing the final layer and send to API + next_service_map: Dict[str, Optional[str]] = {} + + if len(device_names) == 1: + # Single device: forwards to itself in a loop + next_service_map[device_names[0]] = device_names[0] + logger.info("Ring (single device): %s -> SELF (loops back)", device_names[0]) + else: + # Multiple devices: each forwards to the next in the ring + for i, service_name in enumerate(device_names): + if i < len(device_names) - 1: + # Forward to next device + next_service_map[service_name] = device_names[i + 1] + else: + # Last device wraps to first device + next_service_map[service_name] = device_names[0] + + # Log ring topology + for service_name in device_names: + logger.info("Ring: %s -> %s", service_name, next_service_map[service_name]) + + # Compute window size for each device: total_layers_per_device / k + window_sizes: Dict[str, int] = {} + for service_name, rounds_layers in layer_assignments.items(): + # Flatten to count total layers + total_layers = sum(len(round_layers) for round_layers in rounds_layers) + if total_layers > 0: + window_size = max(1, total_layers // solution_k) + window_sizes[service_name] = window_size + logger.info( + "Window size for %s: %d (total_layers=%d, k=%d)", + service_name, + window_size, + total_layers, + solution_k, + ) + else: + # FIXME: how to handle? + logger.error( + "No layers assigned to %s, setting window size to 1", + service_name, + ) + window_sizes[service_name] = 1 + + logger.info("Layer assignments (by rounds): %s", layer_assignments) + # return layer_assignments, next_service_map, window_size + + return [ + LayerAssignment( + service=name, + layers=layer_assignments[name], + next_service=next_service_map[name], + window_size=window_sizes[name], + ) + for name in device_names + ] + + +def optimize_device_ordering( + shard_profiles: Dict[str, DeviceProfile], + thunderbolt_conns: Dict[str, Dict[str, ThunderboltConnection]], +) -> list[str]: + """Optimize device ordering to place Thunderbolt-connected devices adjacently. + + Args: + shard_profiles: Collected shard profiles + thunderbolt_conns: Thunderbolt connections mapping (device -> {neighbor -> connection_info}) + + Returns: + Optimized list of device names with head devices first and Thunderbolt neighbors adjacent + """ + device_names = list(shard_profiles.keys()) + + # Find all head devices (multiple shards can run on same machine as API) + head_devices = [] + for device_name, profile_data in shard_profiles.items(): + if profile_data.is_head: + head_devices.append(device_name) + + if not head_devices: + logger.warning("No head device found in profiles, using first device") + head_devices = [device_names[0]] if device_names else [] + + logger.info("Found %d head device(s): %s", len(head_devices), head_devices) + + # FIXME: shards on the same machine should be adjacent too! + + # Build adjacency graph of Thunderbolt connections + # Graph: device_name -> set of connected device names + tb_graph: Dict[str, set[str]] = {name: set() for name in device_names} + for device_name, neighbors in thunderbolt_conns.items(): + if device_name in tb_graph: + for neighbor_name in neighbors.keys(): + if neighbor_name in tb_graph: + tb_graph[device_name].add(neighbor_name) + tb_graph[neighbor_name].add(device_name) + + # Greedy ordering: Start with all head devices, then pick neighbors with most TB connections + ordered = head_devices.copy() + remaining = set(device_names) - set(head_devices) + + while remaining: + best_candidate = None + best_score = -1 + + # For each remaining device, calculate connection score to already-ordered devices + for candidate in remaining: + # Count Thunderbolt connections to devices already in the order + score = sum( + 1 for ordered_dev in ordered if ordered_dev in tb_graph[candidate] + ) + + # Prioritize devices with TB connections, otherwise any device is fine + if score > best_score: + best_score = score + best_candidate = candidate + + # Add best candidate (or any remaining if no TB connections exist) + if best_candidate: + ordered.append(best_candidate) + remaining.remove(best_candidate) + else: + # Fallback: just pick any remaining device + next_device = remaining.pop() + ordered.append(next_device) + + logger.info("Optimized device ordering: %s", ordered) + logger.info("Thunderbolt graph: %s", tb_graph) + + return ordered diff --git a/src/dnet/ring/shard/attrib.py b/src/dnet/ring/shard/attrib.py new file mode 100644 index 00000000..a49feb42 --- /dev/null +++ b/src/dnet/ring/shard/attrib.py @@ -0,0 +1,92 @@ +from typing import Callable, Optional, Any +from fastapi import FastAPI +import grpc.aio as aio_grpc +import asyncio +import threading +from mlx.core import Dtype +from queue import Queue +from concurrent.futures import ThreadPoolExecutor +from dnet_p2p import DnetDeviceProperties, DnetP2P + +from dnet.ring.data_types import ActivationMessage +from dnet.ring.memory_pool import LayerAwareMemoryPool +from dnet.ring.model.base import BaseRingModel +from dnet.ring.shard.config import ShardConfig +from dnet.utils.model import ModelMetadata +from dnet.ring.weight_cache import WeightCache +from dnet.ring.observability import Profiler + + +class RingShardNodeAttributes: + """A mixin class that defines the attributes for a ring shard node, intended + to be shared by all mixins & ensure type-safety among them.""" + + _mlx_lock: threading.Lock + + # prefetch-related + _prefetch_scheduled: set[int] + _prefetch_pending: set[int] + _prefetch_pause: threading.Event + _prefetch_active = 0 + + _streaming_enabled: bool + + _resident_windows: int + _recent_windows: list[list[int]] + node_id: int + running: bool + weight_cache: WeightCache + weight_prefetch_queue: Queue[int] + _materialize_prefetch_default: bool + executor: ThreadPoolExecutor + _touch_during_compute: bool + _compute_busy: threading.Event + + activation_computed_queue: asyncio.Queue[ActivationMessage] + _defer_unload: bool + _warmup_keep_flag: bool + + # node + grpc_port: int + http_port: int + app: FastAPI + discovery: DnetP2P + next_node: Optional[DnetDeviceProperties] + next_node_stub: Optional[Any] + + config: ShardConfig + + _sync_per_layer: bool + _sync_every_n: int + + # profiler + _profile: bool + _prof: Profiler + + input_pool: LayerAwareMemoryPool + output_pool: Optional[LayerAwareMemoryPool] + activation_recv_queue: Queue[ActivationMessage] + + # model + model: Optional[BaseRingModel] + model_metadata: Optional[ModelMetadata] + model_path: Optional[str] + + _wire_dtype_str: str + _wire_mx_dtype: Dtype + _assigned_set: set[int] + assigned_layers: list[int] + window_size: int + + _assigned_sorted: list[int] + _bound_versions: dict[int, int] + + next_node_channel: Optional[aio_grpc.Channel] + next_node_stub: Optional[Any] + + # shared methods + _prefetch_to_ram: Callable[[int], None] + _clear_prefetch_state: Callable[[], None] + _enqueue_weight_prefetch: Callable[[int], None] + _next_local_layers: Callable[[int, int], list[int]] + _get_or_make_kv: Callable[[str], list] diff --git a/src/dnet/ring/shard/send.py b/src/dnet/ring/shard/comms.py similarity index 70% rename from src/dnet/ring/shard/send.py rename to src/dnet/ring/shard/comms.py index ed5ef180..dc351bf9 100644 --- a/src/dnet/ring/shard/send.py +++ b/src/dnet/ring/shard/comms.py @@ -3,61 +3,49 @@ import asyncio import time from dataclasses import dataclass -from typing import Callable, Optional, Any +from typing import Mapping, Optional, Any from urllib.parse import urlparse import grpc from grpc import aio as aio_grpc import numpy as np -from mlx.core import Dtype -from dnet.ring.memory_pool import LayerAwareMemoryPool +from dnet_p2p import ( + DnetDeviceProperties, + discover_thunderbolt_connection, +) +from dnet_p2p.thunderbolt import ThunderboltConnection +from dnet.utils.latency import DeviceLatencyResult, LatencyMeasurement, LatencyResults + from ...utils.grpc_config import GRPC_AIO_OPTIONS from ...utils.logger import logger from ...utils.time import utc_epoch_now from ...utils.serialization import dtype_map, tensor_to_bytes -from ...utils.model import ModelMetadata from ...protos import shard_api_comm_pb2, shard_api_comm_pb2_grpc, dnet_ring_pb2 from ..data_types import ActivationMessage +from .attrib import RingShardNodeAttributes + + +@dataclass +class _StreamCtx: + nonce: str + queue: asyncio.Queue + call: Optional[Any] = None + ack_task: Optional[asyncio.Task] = None + open: bool = False + disabled: bool = False + disabled_until: float = 0.0 + last_seq: int = 0 + last_activity_t: float = 0.0 + -class SendMixin: - next_node_stub: Optional[Any] - activation_computed_queue: asyncio.Queue[ActivationMessage] - node_id: int - _profile: bool - output_pool: LayerAwareMemoryPool - running: bool - model_metadata: ModelMetadata - _prefetch_pending: set[int] - _prefetch_active = 0 - weight_prefetch_queue: asyncio.Queue[int] - _wire_dtype_str: str - _wire_mx_dtype: Dtype - _assigned_set: set[int] - window_size: int - - _prefetch_to_ram: Callable[[int], None] - _enqueue_weight_prefetch: Callable[[int], None] - _prefetch_pause: asyncio.Event - _next_local_layers: Callable[[int, int], list[int]] - _clear_prefetch_state: Callable[[], None] - - @dataclass - class _StreamCtx: - nonce: str - queue: asyncio.Queue - call: Optional[Any] = None - ack_task: Optional[asyncio.Task] = None - open: bool = False - disabled: bool = False - disabled_until: float = 0.0 - last_seq: int = 0 - last_activity_t: float = 0.0 +class CommsMixin(RingShardNodeAttributes): + """Communication-related methods for ring shard node.""" async def _ensure_stream(self, nonce: str): try: - if not getattr(self, "_streaming_enabled", False): + if not self._streaming_enabled: return None if self.next_node_stub is None: return None @@ -77,7 +65,7 @@ async def _ensure_stream(self, nonce: str): pass return ctx - ctx = SendMixin._StreamCtx(nonce=nonce, queue=asyncio.Queue(maxsize=64)) + ctx = _StreamCtx(nonce=nonce, queue=asyncio.Queue(maxsize=64)) if not hasattr(self, "_streams"): self._streams = {} self._streams[nonce] = ctx @@ -150,7 +138,7 @@ async def _stream_sweeper(self): idle_s = float(getattr(self, "_stream_idle_s", 2.0)) while getattr(self, "running", False): try: - if not getattr(self, "_streaming_enabled", False): + if not self._streaming_enabled: await asyncio.sleep(1.0) continue now = asyncio.get_running_loop().time() @@ -203,9 +191,10 @@ async def _send_worker(self): logger.error("Send worker error: %s", e) async def _send_activation(self, activation_msg: ActivationMessage): - if not self._check_model_loaded() or not self.output_pool: + if not self.output_pool or not self.model_metadata: logger.error( - "Node %s: Cannot send activation - model not loaded", self.node_id + "Node %s: Cannot send activation - output pool / model metadata not initialized", + self.node_id, ) return try: @@ -220,14 +209,14 @@ async def _send_activation(self, activation_msg: ActivationMessage): logger.error("Invalid gRPC callback URL for token: %s", cb) return # Ensure API channel/stub - if (getattr(self, "api_channel", None) is None) or ( - addr != getattr(self, "api_address", None) - ): + if (self.api_channel is None) or (addr != self.api_address): + # close existing channel if any try: - if getattr(self, "api_channel", None) is not None: + if self.api_channel is not None: await self.api_channel.close() except Exception: pass + self.api_address = addr self.api_channel = aio_grpc.insecure_channel( addr, options=GRPC_AIO_OPTIONS @@ -267,7 +256,9 @@ async def _send_activation(self, activation_msg: ActivationMessage): return used_pool = False - shaped = getattr(activation_msg, "tensor", None) + + # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check + shaped = activation_msg.tensor if shaped is None: output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) if output_buffer is None: @@ -303,7 +294,9 @@ async def _send_activation(self, activation_msg: ActivationMessage): wire_np_dtype = dtype_map[self._wire_dtype_str] except Exception: wire_np_dtype = np.float16 # reasonable default fallback + if isinstance(shaped, np.ndarray): + logger.warning("Activation tensor is a numpy array!!!") if shaped.dtype != wire_np_dtype: shaped = shaped.astype(wire_np_dtype, copy=False) else: @@ -380,13 +373,12 @@ async def _send_activation(self, activation_msg: ActivationMessage): grpc.StatusCode.DEADLINE_EXCEEDED, }: logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting to %s", + "SendActivation attempt %s/%s failed (%s); reconnecting...", attempt, max_attempts, code.name, - self.next_node_address, # FIXME: will be `next_node` ! ) - await self._reconnect_next_node() # FIXME: !!! + await self._reconnect_next_node() await asyncio.sleep(min(0.25 * attempt, 1.0)) continue raise @@ -501,3 +493,182 @@ async def _send_activation(self, activation_msg: ActivationMessage): pass except Exception as e: logger.exception("Error sending activation: %s", e) + + async def _connect_next_node(self) -> bool: + """Connect to next node in ring. + + Returns: + True if connected or no next node, False on failure + """ + if not self.next_node: + logger.info(f"Shard node {self.node_id} is the final shard (no next node)") + return True + + if self.next_node_channel: + logger.debug(f"Shard node {self.node_id} already connected to next node.") + return True + + try: + # use thunderbolt here if available + this_properties = self.discovery.get_own_properties() + thunderbolt_conn = discover_thunderbolt_connection( + this_properties, + self.next_node, + ) + next_ip = ( + thunderbolt_conn.ip_addr + if thunderbolt_conn + else self.next_node.local_ip + ) + address = f"{next_ip}:{self.next_node.shard_port}" + logger.info( + f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" + ) + + self.next_node_channel = aio_grpc.insecure_channel(address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + self.next_node_stub = DnetRingServiceStub(self.next_node_channel) + return True + except Exception as e: + logger.warning( + f"Shard node {self.node_id} failed to connect to next node {address}: {e}" + ) + self.next_node_channel = None + self.next_node_stub = None + return False + + async def _reconnect_next_node(self) -> bool: + try: + if self.next_node_channel: + await self.next_node_channel.close() + except Exception: + pass + self.next_node_channel = None + self.next_node_stub = None + return await self._connect_next_node() + + async def _measure_latency_to_devices( + self, + devices: Mapping[str, DnetDeviceProperties], + thunderbolts: Mapping[str, ThunderboltConnection], + payload_sizes: list[int], + ) -> LatencyResults: + """Measure latency to all devices except self. + + Args: + devices: Device information mapping + thunderbolts: Thunderbolt connection information + payload_sizes: List of payload sizes to test + + Returns: + Latency measurement results + """ + latency_results_dict: dict[str, DeviceLatencyResult] = {} + + for service_name, device_info in devices.items(): + # Skip measuring latency to ourselves + # FIXME: just equals check should suffice here? + if service_name.startswith(self.discovery.instance_name()): + logger.debug("Skipping latency measurement to self: %s", service_name) + continue + + # Skip measuring latency to API (manager) devices + if device_info.is_manager: + logger.debug( + "Skipping latency measurement to manager/API: %s", service_name + ) + continue + + try: + shard_port = device_info.shard_port + + # Check for Thunderbolt connection + if service_name in thunderbolts: + tb_data = thunderbolts[service_name] + service_ip = tb_data.ip_addr + logger.info( + "Using Thunderbolt for %s at %s, connected to instance %s", + service_name, + service_ip, + tb_data.instance, + ) + else: + # No Thunderbolt, use WiFi + service_ip = device_info.local_ip + + if not shard_port or not service_ip: + logger.warning( + "No shard_port or local_ip for device %s", service_name + ) + continue + + # Connect to target shard's gRPC server + target_address = f"{service_ip}:{shard_port}" + channel = aio_grpc.insecure_channel(target_address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + stub = DnetRingServiceStub(channel) + + # Measure latency for each payload size + latency_measurements: list[LatencyMeasurement] = [] + for payload_size in payload_sizes: + # Create dummy payload + dummy_data = b"x" * payload_size + + start_time = time.perf_counter() + timestamp_ms = int(time.time() * 1000) + + request = dnet_ring_pb2.LatencyMeasureRequest( + requester_id=str(self.node_id), + payload_size=payload_size, + dummy_data=dummy_data, + timestamp=timestamp_ms, + ) + + response = await stub.MeasureLatency(request) # type: ignore + end_time = time.perf_counter() + + if response.success: + latency_ms = (end_time - start_time) * 1000 + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + latency_ms=round(latency_ms, 2), + success=True, + error=None, + ) + ) + else: + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + success=False, + error=response.message, + latency_ms=0, + ) + ) + + # Store results + result = DeviceLatencyResult( + target_node_id=response.node_id if response.success else None, + measurements=latency_measurements, + success=True, + error=None, + ) + latency_results_dict[service_name] = result + + # Close channel + await channel.close() + + except Exception as e: + logger.error("Error measuring latency to %s: %s", service_name, e) + result = DeviceLatencyResult( + target_node_id=None, + success=False, + error=str(e), + measurements=[], + ) + latency_results_dict[service_name] = result + + return LatencyResults(results=latency_results_dict) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index ee74bbdb..d3497ac7 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -11,14 +11,16 @@ from ...utils.serialization import mlx_dtype_map from ...utils.time import utc_epoch_now from ..data_types import ActivationMessage +from .attrib import RingShardNodeAttributes -class ComputeMixin: +class ComputeMixin(RingShardNodeAttributes): """Split out the hot-path compute from RingShardNode.""" def _process_activation(self, activation_msg: ActivationMessage): if ( - not self._check_model_loaded() + not self.model + or not self.model_metadata or not self.weight_cache or not self.input_pool or not self.output_pool @@ -172,23 +174,17 @@ def _process_activation(self, activation_msg: ActivationMessage): pass layer_times_ms: list[tuple[int, float]] = [] for i, lyr in enumerate(window_layers): - t_l0 = ( - time.perf_counter() if getattr(self, "_profile", False) else 0.0 - ) + t_l0 = time.perf_counter() if self._profile else 0.0 with self._mlx_lock: x = self.model.apply_single_layer(lyr, x, cache=kv) # Optional per-n-layer sync for profiling, gated by settings - if getattr(self, "_profile", False) and getattr( - self, "_sync_per_layer", False - ): + if self._profile and self._sync_per_layer: do_sync = True - try: - n = int(getattr(self, "_sync_every_n", 0) or 0) - except Exception: - n = 0 + n = self._sync_every_n if n > 0 and (i % n) != 0: do_sync = False + if do_sync: try: with self._mlx_lock: @@ -361,7 +357,7 @@ def _process_activation(self, activation_msg: ActivationMessage): pass nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model + if nxt >= self.model_metadata.num_layers: # End of model try: with self._mlx_lock: y = self.model.normalize(x_cast) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d909d0c4..99fe68cf 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -8,16 +8,29 @@ from typing import Any, Dict, List, Optional, cast from bisect import bisect_left as _bisect_left +from socket import gethostname +from secrets import token_hex + import mlx.core as mx import numpy as np -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from grpc import aio as aio_grpc from dnet_p2p import DnetP2P, DnetDeviceProperties +from dnet.utils.latency import calculate_median_latency_seconds +from dnet.utils.serialization import tensor_to_bytes + +from .servicer import ShardServicer +from dnet.protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server + from .models import ( + HealthResponse, ShardLoadModelRequest, ShardLoadModelResponse, + ShardProfileRequest, + ShardProfileResponse, ShardUnloadModelResponse, ) @@ -44,12 +57,11 @@ from ..model import get_ring_model from .compute import ComputeMixin from .prefetch import PrefetchMixin -from .send import SendMixin -from .startup import StartupMixin +from .comms import CommsMixin from ..weight_cache import WeightCache -class RingShardNode(ComputeMixin, PrefetchMixin, SendMixin, StartupMixin): +class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): """Single shard node in the distributed inference ring with dynamic model loading.""" def __init__( @@ -114,6 +126,7 @@ def __init__( # Offloading/config-derived params self._resident_windows = int(self.config.resident_windows) + self._recent_windows = [] self._defer_unload = True self._await_next_ready = False set_prefetch_mode(self.config.prefetch_mode) @@ -306,7 +319,11 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse has_start = 0 in self.assigned_layers has_end = (self.model_metadata.num_layers - 1) in self.assigned_layers tied = bool( - getattr(getattr(self.model, "config", object()), "tie_word_embeddings", False) + getattr( + getattr(self.model, "config", object()), + "tie_word_embeddings", + False, + ) ) loaded_cnt = 0 @@ -317,7 +334,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse if tied: # End shard needs embeddings for tied projection if not has_start: - loaded_cnt += load_embeddings(self.model_metadata, self.model) + loaded_cnt += load_embeddings(self.model_metadata, self.model) # fmt: skip try: setattr(self.model, "force_tied_head", True) except Exception: @@ -333,9 +350,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse int(tied), ) except Exception as e: - logger.warning( - "Failed to load API-layer weights: %s", e - ) + logger.warning("Failed to load API-layer weights: %s", e) # Reset prefetch tracking self._prefetch_scheduled = set() @@ -349,7 +364,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse if self.next_node: await self._connect_next_node() else: - logger.info("Node %s: No next node configured", self.node_id) + logger.warning("Node %s: No next node configured", self.node_id) # Warmup if requested (run in executor to avoid blocking event loop) if req.warmup: @@ -468,17 +483,9 @@ async def unload_model(self) -> ShardUnloadModelResponse: message=f"Error unloading model: {str(e)}", ) - def _check_model_loaded(self) -> bool: - """Check if model is loaded. - - Returns: - True if model is loaded, False otherwise - """ - return self.model is not None and self.model_metadata is not None - async def reset_cache(self) -> None: """Reset LLM KV cache.""" - if not self._check_model_loaded(): + if not self.model: logger.warning( "Node %s: Cannot reset cache - no model loaded", self.node_id ) @@ -497,6 +504,13 @@ async def reset_cache(self) -> None: async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" + if self.input_pool is None: + logger.error( + "Node %s: Cannot receive activation - input pool not initialized", + self.node_id, + ) + return + t_recv = time.perf_counter() await self._connect_next_node() @@ -808,7 +822,9 @@ async def _ingress_worker(self): activation_msg.nonce, ) try: - self.input_pool.release(activation_msg.pool_id) + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) except Exception: pass else: @@ -825,6 +841,8 @@ async def _ingress_worker(self): def _get_or_make_kv(self, nonce: str) -> list: """Return a per-nonce KV cache list for this shard's local layers.""" + if not self.model: + raise RuntimeError("Model not initialized") try: now = time.perf_counter() ttl = float(getattr(self, "_kv_ttl_s", 30.0)) @@ -865,6 +883,13 @@ def _prepare_activation_message_blocking( Returns None on failure. """ + if self.input_pool is None: + logger.error( + "Node %s: Cannot prepare activation - input pool not initialized", + self.node_id, + ) + return None + try: activation = request.activation if "|" in activation.dtype: @@ -1067,6 +1092,348 @@ def _clear_prefetch_state(self): except Exception: pass try: - self.weight_cache.cancel_all_prefetch() + if self.weight_cache: + self.weight_cache.cancel_all_prefetch() + except Exception: + pass + + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): + self.running = True + # Capture the main event loop for cross-thread scheduling + try: + self._loop = asyncio.get_running_loop() + except Exception: + self._loop = None + await self._start_grpc_server() + await self._start_http_server(shutdown_trigger) + await asyncio.sleep(0.2) + + self.background_tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._prefetch_worker()), + asyncio.create_task(self._send_worker()), + ] + # Start idle sweeper to close silent streams + try: + if getattr(self, "_streaming_enabled", False) and hasattr( + self, "_stream_sweeper" + ): + self.background_tasks.append( + asyncio.create_task(self._stream_sweeper()) + ) + except Exception: + pass + + self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) + self.compute_thread.start() + + self._start_discovery() + logger.info( + "Shard node %s started on gRPC port %s HTTP port %s", + self.node_id, + self.grpc_port, + self.http_port, + ) + + def _start_discovery(self) -> None: + """Start mDNS discovery service.""" + + hostname = gethostname() + # TODO: optionally take shard name from CLI + instance = f"shard-{token_hex(4)}-{hostname}" + self.discovery.create_instance( + instance, + hostname, + "0.0.0.0", # Binds to all addresses + self.http_port, # HTTP port + self.grpc_port, # gRPC port + is_manager=False, # Shard is never a manager + ) + self.discovery.start() + logger.info( + "Discovery service started for shard node %s with name %s", + self.node_id, + self.discovery.fullname(), + ) + + async def _start_grpc_server(self) -> None: + """Start gRPC server.""" + self.server = aio_grpc.server() + + # Add the ring servicer; shard acts as client for ShardApiService (to API) + servicer = ShardServicer(self) # type: ignore # FIXME: !!! + add_DnetRingServiceServicer_to_server(servicer, self.server) + + listen_addr = f"[::]:{self.grpc_port}" + self.server.add_insecure_port(listen_addr) + await self.server.start() + logger.info( + "Shard node %s gRPC server started on %s", self.node_id, listen_addr + ) + try: + await asyncio.get_running_loop().run_in_executor( + self.executor, self._warmup_serialization + ) + logger.info("Warmup serialization completed") + except Exception as e: + logger.warning("Warmup serialization failed: %s", e) + + def _warmup_serialization(self): + try: + dummy = mx.random.normal((1024, 1024), dtype=mx.float32) + dummy16 = dummy.astype(self._wire_mx_dtype) + _ = tensor_to_bytes(dummy16) except Exception: pass + + def _warmup_shard(self): + logger.info( + "[WARMUP] Starting shard warmup with window size %s", self.window_size + ) + if not self.model or not self.model_metadata or not self.weight_cache: + logger.warning("[WARMUP] No model loaded; skipping warmup") + return + + batch_size, seq_len = 1, 1 + hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) + x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) + start_time = time.perf_counter() + + max_windows = max(1, self.config.warmup_windows) + windows: list[list[int]] = [] + for window_start in range(0, len(self._assigned_sorted), self.window_size): + window_end = min( + window_start + self.window_size, len(self._assigned_sorted) + ) + windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): + weights_to_bind = {} + for layer_id in window_layers: + weights = self.weight_cache.get_weight(layer_id) + if weights: + for k, v in weights.items(): + weights_to_bind[k] = v + if weights_to_bind: + self.model.load_weights(list(weights_to_bind.items()), strict=False) + try: + for layer_id in window_layers: + x = self.model.apply_single_layer(layer_id, x, cache=None) + _s = mx.sum(x) + mx.eval(_s) + except Exception: + pass + try: + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + except Exception: + pass + if not self._warmup_keep_flag: + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + except Exception: + pass + try: + self.weight_cache.evict_layers(window_layers) + except Exception: + pass + total_time = (time.perf_counter() - start_time) * 1000 + self._warmup_completed = True + logger.info( + "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", + total_time, + min(len(windows), max_windows), + int(self._warmup_keep_flag), + ) + + async def _start_http_server(self, shutdown_trigger: Any) -> None: + """Start HTTP server. + + Args: + shutdown_trigger: Shutdown trigger function + """ + from hypercorn import Config + import hypercorn.asyncio as aio_hypercorn + + await self._setup_routes() + + # Start HTTP server in background + config = Config.from_mapping( + bind=f"0.0.0.0:{self.http_port}", + log_level="info", + log_config=None, + use_reloader=False, + h2c=False, + ) + + # Start the server as a background task + self.http_server = asyncio.create_task( + aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore + ) + logger.info( + "Shard node %s HTTP server started on port %s", self.node_id, self.http_port + ) + + async def _setup_routes(self) -> None: + """Setup HTTP routes.""" + + @self.app.get("/health") + async def health() -> HealthResponse: + try: + instance = self.discovery.instance_name() + except Exception: + instance = None + return HealthResponse( + status="ok", + node_id=self.node_id, + running=self.running, + model_loaded=self.model is not None, + model_path=self.model_path, + assigned_layers=self.assigned_layers, + queue_size=self.activation_recv_queue.qsize(), + grpc_port=self.grpc_port, + http_port=self.http_port, + instance=instance, + ) + + @self.app.post("/profile") + async def profile(req: ShardProfileRequest) -> ShardProfileResponse: + logger.info("Received /profile request") + try: + # Measure latencies + latency_results = await self._measure_latency_to_devices( + req.devices, req.thunderbolts, req.payload_sizes + ) + + # Profile device using dperf + device_profile = await self._profile_device( + req.repo_id, req.max_batch_exp + ) + + # Overwrite `t_comm` with median latency (subprocess returns a dict) + median_latency = calculate_median_latency_seconds(latency_results) + if median_latency is not None: + device_profile["t_comm"] = float(median_latency) + logger.info( + f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" + ) + else: + logger.warning( + "No valid latency measurements, keeping default t_comm" + ) + + # Return the dict payload directly + return ShardProfileResponse( + profile=device_profile, + latency=latency_results, + ) + except Exception as e: + logger.error(f"Error in /profile endpoint: {e}") + raise + + @self.app.post("/load_model") + async def load_model_endpoint( + req: ShardLoadModelRequest, + ) -> ShardLoadModelResponse: + """Load model with specified layers.""" + try: + logger.info( + f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " + f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " + f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" + ) + result = await self.load_model(req) + return result + + except Exception as e: + logger.error(f"Error in /load_model endpoint: {e}") + return ShardLoadModelResponse( + success=False, + message=f"Error: {str(e)}", + layers_loaded=[], + load_time_ms=0.0, + ) + + @self.app.post("/unload_model") + async def unload_model_endpoint() -> ShardUnloadModelResponse: + """Unload current model.""" + try: + logger.info("HTTP /unload_model") + result = await self.unload_model() + return result + + except Exception as e: + logger.error(f"Error in /unload_model endpoint: {e}") + return ShardUnloadModelResponse( + success=False, + message=f"Error: {str(e)}", + ) + + @self.app.post("/warm") + # FIXME: add pydantic type here + async def warm(request: Request) -> JSONResponse: + try: + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) + except Exception as e: + logger.error("/warm failed: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: + """Profile device using dperf in a subprocess and return a dict. + + Args: + repo_id: Hugging Face repository ID + max_batch_exp: Maximum batch size exponent (2^max_batch_exp) + + Returns: + Device profile information as a plain dict + """ + from ...utils.profile_subproc import profile_device_via_subprocess + + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) + logger.info("Device profiling completed for node %s", self.node_id) + return profile_dict + + # FIXME: this is not used, use it within healthcheck + # this checks the health of the entire ring, but requires a bit more setup + # e.g. it should not get into infinite loop + async def _health_check(self): + try: + health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) + response = await self.next_node_stub.HealthCheck(health_request) # type: ignore # FIXME: this assumes an existing connection + logger.info( + "Shard node %s successfully pinged: %s, healthy: %s", + self.node_id, + response.node_id, + response.healthy, + ) + return True + except Exception as e: + logger.warning( + "Shard node %s failed to ping next node %s", + self.node_id, + e, + ) + return False diff --git a/src/dnet/ring/shard/prefetch.py b/src/dnet/ring/shard/prefetch.py index 290a1907..12a3b528 100644 --- a/src/dnet/ring/shard/prefetch.py +++ b/src/dnet/ring/shard/prefetch.py @@ -1,35 +1,17 @@ from __future__ import annotations import asyncio -from concurrent.futures import ThreadPoolExecutor import time import logging from typing import Dict -from threading import Lock -from queue import Queue - import mlx.core as mx -from dnet.ring.weight_cache import WeightCache from ...utils.logger import logger +from .attrib import RingShardNodeAttributes -class PrefetchMixin: - _mlx_lock: Lock - _prefetch_scheduled: set[int] - _prefetch_pending: set[int] - _prefetch_pause: asyncio.Event - _profile: bool - node_id: int - running: bool - weight_cache: WeightCache - weight_prefetch_queue: Queue[int] - _materialize_prefetch_default: bool - executor: ThreadPoolExecutor - _touch_during_compute: bool - _compute_busy: asyncio.Event - +class PrefetchMixin(RingShardNodeAttributes): def _touch_weights(self, layer_id: int, weights: Dict[str, mx.array]) -> None: mode = getattr(self, "_prefetch_touch_mode", "none") if mode in ("", "none"): diff --git a/src/dnet/ring/shard/servicer.py b/src/dnet/ring/shard/servicer.py index 25e1b100..93e8627e 100644 --- a/src/dnet/ring/shard/servicer.py +++ b/src/dnet/ring/shard/servicer.py @@ -158,4 +158,4 @@ async def StreamActivations(self, request_iterator, context): except Exception as e: logger.error("[STREAM][RX] error: %s", e) - await context.abort(grpc.StatusCode.INTERNAL, str(e)) + context.abort(grpc.StatusCode.INTERNAL, str(e)) diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py deleted file mode 100644 index b630438b..00000000 --- a/src/dnet/ring/shard/startup.py +++ /dev/null @@ -1,564 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from typing import Any, Dict, List, Mapping -import threading -from socket import gethostname -from secrets import token_hex - -import mlx.core as mx -from fastapi import Request -from fastapi.responses import JSONResponse -from grpc import aio as aio_grpc - -from hypercorn import Config -import hypercorn.asyncio as aio_hypercorn -from dnet_p2p.thunderbolt import ThunderboltConnection -from dnet_p2p import ( - DnetDeviceProperties, - discover_thunderbolt_connection, -) - -from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server -from .servicer import ShardServicer -from ...utils.logger import logger -from ...utils.serialization import tensor_to_bytes -from ...utils.latency import ( - DeviceLatencyResult, - LatencyMeasurement, - LatencyResults, - calculate_median_latency_seconds, -) -from .models import ( - HealthResponse, - ShardLoadModelRequest, - ShardLoadModelResponse, - ShardProfileRequest, - ShardProfileResponse, - ShardUnloadModelResponse, -) -from ...protos import dnet_ring_pb2 - - -class StartupMixin: - async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): - self.running = True - # Capture the main event loop for cross-thread scheduling - try: - self._loop = asyncio.get_running_loop() - except Exception: - self._loop = None - await self._start_grpc_server() - await self._start_http_server(shutdown_trigger) - await asyncio.sleep(0.2) - - self.background_tasks = [ - asyncio.create_task(self._ingress_worker()), - asyncio.create_task(self._prefetch_worker()), - asyncio.create_task(self._send_worker()), - ] - # Start idle sweeper to close silent streams - try: - if getattr(self, "_streaming_enabled", False) and hasattr( - self, "_stream_sweeper" - ): - self.background_tasks.append( - asyncio.create_task(self._stream_sweeper()) - ) - except Exception: - pass - - self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) - self.compute_thread.start() - - self._start_discovery() - logger.info( - "Shard node %s started on gRPC port %s HTTP port %s", - self.node_id, - self.grpc_port, - self.http_port, - ) - - def _start_discovery(self) -> None: - """Start mDNS discovery service.""" - hostname = gethostname() - # TODO: optionally take shard name from CLI - instance = f"shard-{token_hex(4)}-{hostname}" - self.discovery.create_instance( - instance, - hostname, - "0.0.0.0", # Binds to all addresses - self.http_port, # HTTP port - self.grpc_port, # gRPC port - is_manager=False, # Shard is never a manager - ) - self.discovery.start() - logger.info( - "Discovery service started for shard node %s with name %s", - self.node_id, - self.discovery.fullname(), - ) - - async def _start_grpc_server(self) -> None: - """Start gRPC server.""" - self.server = aio_grpc.server() - - # Add the ring servicer; shard acts as client for ShardApiService (to API) - servicer = ShardServicer(self) # type: ignore # FIXME: !!! - add_DnetRingServiceServicer_to_server(servicer, self.server) - - listen_addr = f"[::]:{self.grpc_port}" - self.server.add_insecure_port(listen_addr) - await self.server.start() - logger.info( - "Shard node %s gRPC server started on %s", self.node_id, listen_addr - ) - try: - await asyncio.get_running_loop().run_in_executor( - self.executor, self._warmup_serialization - ) - logger.info("Warmup serialization completed") - except Exception as e: - logger.warning("Warmup serialization failed: %s", e) - - def _warmup_serialization(self): - try: - dummy = mx.random.normal((1024, 1024), dtype=mx.float32) - dummy16 = dummy.astype(self._wire_mx_dtype) - _ = tensor_to_bytes(dummy16) - except Exception: - pass - - def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) - batch_size, seq_len = 1, 1 - hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) - x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) - start_time = time.perf_counter() - try: - default_n = max(1, int(getattr(self, "_resident_windows", 1))) - except Exception: - default_n = 1 - try: - max_windows = max( - 1, - int( - getattr(self, "config", None).warmup_windows - if getattr(self, "config", None) - else default_n - ), - ) - except Exception: - max_windows = default_n - windows: list[list[int]] = [] - for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) - windows.append(self._assigned_sorted[window_start:window_end]) - for wi, window_layers in enumerate(windows[:max_windows]): - weights_to_bind = {} - for layer_id in window_layers: - weights = self.weight_cache.get_weight(layer_id) - if weights: - for k, v in weights.items(): - weights_to_bind[k] = v - if weights_to_bind: - self.model.load_weights(list(weights_to_bind.items()), strict=False) - try: - for layer_id in window_layers: - x = self.model.apply_single_layer(layer_id, x, cache=None) - _s = mx.sum(x) - mx.eval(_s) - except Exception: - pass - try: - for lid in window_layers: - self.weight_cache.decrease_reference(lid) - except Exception: - pass - if not self._warmup_keep_flag: - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - except Exception: - pass - try: - self.weight_cache.evict_layers(window_layers) - except Exception: - pass - total_time = (time.perf_counter() - start_time) * 1000 - self._warmup_completed = True - logger.info( - "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", - total_time, - min(len(windows), max_windows), - int(self._warmup_keep_flag), - ) - - async def _start_http_server(self, shutdown_trigger: Any) -> None: - """Start HTTP server. - - Args: - shutdown_trigger: Shutdown trigger function - """ - await self._setup_routes() - - # Start HTTP server in background - config = Config.from_mapping( - bind=f"0.0.0.0:{self.http_port}", - log_level="info", - log_config=None, - use_reloader=False, - h2c=False, - ) - - # Start the server as a background task - self.http_server = asyncio.create_task( - aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore - ) - logger.info( - "Shard node %s HTTP server started on port %s", self.node_id, self.http_port - ) - - async def _setup_routes(self) -> None: - """Setup HTTP routes.""" - - @self.app.get("/health") - async def health() -> HealthResponse: - try: - instance = self.discovery.instance_name() - except Exception: - instance = None - return HealthResponse( - status="ok", - node_id=self.node_id, - running=self.running, - model_loaded=self._check_model_loaded(), - model_path=self.model_path, - assigned_layers=self.assigned_layers, - queue_size=self.activation_recv_queue.qsize(), - grpc_port=self.grpc_port, - http_port=self.http_port, - instance=instance, - ) - - @self.app.post("/profile") - async def profile(req: ShardProfileRequest) -> ShardProfileResponse: - logger.info("Received /profile request") - try: - # Measure latencies - latency_results = await self._measure_latency_to_devices( - req.devices, req.thunderbolts, req.payload_sizes - ) - - # Profile device using dperf - device_profile = await self._profile_device( - req.repo_id, req.max_batch_exp - ) - - # Overwrite `t_comm` with median latency (subprocess returns a dict) - median_latency = calculate_median_latency_seconds(latency_results) - if median_latency is not None: - device_profile["t_comm"] = float(median_latency) - logger.info( - f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" - ) - else: - logger.warning( - "No valid latency measurements, keeping default t_comm" - ) - - # Return the dict payload directly - return ShardProfileResponse( - profile=device_profile, - latency=latency_results, - ) - except Exception as e: - logger.error(f"Error in /profile endpoint: {e}") - raise - - @self.app.post("/load_model") - async def load_model_endpoint( - req: ShardLoadModelRequest, - ) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - logger.info( - f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " - f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " - f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" - ) - result = await self.load_model(req) - return result - - except Exception as e: - logger.error(f"Error in /load_model endpoint: {e}") - return ShardLoadModelResponse( - success=False, - message=f"Error: {str(e)}", - layers_loaded=[], - load_time_ms=0.0, - ) - - @self.app.post("/unload_model") - async def unload_model_endpoint() -> ShardUnloadModelResponse: - """Unload current model.""" - try: - logger.info("HTTP /unload_model") - result = await self.unload_model() - return result - - except Exception as e: - logger.error(f"Error in /unload_model endpoint: {e}") - return ShardUnloadModelResponse( - success=False, - message=f"Error: {str(e)}", - ) - - @self.app.post("/warm") - async def warm(request: Request) -> JSONResponse: - try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) - except Exception as e: - logger.error("/warm failed: %s", e) - return JSONResponse(status_code=500, content={"error": str(e)}) - - async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: - """Profile device using dperf in a subprocess and return a dict. - - Args: - repo_id: Hugging Face repository ID - max_batch_exp: Maximum batch size exponent (2^max_batch_exp) - - Returns: - Device profile information as a plain dict - """ - from ...utils.profile_subproc import profile_device_via_subprocess - - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) - logger.info("Device profiling completed for node %s", self.node_id) - return profile_dict - - async def _connect_next_node(self) -> bool: - """Connect to next node in ring. - - Returns: - True if connected or no next node, False on failure - """ - if not self.next_node: - logger.info(f"Shard node {self.node_id} is the final shard (no next node)") - return True - - if self.next_node_channel: - logger.debug(f"Shard node {self.node_id} already connected to next node.") - return True - - try: - # use thunderbolt here if available - this_properties = self.discovery.get_own_properties() - thunderbolt_conn = discover_thunderbolt_connection( - this_properties, - self.next_node, - ) - next_ip = ( - thunderbolt_conn.ip_addr - if thunderbolt_conn - else self.next_node.local_ip - ) - address = f"{next_ip}:{self.next_node.shard_port}" - logger.info( - f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" - ) - - self.next_node_channel = aio_grpc.insecure_channel(address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - self.next_node_stub = DnetRingServiceStub(self.next_node_channel) - return True - except Exception as e: - logger.warning( - f"Shard node {self.node_id} failed to connect to next node {address}: {e}" - ) - self.next_node_channel = None - self.next_node_stub = None - return False - - async def _reconnect_next_node(self) -> bool: - try: - if self.next_node_channel: - await self.next_node_channel.close() - except Exception: - pass - self.next_node_channel = None - self.next_node_stub = None - return await self._connect_next_node() - - async def _health_check(self): - try: - health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) - response = await self.next_node_stub.HealthCheck(health_request) # type: ignore - logger.info( - "Shard node %s successfully pinged: %s, healthy: %s", - self.node_id, - response.node_id, - response.healthy, - ) - return True - except Exception as e: - logger.warning( - "Shard node %s failed to ping next node %s: %s", - self.node_id, - self.next_node_address, - e, - ) - return False - - async def _measure_latency_to_devices( - self, - devices: Mapping[str, DnetDeviceProperties], - thunderbolts: Mapping[str, ThunderboltConnection], - payload_sizes: List[int], - ) -> LatencyResults: - """Measure latency to all devices except self. - - Args: - devices: Device information mapping - thunderbolts: Thunderbolt connection information - payload_sizes: List of payload sizes to test - - Returns: - Latency measurement results - """ - latency_results_dict: Dict[str, DeviceLatencyResult] = {} - - for service_name, device_info in devices.items(): - # Skip measuring latency to ourselves - if service_name.startswith(self.discovery.instance_name()): - logger.debug("Skipping latency measurement to self: %s", service_name) - continue - - # Skip measuring latency to API (manager) devices - if device_info.is_manager: - logger.debug( - "Skipping latency measurement to manager/API: %s", service_name - ) - continue - - try: - shard_port = device_info.shard_port - - # Check for Thunderbolt connection - if service_name in thunderbolts: - tb_data = thunderbolts[service_name] - service_ip = tb_data.ip_addr - logger.info( - "Using Thunderbolt for %s at %s, connected to instance %s", - service_name, - service_ip, - tb_data.instance, - ) - else: - # No Thunderbolt, use WiFi - service_ip = device_info.local_ip - - if not shard_port or not service_ip: - logger.warning( - "No shard_port or local_ip for device %s", service_name - ) - continue - - # Connect to target shard's gRPC server - target_address = f"{service_ip}:{shard_port}" - channel = aio_grpc.insecure_channel(target_address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - stub = DnetRingServiceStub(channel) - - # Measure latency for each payload size - latency_measurements: List[LatencyMeasurement] = [] - for payload_size in payload_sizes: - # Create dummy payload - dummy_data = b"x" * payload_size - - start_time = time.perf_counter() - timestamp_ms = int(time.time() * 1000) - - request = dnet_ring_pb2.LatencyMeasureRequest( - requester_id=str(self.node_id), - payload_size=payload_size, - dummy_data=dummy_data, - timestamp=timestamp_ms, - ) - - response = await stub.MeasureLatency(request) # type: ignore - end_time = time.perf_counter() - - if response.success: - latency_ms = (end_time - start_time) * 1000 - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - latency_ms=round(latency_ms, 2), - success=True, - error=None, - ) - ) - else: - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - success=False, - error=response.message, - latency_ms=0, - ) - ) - - # Store results - result = DeviceLatencyResult( - target_node_id=response.node_id if response.success else None, - measurements=latency_measurements, - success=True, - error=None, - ) - latency_results_dict[service_name] = result - - # Close channel - await channel.close() - - except Exception as e: - logger.error("Error measuring latency to %s: %s", service_name, e) - result = DeviceLatencyResult( - target_node_id=None, - success=False, - error=str(e), - measurements=[], - ) - latency_results_dict[service_name] = result - - return LatencyResults(results=latency_results_dict) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 38a8b896..5b0b3b96 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -105,9 +105,7 @@ def get_weight( # Estimate bytes by summing tensor sizes for the layer try: winfo = self.layer_manager.weight_info.get(layer_id, {}) - total_bytes = sum( - getattr(w, "size_bytes", 0) for w in winfo.values() - ) + total_bytes = sum(w.size_bytes for w in winfo.values()) except Exception: total_bytes = 0 # Commit to cache under lock