Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -4946,6 +4946,74 @@
}
}
},
"/v1/providers/{api}/{provider_id}/{provider_type}": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Providers"
],
"description": "",
"parameters": [
{
"name": "api",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "provider_id",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "provider_type",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UpdateProviderRequest"
}
}
},
"required": true
}
}
},
"/v1/version": {
"get": {
"responses": {
Expand Down Expand Up @@ -16101,6 +16169,41 @@
"title": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
},
"UpdateProviderRequest": {
"type": "object",
"properties": {
"config": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"config"
],
"title": "UpdateProviderRequest"
},
"VersionInfo": {
"type": "object",
"properties": {
Expand Down
61 changes: 61 additions & 0 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,50 @@ paths:
schema:
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
required: true
/v1/providers/{api}/{provider_id}/{provider_type}:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ProviderInfo'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Providers
description: ''
parameters:
- name: api
in: path
required: true
schema:
type: string
- name: provider_id
in: path
required: true
schema:
type: string
- name: provider_type
in: path
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdateProviderRequest'
required: true
/v1/version:
get:
responses:
Expand Down Expand Up @@ -11234,6 +11278,23 @@ components:
description: >-
Response from the synthetic data generation. Batch of (prompt, response, score)
tuples that pass the threshold.
UpdateProviderRequest:
type: object
properties:
config:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- config
title: UpdateProviderRequest
VersionInfo:
type: object
properties:
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/apis/providers/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ async def inspect_provider(self, provider_id: str) -> ProviderInfo:
:returns: A ProviderInfo object containing the provider's details.
"""
...

@webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT")
async def update_provider(
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
) -> ProviderInfo: ...
1 change: 1 addition & 0 deletions llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class Provider(BaseModel):
provider_id: str | None
provider_type: str
config: dict[str, Any]
immutable: bool = False


class LoggingConfig(BaseModel):
Expand Down
17 changes: 17 additions & 0 deletions llama_stack/distribution/library_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AsyncStream,
LlamaStackClient,
)
from llama_stack_client.types import provider_info
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
Expand Down Expand Up @@ -293,6 +294,22 @@ async def request(
cast_to=cast_to,
options=options,
)
# Check if response is of a certain type
# this indicates we have done a provider update
if (
isinstance(response, provider_info.ProviderInfo)
and hasattr(response, "config")
and options.method.lower() == "put"
):
# patch in the new provider config
for api, providers in self.config.providers.items():
if api != response.api:
continue
for prov in providers:
if prov.provider_id == response.provider_id:
prov.config = response.config
break
await self.initialize()
return response

async def _call_non_streaming(
Expand Down
89 changes: 88 additions & 1 deletion llama_stack/distribution/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import asyncio
import copy
from typing import Any

from pydantic import BaseModel
Expand All @@ -13,7 +14,7 @@
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus

from .datatypes import StackRunConfig
from .datatypes import Provider, StackRunConfig
from .utils.config import redact_sensitive_fields

logger = get_logger(name=__name__, category="core")
Expand Down Expand Up @@ -129,3 +130,89 @@ async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
providers_health[api_name] = health_response

return providers_health

async def update_provider(
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
) -> ProviderInfo:
# config = ast.literal_eval(provider_request.config)
prov = Provider(
provider_id=provider_id,
provider_type=provider_type,
config=config,
)
assert prov.provider_id is not None
existing_provider = None
# if the provider isn't there or the API is invalid, we should not continue
for prov_api, providers in self.config.run_config.providers.items():
if prov_api != api:
continue
for p in providers:
# the provider needs to be mutable for us to update its config
if p.provider_id == provider_id:
if p.immutable:
raise ValueError(f"Provider {provider_id} is immutable, you can only update mutable providers.")
existing_provider = p
break
if existing_provider is not None:
break

if existing_provider is None:
raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")

new_config = self.merge_providers(existing_provider, prov)
existing_provider.config = new_config
providers_health = await self.get_providers_health()
# takes a single provider, validates its in the registry
# if it is, merge the provider config with the existing one
ret = ProviderInfo(
api=api,
provider_id=prov.provider_id,
provider_type=prov.provider_type,
config=new_config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
),
)

return ret

def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
"""Recursively merges `overrides` into `base`, replacing only specified keys."""

merged = copy.deepcopy(base) # Preserve original dict
for key, value in overrides.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
# Recursively merge if both are dictionaries
merged[key] = self.merge_dicts(merged[key], value)
else:
# Otherwise, directly override
merged[key] = value

return merged

def merge_configs(
self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
) -> dict[str, list[Provider]]:
merged_config = copy.deepcopy(global_config) # Preserve original structure

for key, new_providers in new_config.items():
if key in merged_config:
existing_providers = {p.provider_id: p for p in merged_config[key]}

for new_provider in new_providers:
if new_provider.provider_id in existing_providers:
# Override settings of existing provider
existing = existing_providers[new_provider.provider_id]
existing.config = self.merge_dicts(existing.config, new_provider.config)
else:
# Append new provider
merged_config[key].append(new_provider)
else:
# Add new category entirely
merged_config[key] = new_providers

return merged_config

def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
return self.merge_dicts(current_provider.config, new_provider.config)
2 changes: 2 additions & 0 deletions llama_stack/distribution/server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
elif webmethod.method == hdrs.METH_PUT:
http_method = hdrs.METH_PUT
else:
http_method = hdrs.METH_POST
routes.append(
Expand Down
Loading
Loading