Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/matrix-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
python_version:
required: false
type: string
secrets:
CI_SSH_KEY:
required: true

jobs:
matrix-checks:
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/tests-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
python_version:
required: false
type: string
secrets:
CI_SSH_KEY:
required: true

jobs:
tests-macos:
Expand Down
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,30 @@ curl -X POST http://localhost:8080/v1/chat/completions \
}'
```

#### MCP Integration

dnet exposes an MCP server at `/mcp` for use with Claude Desktop, Cursor, and other MCP clients.

Add this to your MCP config:

```json
{
"mcpServers": {
"dnet": {
"command": "npx",
"args": [
"-y",
"mcp-remote@latest",
"http://localhost:8080/mcp",
"--allow-http"
]
}
}
}
```

Available tools: `chat_completion`, `load_model`, `unload_model`, `list_models`, `get_status`, `get_cluster_details`.

#### Devices

You can get the list of discoverable devices with:
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py",
"rich>=13.0.0",
"psutil>=5.9.0",
"fastmcp==2.13.0",
]

[project.optional-dependencies]
Expand All @@ -47,6 +48,7 @@ cuda = ["mlx[cuda]"]
dev = [
"openai>=2.6.0", # for OpenAI compatibility tests
"pytest>=8.4.2",
"pytest-asyncio>=0.24.0",
"mypy>=1.3.0", # Type checking
"ruff>=0.0.285",
"types-psutil>=7.1.3",
Expand All @@ -67,6 +69,7 @@ python_files = ["test_*.py", "*_test.py"]
testpaths = ["tests"]
python_functions = ["test_"]
log_cli = true
asyncio_mode = "auto"
markers = [
"api: tests for API node components (HTTP, gRPC, managers)",
"shard: tests for Shard node components (HTTP, gRPC, runtime, policies, ring)",
Expand All @@ -79,6 +82,7 @@ markers = [
"core: tests for core memory/cache/utils not tied to api/shard",
"e2e: integration tests requiring live servers or multiple components",
"integration: model catalog integration tests for CI (manual trigger)",
"mcp: tests for MCP handler tools and server integration",
]

[tool.ruff]
Expand Down
90 changes: 34 additions & 56 deletions src/dnet/api/http_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Any, List
import asyncio
import os
from hypercorn import Config
from hypercorn.utils import LifespanFailureError
import hypercorn.asyncio as aio_hypercorn
Expand All @@ -25,6 +24,12 @@
from .inference import InferenceManager
from .model_manager import ModelManager
from dnet_p2p import DnetDeviceProperties
from .mcp_handler import create_mcp_server
from .load_helpers import (
_prepare_topology_core,
_load_model_core,
_unload_model_core,
)


class HTTPServer:
Expand All @@ -41,9 +46,19 @@ def __init__(
self.inference_manager = inference_manager
self.model_manager = model_manager
self.node_id = node_id
self.app = FastAPI()
self.http_server: Optional[asyncio.Task] = None

# Create MCP server first to get lifespan
mcp = create_mcp_server(inference_manager, model_manager, cluster_manager)
# Use path='/' since we're mounting at /mcp, so final path will be /mcp/
mcp_app = mcp.http_app(path="/")

# Create FastAPI app with MCP lifespan
self.app = FastAPI(lifespan=mcp_app.lifespan)

# Mount MCP server as ASGI app
self.app.mount("/mcp", mcp_app)

async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None:
await self._setup_routes()

Expand Down Expand Up @@ -152,59 +167,27 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
),
)

model_config = get_model_config_json(req.model)
embedding_size = int(model_config["hidden_size"])
num_layers = int(model_config["num_hidden_layers"])

await self.cluster_manager.scan_devices()
batch_sizes = [1]
profiles = await self.cluster_manager.profile_cluster(
req.model, embedding_size, 2, batch_sizes
)
if not profiles:
return APILoadModelResponse(
model=req.model,
success=False,
shard_statuses=[],
message="No profiles collected",
try:
topology = await _prepare_topology_core(
self.cluster_manager, req.model, req.kv_bits, req.seq_len
)

model_profile_split = profile_model(
repo_id=req.model,
batch_sizes=batch_sizes,
sequence_length=req.seq_len,
)
model_profile = model_profile_split.to_model_profile()
topology = await self.cluster_manager.solve_topology(
profiles, model_profile, req.model, num_layers, req.kv_bits
)
except RuntimeError as e:
if "No profiles collected" in str(e):
return APILoadModelResponse(
model=req.model,
success=False,
shard_statuses=[],
message="No profiles collected",
)
raise
self.cluster_manager.current_topology = topology

api_props = await self.cluster_manager.discovery.async_get_own_properties()
grpc_port = int(self.inference_manager.grpc_port)

# Callback address shards should use for SendToken.
# In static discovery / cloud setups, discovery may report 127.0.0.1 which is not usable.
api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip()
if not api_callback_addr:
api_callback_addr = f"{api_props.local_ip}:{grpc_port}"
if api_props.local_ip in ("127.0.0.1", "localhost"):
logger.warning(
"API callback address is loopback (%s). Remote shards will fail to SendToken. "
"Set DNET_API_CALLBACK_ADDR to a reachable host:port.",
api_callback_addr,
)
response = await self.model_manager.load_model(
response = await _load_model_core(
self.cluster_manager,
self.model_manager,
self.inference_manager,
topology,
api_props,
self.inference_manager.grpc_port,
api_callback_address=api_callback_addr,
)
if response.success:
first_shard = topology.devices[0]
await self.inference_manager.connect_to_ring(
first_shard.local_ip, first_shard.shard_port, api_callback_addr
)
return response

except Exception as e:
Expand All @@ -217,12 +200,7 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
)

async def unload_model(self) -> UnloadModelResponse:
await self.cluster_manager.scan_devices()
shards = self.cluster_manager.shards
response = await self.model_manager.unload_model(shards)
if response.success:
self.cluster_manager.current_topology = None
return response
return await _unload_model_core(self.cluster_manager, self.model_manager)

async def get_devices(self) -> JSONResponse:
devices = await self.cluster_manager.discovery.async_get_properties()
Expand Down
98 changes: 98 additions & 0 deletions src/dnet/api/load_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
from dnet.utils.logger import logger
from dnet.utils.model import get_model_config_json
from distilp.profiler import profile_model
from dnet.core.types.topology import TopologyInfo
from .models import APILoadModelResponse, UnloadModelResponse


async def get_api_callback_address(
cluster_manager,
grpc_port: int | str,
) -> str:
api_props = await cluster_manager.discovery.async_get_own_properties()
grpc_port_int = int(grpc_port)
api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip()
if not api_callback_addr:
api_callback_addr = f"{api_props.local_ip}:{grpc_port_int}"
if api_props.local_ip in ("127.0.0.1", "localhost"):
logger.warning(
"API callback address is loopback (%s). Remote shards will fail to SendToken. "
"Set DNET_API_CALLBACK_ADDR to a reachable host:port.",
api_callback_addr,
)
return api_callback_addr


async def _prepare_topology_core(
cluster_manager,
model: str,
kv_bits: str,
seq_len: int,
progress_callback=None,
) -> TopologyInfo:
model_config = get_model_config_json(model)
embedding_size = int(model_config["hidden_size"])
num_layers = int(model_config["num_hidden_layers"])

await cluster_manager.scan_devices()
if progress_callback:
await progress_callback("Profiling cluster performance")
batch_sizes = [1]
profiles = await cluster_manager.profile_cluster(
model, embedding_size, 2, batch_sizes
)
if not profiles:
raise RuntimeError("No profiles collected")

if progress_callback:
await progress_callback("Computing optimal layer distribution")
model_profile_split = profile_model(
repo_id=model,
batch_sizes=batch_sizes,
sequence_length=seq_len,
)
model_profile = model_profile_split.to_model_profile()

topology = await cluster_manager.solve_topology(
profiles, model_profile, model, num_layers, kv_bits
)
return topology


async def _load_model_core(
cluster_manager,
model_manager,
inference_manager,
topology: TopologyInfo,
) -> APILoadModelResponse:
api_props = await cluster_manager.discovery.async_get_own_properties()
grpc_port = int(inference_manager.grpc_port)

api_callback_addr = await get_api_callback_address(
cluster_manager, inference_manager.grpc_port
)
response = await model_manager.load_model(
topology,
api_props,
grpc_port,
api_callback_address=api_callback_addr,
)
if response.success:
first_shard = topology.devices[0]
await inference_manager.connect_to_ring(
first_shard.local_ip, first_shard.shard_port, api_callback_addr
)
return response


async def _unload_model_core(
cluster_manager,
model_manager,
) -> UnloadModelResponse:
await cluster_manager.scan_devices()
shards = cluster_manager.shards
response = await model_manager.unload_model(shards)
if response.success:
cluster_manager.current_topology = None
return response
Loading