Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -2642,7 +2642,81 @@
}
}
},
"/v1/inspect/providers": {
"/v1/providers": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListProvidersResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Providers"
],
"description": "",
"parameters": []
}
},
"/v1/providers/{provider_id}": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GetProviderResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Providers"
],
"description": "",
"parameters": [
{
"name": "provider_id",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
}
]
},
"/v1/inspect/providers": {
"get": {
"responses": {
"200": {
Expand Down
51 changes: 51 additions & 0 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,57 @@ paths:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
/v1/providers:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ListProvidersResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Providers
description: ''
parameters: []
/v1/providers/{provider_id}:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/GetProviderResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Providers
description: ''
parameters:
- name: provider_id
in: path
required: true
schema:
type: string
/v1/inspect/providers:
get:
responses:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/apis/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

@json_schema_type
class Api(Enum):
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
Expand Down
18 changes: 9 additions & 9 deletions llama_stack/apis/inspect/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
from llama_stack.schema_utils import json_schema_type, webmethod


@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str


@json_schema_type
class RouteInfo(BaseModel):
route: str
Expand All @@ -32,14 +25,21 @@ class HealthInfo(BaseModel):


@json_schema_type
class VersionInfo(BaseModel):
version: str
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str


class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]


@json_schema_type
class VersionInfo(BaseModel):
version: str


class ListRoutesResponse(BaseModel):
data: List[RouteInfo]

Expand Down
7 changes: 7 additions & 0 deletions llama_stack/apis/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .providers import * # noqa: F401 F403
40 changes: 40 additions & 0 deletions llama_stack/apis/providers/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import List, Protocol, runtime_checkable

from pydantic import BaseModel

from llama_stack.distribution.datatypes import Provider
from llama_stack.schema_utils import json_schema_type, webmethod


@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str


class GetProviderResponse(BaseModel):
data: Provider | None


class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]


@runtime_checkable
class Providers(Protocol):
"""
Providers API for inspecting, listing, and modifying providers and their configurations.
"""

@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...

@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> GetProviderResponse: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this return ProviderInfo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProviderInfo doesn't contain a config, so I created a new type for that. inspect returns detailed info about a provider. list returns ProviderInfo

2 changes: 1 addition & 1 deletion llama_stack/distribution/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]

for api_str in apis_to_serve:
api = Api(api_str)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:

def providable_apis() -> List[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]


def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
Expand Down
59 changes: 59 additions & 0 deletions llama_stack/distribution/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pydantic import BaseModel

from llama_stack.apis.providers import GetProviderResponse, ListProvidersResponse, ProviderInfo, Providers

from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields


class ProviderImplConfig(BaseModel):
run_config: StackRunConfig


async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl


class ProviderImpl(Providers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need a fast path to such builtin APIs rather than having to go through the whole resolution abstraction (since there never will be multiple "implementations" of this API)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I agree, maybe we should scope future work to add this for inspect and provider APIs?

def __init__(self, config, deps):
self.config = config
self.deps = deps

async def initialize(self) -> None:
pass

async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)

return ListProvidersResponse(data=ret)

async def inspect_provider(self, provider_id: str) -> GetProviderResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = None
for _, providers in safe_config.providers.items():
for p in providers:
if p.provider_id == provider_id:
ret = p

return GetProviderResponse(data=ret)
21 changes: 21 additions & 0 deletions llama_stack/distribution/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
Expand Down Expand Up @@ -59,6 +60,7 @@ class InvalidProviderError(Exception):

def api_protocol_map() -> Dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Expand Down Expand Up @@ -247,6 +249,25 @@ def sort_providers_by_deps(
)
)

sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)

logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def main():
apis_to_serve.add(inf.routing_table_api.value)

apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)

Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distribution/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
Expand All @@ -44,6 +45,7 @@


class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ exclude = [
"^llama_stack/apis/inspect/inspect\\.py$",
"^llama_stack/apis/models/models\\.py$",
"^llama_stack/apis/post_training/post_training\\.py$",
"^llama_stack/apis/providers/providers\\.py$",
"^llama_stack/apis/resource\\.py$",
"^llama_stack/apis/safety/safety\\.py$",
"^llama_stack/apis/scoring/scoring\\.py$",
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
Loading