Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
Expand Down Expand Up @@ -160,7 +161,11 @@ async def register_model(
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)

def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.

Expand Down Expand Up @@ -298,7 +303,12 @@ async def stream_generator():
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
Expand Down Expand Up @@ -471,21 +481,36 @@ async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass

async def get_rows_paginated(
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)

async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
start_index=start_index,
limit=limit,
)

async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
Expand Down
52 changes: 34 additions & 18 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@
# the root directory of this source tree.

import logging
import uuid
from typing import Any, Dict, List, Optional

from pydantic import TypeAdapter

from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
Expand Down Expand Up @@ -352,34 +360,42 @@ async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:

async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
if provider_dataset_id is None:
provider_dataset_id = dataset_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
dataset_id: Optional[str] = None,
) -> Dataset:
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"

provider_dataset_id = dataset_id

# infer provider from source
if source.type == DatasetType.rows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Infering provider in this way is not scalable for additional dataset sources (for example, relational databases) - adding more dataset sources will become more and more challenging/

provider_id = "localfs"
elif source.type == DatasetType.uri:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")

if metadata is None:
metadata = {}

dataset = Dataset(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
dataset_schema=dataset_schema,
url=url,
purpose=purpose,
source=source,
metadata=metadata,
)

await self.register_object(dataset)
return dataset

async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"]

dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated(
rows = llama_stack_api.client.datasetio.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
Expand Down
33 changes: 12 additions & 21 deletions llama_stack/providers/inline/datasetio/localfs/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas

from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
Expand Down Expand Up @@ -128,36 +128,27 @@ async def unregister_dataset(self, dataset_id: str) -> None:
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]

async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()

if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0

if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)

start = next_page_token
if rows_in_page == -1:
if limit is None or limit == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
end = min(start_index + limit, len(dataset_info.dataset_impl))

rows = dataset_info.dataset_impl[start:end]
rows = dataset_info.dataset_impl[start_index:end]

return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_index=end if end < len(dataset_info.dataset_impl) else None,
)

async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/eval/meta_reference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def run_eval(
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def _setup_data(
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated(
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
Expand Down
6 changes: 4 additions & 2 deletions llama_stack/providers/inline/scoring/basic/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn

Expand Down Expand Up @@ -82,7 +84,7 @@ async def score_batch(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))

all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def score_batch(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))

all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def score_batch(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))

all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
Expand Down
33 changes: 12 additions & 21 deletions llama_stack/providers/remote/datasetio/huggingface/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import datasets as hf_datasets

from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
Expand Down Expand Up @@ -73,36 +73,27 @@ async def unregister_dataset(self, dataset_id: str) -> None:
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]

async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)

if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0

if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)

start = next_page_token
if rows_in_page == -1:
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start + rows_in_page, len(loaded_dataset))
end = min(start_index + limit, len(loaded_dataset))

rows = [loaded_dataset[i] for i in range(start, end)]
rows = [loaded_dataset[i] for i in range(start_index, end)]

return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_index=end if end < len(loaded_dataset) else None,
)

async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
Expand Down