Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
],
"paths": {
"/v1/datasets/{dataset_id}/append-rows": {
"/v1/datasetio/append-rows/{dataset_id}": {
"post": {
"responses": {
"200": {
Expand All @@ -60,7 +60,7 @@
}
},
"tags": [
"Datasets"
"DatasetIO"
],
"description": "",
"parameters": [
Expand Down Expand Up @@ -2177,7 +2177,7 @@
}
}
},
"/v1/datasets/{dataset_id}/iterrows": {
"/v1/datasetio/iterrows/{dataset_id}": {
"get": {
"responses": {
"200": {
Expand All @@ -2204,7 +2204,7 @@
}
},
"tags": [
"Datasets"
"DatasetIO"
],
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
"parameters": [
Expand Down Expand Up @@ -10274,7 +10274,7 @@
"name": "Benchmarks"
},
{
"name": "Datasets"
"name": "DatasetIO"
},
{
"name": "Datasets"
Expand Down Expand Up @@ -10342,7 +10342,7 @@
"Agents",
"BatchInference (Coming Soon)",
"Benchmarks",
"Datasets",
"DatasetIO",
"Datasets",
"Eval",
"Files",
Expand Down
12 changes: 6 additions & 6 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ info:
servers:
- url: http://any-hosted-llama-stack.com
paths:
/v1/datasets/{dataset_id}/append-rows:
/v1/datasetio/append-rows/{dataset_id}:
post:
responses:
'200':
Expand All @@ -26,7 +26,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Datasets
- DatasetIO
description: ''
parameters:
- name: dataset_id
Expand Down Expand Up @@ -1457,7 +1457,7 @@ paths:
schema:
$ref: '#/components/schemas/InvokeToolRequest'
required: true
/v1/datasets/{dataset_id}/iterrows:
/v1/datasetio/iterrows/{dataset_id}:
get:
responses:
'200':
Expand All @@ -1477,7 +1477,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Datasets
- DatasetIO
description: >-
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
parameters:
Expand Down Expand Up @@ -6931,7 +6931,7 @@ tags:
Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon)
- name: Benchmarks
- name: Datasets
- name: DatasetIO
- name: Datasets
- name: Eval
x-displayName: >-
Expand Down Expand Up @@ -6971,7 +6971,7 @@ x-tagGroups:
- Agents
- BatchInference (Coming Soon)
- Benchmarks
- Datasets
- DatasetIO
- Datasets
- Eval
- Files
Expand Down
4 changes: 2 additions & 2 deletions docs/openapi_generator/pyopenapi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,8 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
print(op.defining_class.__name__)

# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
if op.defining_class.__name__ in ["DatasetIO"]:
op.defining_class.__name__ = "Datasets"
# if op.defining_class.__name__ in ["DatasetIO"]:
# op.defining_class.__name__ = "Datasets"

doc_string = parse_type(op.func_ref)
doc_params = dict(
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/apis/datasetio/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore

@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
Expand All @@ -49,5 +50,5 @@ async def iterrows(
"""
...

@webmethod(route="/datasets/{dataset_id}/append-rows", method="POST")
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
4 changes: 2 additions & 2 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ async def register_dataset(
provider_dataset_id = dataset_id

# infer provider from source
if source.type == DatasetType.rows:
if source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri:
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
Expand Down
112 changes: 24 additions & 88 deletions llama_stack/providers/inline/datasetio/localfs/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,22 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

import pandas

from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl

from .config import LocalFSDatasetIOConfig

DATASETS_PREFIX = "localfs_datasets:"


class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()

@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()

@abstractmethod
def load(self):
raise NotImplementedError()


@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset


class PandasDataframeDataset(BaseDataset):
class PandasDataframeDataset:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
Expand All @@ -64,23 +35,19 @@ def __getitem__(self, idx):
else:
return self.df.iloc[idx].to_dict()

def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df

def load(self) -> None:
if self.df is not None:
return

df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")

self.df = self._validate_dataset_schema(df)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")


class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
Expand All @@ -99,29 +66,21 @@ async def initialize(self) -> None:

for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
self.dataset_infos[dataset.identifier] = dataset

async def shutdown(self) -> None: ...

async def register_dataset(
self,
dataset: Dataset,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}"
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset.json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def

async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
Expand All @@ -134,51 +93,28 @@ async def iterrows(
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()

start_index = start_index or 0

if limit is None or limit == -1:
end = len(dataset_info.dataset_impl)
end = len(dataset_impl)
else:
end = min(start_index + limit, len(dataset_info.dataset_impl))
end = min(start_index + limit, len(dataset_impl))

rows = dataset_info.dataset_impl[start_index:end]
rows = dataset_impl[start_index:end]

return IterrowsResponse(
data=rows,
next_index=end if end < len(dataset_info.dataset_impl) else None,
next_index=end if end < len(dataset_impl) else None,
)

async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")

dataset_impl = dataset_info.dataset_impl
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()

new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)

if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")

csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)
Loading