From 917679cc2f2776f5fbd6f1a21656fad00cf1c750 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:18:52 -0700 Subject: [PATCH 1/7] types --- llama_stack/distribution/routers/routers.py | 156 +++++++++++++----- .../inline/datasetio/localfs/datasetio.py | 18 +- .../datasetio/huggingface/huggingface.py | 14 +- 3 files changed, 139 insertions(+), 49 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 22a1e46f9a..875c8c94eb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,11 +8,11 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, + URL, ) -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.eval import ( BenchmarkConfig, Eval, @@ -93,7 +93,9 @@ async def register_vector_db( provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + logger.debug( + f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" + ) await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -111,7 +113,9 @@ async def insert_chunks( logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds + ) async def query_chunks( self, @@ -120,7 +124,9 @@ async def query_chunks( params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params + ) class InferenceRouter(Inference): @@ -157,10 +163,16 @@ async def register_model( logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( - self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, ) -> List[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. @@ -207,11 +219,16 @@ async def _compute_and_log_token_usage( total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -236,7 +253,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -246,12 +265,19 @@ async def chat_completion( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -269,9 +295,14 @@ async def chat_completion( pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -286,19 +317,32 @@ async def chat_completion( tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.progress + ): if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.complete + ): completion_tokens = await self._count_tokens( - [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + [ + CompletionMessage( + content=completion_text, + stop_reason=StopReason.end_of_turn, + ) + ], tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) @@ -308,7 +352,11 @@ async def stream_generator(): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -325,7 +373,9 @@ async def stream_generator(): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def completion( @@ -346,7 +396,9 @@ async def completion( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -366,7 +418,11 @@ async def stream_generator(): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -375,7 +431,11 @@ async def stream_generator(): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -389,7 +449,9 @@ async def stream_generator(): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def embeddings( @@ -405,7 +467,9 @@ async def embeddings( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -439,7 +503,9 @@ async def register_shield( params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + return await self.routing_table.register_shield( + shield_id, provider_shield_id, provider_id, params + ) async def run_shield( self, @@ -477,11 +543,13 @@ async def get_rows_paginated( rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: logger.debug( f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", ) - return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( + return await self.routing_table.get_provider_impl( + dataset_id + ).get_rows_paginated( dataset_id=dataset_id, rows_in_page=rows_in_page, page_token=page_token, @@ -521,7 +589,9 @@ async def score_batch( logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -539,11 +609,15 @@ async def score( input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + logger.debug( + f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" + ) res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -586,7 +660,9 @@ async def evaluate_rows( scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + logger.debug( + f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" + ) return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -600,7 +676,9 @@ async def job_status( job_id: str, ) -> Optional[JobStatus]: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + return await self.routing_table.get_provider_impl(benchmark_id).job_status( + benchmark_id, job_id + ) async def job_cancel( self, @@ -654,9 +732,9 @@ async def insert( logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + return await self.routing_table.get_provider_impl( + "insert_into_memory" + ).insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -689,4 +767,6 @@ async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index c5216e026f..03dbae337e 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -13,7 +13,7 @@ import pandas from llama_stack.apis.common.content_types import URL -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -134,7 +134,7 @@ async def get_rows_paginated( rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() @@ -154,7 +154,7 @@ async def get_rows_paginated( rows = dataset_info.dataset_impl[start:end] - return PaginatedRowsResult( + return IterrowsResponse( rows=rows, total_count=len(rows), next_page_token=str(end), @@ -170,7 +170,9 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None new_rows_df = pandas.DataFrame(rows) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) - dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) url = str(dataset_info.dataset_def.url.uri) parsed_url = urlparse(url) @@ -185,8 +187,12 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None raise ValueError("Data URL must be a base64-encoded CSV") csv_buffer = dataset_impl.df.to_csv(index=False) - base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8") - dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}") + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( + "utf-8" + ) + dataset_info.dataset_def.url = URL( + uri=f"data:text/csv;base64,{base64_content}" + ) else: raise ValueError( f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cd4e7f1f10..8df64a1901 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -7,7 +7,7 @@ import datasets as hf_datasets -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -79,7 +79,7 @@ async def get_rows_paginated( rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] loaded_dataset = load_hf_dataset(dataset_def) @@ -99,7 +99,7 @@ async def get_rows_paginated( rows = [loaded_dataset[i] for i in range(start, end)] - return PaginatedRowsResult( + return IterrowsResponse( rows=rows, total_count=len(rows), next_page_token=str(end), @@ -113,9 +113,13 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None new_dataset = hf_datasets.Dataset.from_list(rows) # Concatenate the new rows with existing dataset - updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) if dataset_def.metadata.get("path", None): updated_dataset.push_to_hub(dataset_def.metadata["path"]) else: - raise NotImplementedError("Uploading to URL-based datasets is not supported yet") + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) From a9c662d68bb67a7c298116e88b9d6d8c50ee3702 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:20:17 -0700 Subject: [PATCH 2/7] fix all occurrence --- llama_stack/distribution/routers/routers.py | 8 +- .../ui/page/evaluations/native_eval.py | 6 +- .../inline/datasetio/localfs/datasetio.py | 2 +- .../inline/eval/meta_reference/eval.py | 74 +++++++++++++++---- .../recipes/lora_finetuning_single_device.py | 74 ++++++++++++++----- .../providers/inline/scoring/basic/scoring.py | 26 +++++-- .../inline/scoring/braintrust/braintrust.py | 35 ++++++--- .../inline/scoring/llm_as_judge/scoring.py | 20 +++-- .../datasetio/huggingface/huggingface.py | 2 +- 9 files changed, 177 insertions(+), 70 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 875c8c94eb..5acd945fe7 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -537,7 +537,7 @@ async def shutdown(self) -> None: logger.debug("DatasetIORouter.shutdown") pass - async def get_rows_paginated( + async def iterrows( self, dataset_id: str, rows_in_page: int, @@ -545,11 +545,9 @@ async def get_rows_paginated( filter_condition: Optional[str] = None, ) -> IterrowsResponse: logger.debug( - f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", + f"DatasetIORouter.iterrows: {dataset_id}, rows_in_page={rows_in_page}", ) - return await self.routing_table.get_provider_impl( - dataset_id - ).get_rows_paginated( + return await self.routing_table.get_provider_impl(dataset_id).iterrows( dataset_id=dataset_id, rows_in_page=rows_in_page, page_token=page_token, diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 00e949ed61..e14c1e01b6 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -166,7 +166,7 @@ def run_evaluation_3(): eval_candidate = st.session_state["eval_candidate"] dataset_id = benchmarks[selected_benchmark].dataset_id - rows = llama_stack_api.client.datasetio.get_rows_paginated( + rows = llama_stack_api.client.datasetio.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) @@ -230,7 +230,9 @@ def run_evaluation_3(): output_res[scoring_fn] = [] output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) - progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})") + progress_text_container.write( + f"Expand to see current processed result ({i + 1} / {len(rows)})" + ) results_container.json(eval_res, expanded=2) progress_bar.progress(1.0, text="Evaluation complete!") diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 03dbae337e..2f8e900030 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -128,7 +128,7 @@ async def unregister_dataset(self, dataset_id: str) -> None: await self.kvstore.delete(key=key) del self.dataset_infos[dataset_id] - async def get_rows_paginated( + async def iterrows( self, dataset_id: str, rows_in_page: int, diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 85b3512624..e55c588888 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -89,10 +89,16 @@ async def run_eval( dataset_id = task_def.dataset_id scoring_functions = task_def.scoring_functions dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) - all_rows = await self.datasetio_api.get_rows_paginated( + validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) + ) + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, - rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), + rows_in_page=( + -1 + if benchmark_config.num_examples is None + else benchmark_config.num_examples + ), ) res = await self.evaluate_rows( benchmark_id=benchmark_id, @@ -118,10 +124,14 @@ async def _run_agent_generation( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] + input_messages = [ + UserMessage(**x) for x in input_messages if x["role"] == "user" + ] # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") + session_create_response = await self.agents_api.create_agent_session( + agent_id, f"session-{i}" + ) session_id = session_create_response.session_id turn_request = dict( @@ -130,7 +140,12 @@ async def _run_agent_generation( messages=input_messages, stream=True, ) - turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] + turn_response = [ + chunk + async for chunk in await self.agents_api.create_agent_turn( + **turn_request + ) + ] final_event = turn_response[-1].event.payload # check if there's a memory retrieval step and extract the context @@ -139,10 +154,14 @@ async def _run_agent_generation( if step.step_type == StepType.tool_execution.value: for tool_response in step.tool_responses: if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join(x.text for x in tool_response.content) + memory_rag_context = " ".join( + x.text for x in tool_response.content + ) agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content + agent_generation[ColumnName.generated_answer.value] = ( + final_event.turn.output_message.content + ) if memory_rag_context: agent_generation[ColumnName.context.value] = memory_rag_context @@ -154,7 +173,9 @@ async def _run_model_generation( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = benchmark_config.eval_candidate - assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" + assert ( + candidate.sampling_params.max_tokens is not None + ), "SamplingParams.max_tokens must be provided" generations = [] for x in tqdm(input_rows): @@ -165,21 +186,39 @@ async def _run_model_generation( content=input_content, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] + chat_completion_input_json = json.loads( + x[ColumnName.chat_completion_input.value] + ) + input_messages = [ + UserMessage(**x) + for x in chat_completion_input_json + if x["role"] == "user" + ] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] + messages += [ + SystemMessage(**x) + for x in chat_completion_input_json + if x["role"] == "system" + ] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) else: raise ValueError("Invalid input row") @@ -202,7 +241,8 @@ async def evaluate_rows( # scoring with generated_answer score_input_rows = [ - input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) + input_r | generated_r + for input_r, generated_r in zip(input_rows, generations, strict=False) ] if benchmark_config.scoring_params is not None: @@ -211,7 +251,9 @@ async def evaluate_rows( for scoring_fn_id in scoring_functions } else: - scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } score_response = await self.scoring_api.score( input_rows=score_input_rows, scoring_functions=scoring_functions_dict diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 941c629e36..24ad95b3d1 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -17,8 +17,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training -from torchtune import utils as torchtune_utils +from torchtune import modules, training, utils as torchtune_utils from torchtune.data import padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -89,7 +88,9 @@ def __init__( self.job_uuid = job_uuid self.training_config = training_config if not isinstance(algorithm_config, LoraFinetuningConfig): - raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning") + raise ValueError( + "You need to speicifc LoraFinetuningConfig for LoRA finetuning" + ) self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device() self._dtype = training.get_dtype(training_config.dtype, device=self._device) @@ -98,7 +99,10 @@ def __init__( def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) - paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] + paths = [ + Path(checkpoint_dir / f"consolidated.{ext}") + for ext in ["pth", "00.pth"] + ] if not any(p.exists() for p in paths): checkpoint_dir = checkpoint_dir / "original" @@ -113,7 +117,9 @@ def model_checkpoint_dir(model) -> str: else: model = resolve_model(self.model_id) if model is None: - raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") + raise ValueError( + f"{self.model_id} not found. Your model id should be in the llama models SKU list" + ) self.checkpoint_dir = model_checkpoint_dir(model) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) @@ -185,7 +191,9 @@ async def setup(self) -> None: self._tokenizer = await self._setup_tokenizer() log.info("Tokenizer is initialized.") - self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config) + self._optimizer = await self._setup_optimizer( + optimizer_config=self.training_config.optimizer_config + ) log.info("Optimizer is initialized.") self._loss_fn = CEWithChunkedOutputLoss() @@ -213,8 +221,13 @@ async def setup(self) -> None: # by the dataloader and the max_steps_per_epoch param set by the user and is used # for logging and tracking training state. This should be computed after the dataloader # has been setup - self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps - if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch: + self._steps_per_epoch = ( + len(self._training_dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch @@ -228,7 +241,9 @@ async def setup(self) -> None: log.info("Learning rate scheduler is initialized.") # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device) + self.ignore_labels_cache = torch.full( + (self._batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) def _log_memory_stats(self): # torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing @@ -269,9 +284,13 @@ async def _setup_model( set_trainable_params(model, self.adapter_params) if enable_activation_checkpointing: - training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}) + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) - base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False) + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) # This is for any adapters that need to be initialized after base weights # have been loaded (e.g. DoRA). @@ -280,7 +299,9 @@ async def _setup_model( if hasattr(m, "initialize_dora_magnitude"): m.initialize_dora_magnitude() if lora_weights_state_dict: - lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False) + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) else: lora_missing, lora_unexpected = None, None validate_missing_and_unexpected_for_lora( @@ -294,10 +315,14 @@ async def _setup_model( ) # Validate model adapter params were loaded in with the expected dtype - training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype) + training.validate_expected_param_dtype( + self.adapter_params.items(), dtype=self._dtype + ) # activation offloading - self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) self._log_memory_stats() @@ -328,7 +353,7 @@ async def _setup_data( batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: async def fetch_rows(dataset_id: str): - return await self.datasetio_api.get_rows_paginated( + return await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) @@ -433,7 +458,9 @@ async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])) + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) if not isinstance(logits, list): labels = labels.reshape(-1) logits = logits.reshape(-1, logits.size(-1)) @@ -462,7 +489,9 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True - metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log") + metric_logger = DiskLogger( + log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log" + ) self._training_sampler.set_epoch(curr_epoch) loss_to_log = 0.0 @@ -470,7 +499,8 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: for idx, batch in enumerate(self._training_dataloader): if ( self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch ): break @@ -478,7 +508,9 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() num_tokens += current_num_tokens # Loss is normalized by default so we multiply by the number of tokens @@ -503,7 +535,9 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: loss_to_log = running_loss.item() / num_tokens pbar.update(1) - pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) time_per_step = time.perf_counter() - t0 log_dict = { diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 599f5f98c5..2a0613bfe9 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -24,7 +24,9 @@ from .config import BasicScoringConfig from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn -from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn +from .scoring_fn.regex_parser_math_response_scoring_fn import ( + RegexParserMathResponseScoringFn, +) from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn @@ -62,11 +64,15 @@ async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = [ - fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() + fn_def + for impl in self.scoring_fn_id_impls.values() + for fn_def in impl.get_supported_scoring_fn_defs() ] for f in scoring_fn_defs_list: - assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! " + assert f.identifier.startswith( + "basic" + ), "All basic scoring fn must have identifier prefixed with 'basic'! " return scoring_fn_defs_list @@ -80,9 +86,11 @@ async def score_batch( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) + validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) @@ -110,8 +118,12 @@ async def score( raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) - agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) + agg_results = await scoring_fn.aggregate( + score_results, scoring_fn_id, scoring_fn_params + ) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index a48b6b58bb..81b309b8fe 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -122,10 +122,12 @@ def __init__( self.datasets_api = datasets_api self.braintrust_evaluators = { - entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY + entry.identifier: entry.evaluator + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } self.supported_fn_defs_registry = { - entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY + entry.identifier: entry.fn_def + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } async def initialize(self) -> None: ... @@ -135,14 +137,16 @@ async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) for f in scoring_fn_defs_list: - assert f.identifier.startswith("braintrust"), ( - "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " - ) + assert f.identifier.startswith( + "braintrust" + ), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " return scoring_fn_defs_list async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: - raise NotImplementedError("Registering scoring function not allowed for braintrust provider") + raise NotImplementedError( + "Registering scoring function not allowed for braintrust provider" + ) async def set_api_key(self) -> None: # api key is in the request headers @@ -165,13 +169,17 @@ async def score_batch( await self.set_api_key() dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) + validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) - res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions) + res = await self.score( + input_rows=all_rows.rows, scoring_functions=scoring_functions + ) if save_results_dataset: # TODO: persist and register dataset on to server for reading # self.datasets_api.register_dataset() @@ -212,8 +220,13 @@ async def score( if scoring_fn_id not in self.supported_fn_defs_registry: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows] - aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions + score_results = [ + await self.score_row(input_row, scoring_fn_id) + for input_row in input_rows + ] + aggregation_functions = self.supported_fn_defs_registry[ + scoring_fn_id + ].params.aggregation_functions # override scoring_fn params if provided if scoring_functions[scoring_fn_id] is not None: diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 5b1715d9f6..d6d051f9e2 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -54,9 +54,9 @@ async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): - assert f.identifier.startswith("llm-as-judge"), ( - "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " - ) + assert f.identifier.startswith( + "llm-as-judge" + ), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " return scoring_fn_defs_list @@ -70,9 +70,11 @@ async def score_batch( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) + validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) @@ -98,8 +100,12 @@ async def score( for scoring_fn_id in scoring_functions.keys(): scoring_fn = self.llm_as_judge_fn scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) - agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) + agg_results = await scoring_fn.aggregate( + score_results, scoring_fn_id, scoring_fn_params + ) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 8df64a1901..2c7e4bdf07 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -73,7 +73,7 @@ async def unregister_dataset(self, dataset_id: str) -> None: await self.kvstore.delete(key=key) del self.dataset_infos[dataset_id] - async def get_rows_paginated( + async def iterrows( self, dataset_id: str, rows_in_page: int, From 82ec0d24f3a88fd069348351c1746f35d1f70a0d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:22:10 -0700 Subject: [PATCH 3/7] fix hf --- .../datasetio/huggingface/huggingface.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 2c7e4bdf07..8145a4a935 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -76,33 +76,24 @@ async def unregister_dataset(self, dataset_id: str) -> None: async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, + start_index: Optional[int] = None, + limit: Optional[int] = None, ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] loaded_dataset = load_hf_dataset(dataset_def) - if page_token and not page_token.isnumeric(): - raise ValueError("Invalid page_token") + start_index = start_index or 0 - if page_token is None or len(page_token) == 0: - next_page_token = 0 - else: - next_page_token = int(page_token) - - start = next_page_token - if rows_in_page == -1: + if limit == -1: end = len(loaded_dataset) else: - end = min(start + rows_in_page, len(loaded_dataset)) + end = min(start_index + limit, len(loaded_dataset)) - rows = [loaded_dataset[i] for i in range(start, end)] + rows = [loaded_dataset[i] for i in range(start_index, end)] return IterrowsResponse( - rows=rows, - total_count=len(rows), - next_page_token=str(end), + data=rows, + next_index=end, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: From a19710163549728b7097ab8b7e2734a50eefd341 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:24:46 -0700 Subject: [PATCH 4/7] fix iterrows --- .../inline/datasetio/localfs/datasetio.py | 25 ++++++------------- .../datasetio/huggingface/huggingface.py | 4 +-- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 2f8e900030..54e9900526 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -131,33 +131,24 @@ async def unregister_dataset(self, dataset_id: str) -> None: async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, + start_index: Optional[int] = None, + limit: Optional[int] = None, ) -> IterrowsResponse: dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token and not page_token.isnumeric(): - raise ValueError("Invalid page_token") + start_index = start_index or 0 - if page_token is None or len(page_token) == 0: - next_page_token = 0 - else: - next_page_token = int(page_token) - - start = next_page_token - if rows_in_page == -1: + if limit is None or limit == -1: end = len(dataset_info.dataset_impl) else: - end = min(start + rows_in_page, len(dataset_info.dataset_impl)) + end = min(start_index + limit, len(dataset_info.dataset_impl)) - rows = dataset_info.dataset_impl[start:end] + rows = dataset_info.dataset_impl[start_index:end] return IterrowsResponse( - rows=rows, - total_count=len(rows), - next_page_token=str(end), + data=rows, + next_index=end if end < len(dataset_info.dataset_impl) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 8145a4a935..8a6599f8f0 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -84,7 +84,7 @@ async def iterrows( start_index = start_index or 0 - if limit == -1: + if limit is None or limit == -1: end = len(loaded_dataset) else: end = min(start_index + limit, len(loaded_dataset)) @@ -93,7 +93,7 @@ async def iterrows( return IterrowsResponse( data=rows, - next_index=end, + next_index=end if end < len(loaded_dataset) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: From 081ec3131d234b37b67bd56f88afca03f58d6400 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:43:41 -0700 Subject: [PATCH 5/7] fix router --- llama_stack/distribution/routers/routers.py | 30 ++++-- .../distribution/routers/routing_tables.py | 98 ++++++++++++++----- 2 files changed, 97 insertions(+), 31 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5acd945fe7..879fc924b4 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -13,6 +13,7 @@ URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse +from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.eval import ( BenchmarkConfig, Eval, @@ -537,21 +538,36 @@ async def shutdown(self) -> None: logger.debug("DatasetIORouter.shutdown") pass + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: Optional[Dict[str, Any]] = None, + dataset_id: Optional[str] = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, + start_index: Optional[int] = None, + limit: Optional[int] = None, ) -> IterrowsResponse: logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, rows_in_page={rows_in_page}", + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) return await self.routing_table.get_provider_impl(dataset_id).iterrows( dataset_id=dataset_id, - rows_in_page=rows_in_page, - page_token=page_token, - filter_condition=filter_condition, + start_index=start_index, + limit=limit, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1be43ec8bd..ec7abba90c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import logging +import uuid from typing import Any, Dict, List, Optional from pydantic import TypeAdapter @@ -12,7 +13,14 @@ from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, +) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( @@ -97,7 +105,9 @@ def __init__( self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -132,7 +142,9 @@ async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -170,7 +182,9 @@ def apiname_object(): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -180,9 +194,13 @@ async def get_object_by_identifier(self, type: str, identifier: str) -> Optional async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -237,7 +255,9 @@ async def register_model( if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -257,7 +277,9 @@ async def unregister_model(self, model_id: str) -> None: class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse( + data=await self.get_all_with_type(ResourceType.shield.value) + ) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier("shield", identifier) @@ -316,14 +338,18 @@ async def register_vector_db( f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError("No provider available. Please configure a vector_io provider.") + raise ValueError( + "No provider available. Please configure a vector_io provider." + ) model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + raise ValueError( + f"Model {embedding_model} does not have an embedding dimension" + ) vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -345,22 +371,37 @@ async def unregister_vector_db(self, vector_db_id: str) -> None: class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse( + data=await self.get_all_with_type(ResourceType.dataset.value) + ) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) async def register_dataset( self, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, + purpose: DatasetPurpose, + source: DataSource, metadata: Optional[Dict[str, Any]] = None, - ) -> None: - if provider_dataset_id is None: - provider_dataset_id = dataset_id + dataset_id: Optional[str] = None, + ) -> Dataset: + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if source.type == DatasetType.rows: + provider_id = "localfs" + elif source.type == DatasetType.uri: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + if provider_id is None: # If provider_id not specified, use the only provider if it supports this dataset if len(self.impls_by_provider_id) == 1: @@ -371,15 +412,18 @@ async def register_dataset( ) if metadata is None: metadata = {} + dataset = Dataset( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, - dataset_schema=dataset_schema, - url=url, + purpose=purpose, + source=source, metadata=metadata, ) + await self.register_object(dataset) + return dataset async def unregister_dataset(self, dataset_id: str) -> None: dataset = await self.get_dataset(dataset_id) @@ -390,7 +434,9 @@ async def unregister_dataset(self, dataset_id: str) -> None: class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -487,8 +533,12 @@ async def register_tool_group( args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( From 9b38ae9323a34df19a6a455051b6a8845e65554a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:45:29 -0700 Subject: [PATCH 6/7] huggingface --- llama_stack/distribution/routers/routing_tables.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ec7abba90c..f626f209c3 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -402,14 +402,6 @@ async def register_dataset( else: raise ValueError(f"Unknown data source type: {source.type}") - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this dataset - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) if metadata is None: metadata = {} From f262bfd061239971aa3d35a62ac41bf963d35e70 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:48:26 -0700 Subject: [PATCH 7/7] fix --- llama_stack/distribution/routers/routers.py | 135 +++++------------- .../distribution/routers/routing_tables.py | 52 ++----- .../ui/page/evaluations/native_eval.py | 4 +- .../inline/datasetio/localfs/datasetio.py | 12 +- .../inline/eval/meta_reference/eval.py | 72 ++-------- .../recipes/lora_finetuning_single_device.py | 72 +++------- .../providers/inline/scoring/basic/scoring.py | 20 +-- .../inline/scoring/braintrust/braintrust.py | 33 ++--- .../inline/scoring/llm_as_judge/scoring.py | 18 +-- .../datasetio/huggingface/huggingface.py | 8 +- 10 files changed, 107 insertions(+), 319 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 879fc924b4..2cf38f5440 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,9 +8,9 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, InterleavedContentItem, - URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource @@ -94,9 +94,7 @@ async def register_vector_db( provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug( - f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" - ) + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -114,9 +112,7 @@ async def insert_chunks( logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( - vector_db_id, chunks, ttl_seconds - ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -125,9 +121,7 @@ async def query_chunks( params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( - vector_db_id, query, params - ) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) class InferenceRouter(Inference): @@ -164,9 +158,7 @@ async def register_model( logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata, model_type - ) + await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( self, @@ -220,16 +212,11 @@ async def _compute_and_log_token_usage( total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics( - prompt_tokens, completion_tokens, total_tokens, model - ) + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, @@ -254,9 +241,7 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ - ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] - ]: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -266,19 +251,12 @@ async def chat_completion( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if ( - tool_prompt_format - and tool_prompt_format != tool_config.tool_prompt_format - ): - raise ValueError( - "tool_prompt_format and tool_config.tool_prompt_format must match" - ) + if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: + raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: params = {} if tool_choice: @@ -296,14 +274,9 @@ async def chat_completion( pass else: # verify tool_choice is one of the tools - tool_names = [ - t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value - for t in tools - ] + tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] if tool_config.tool_choice not in tool_names: - raise ValueError( - f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" - ) + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") params = dict( model_id=model_id, @@ -318,25 +291,17 @@ async def chat_completion( tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens( - messages, tool_config.tool_prompt_format - ) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.progress - ): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.complete - ): + if chunk.event.event_type == ChatCompletionResponseEventType.complete: completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -353,11 +318,7 @@ async def stream_generator(): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -374,9 +335,7 @@ async def stream_generator(): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def completion( @@ -397,9 +356,7 @@ async def completion( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -419,11 +376,7 @@ async def stream_generator(): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if ( - hasattr(chunk, "stop_reason") - and chunk.stop_reason - and self.telemetry - ): + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -432,11 +385,7 @@ async def stream_generator(): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -450,9 +399,7 @@ async def stream_generator(): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def embeddings( @@ -468,9 +415,7 @@ async def embeddings( if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError( - f"Model '{model_id}' is an LLM model and does not support embeddings" - ) + raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -504,9 +449,7 @@ async def register_shield( params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield( - shield_id, provider_shield_id, provider_id, params - ) + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( self, @@ -603,9 +546,7 @@ async def score_batch( logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score_batch( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -623,15 +564,11 @@ async def score( input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug( - f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" - ) + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -674,9 +611,7 @@ async def evaluate_rows( scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug( - f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" - ) + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -690,9 +625,7 @@ async def job_status( job_id: str, ) -> Optional[JobStatus]: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status( - benchmark_id, job_id - ) + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( self, @@ -746,9 +679,9 @@ async def insert( logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl( - "insert_into_memory" - ).insert(documents, vector_db_id, chunk_size_in_tokens) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) def __init__( self, @@ -781,6 +714,4 @@ async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools( - tool_group_id, mcp_endpoint - ) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f626f209c3..589a03b25d 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -105,9 +105,7 @@ def __init__( self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects( - objs: List[RoutableObjectWithProvider], provider_id: str, cls - ) -> None: + async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -142,9 +140,7 @@ async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl( - self, routing_key: str, provider_id: Optional[str] = None - ) -> Any: + def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -182,9 +178,7 @@ def apiname_object(): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier( - self, type: str, identifier: str - ) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -194,13 +188,9 @@ async def get_object_by_identifier( async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider( - obj, self.impls_by_provider_id[obj.provider_id] - ) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - async def register_object( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -255,9 +245,7 @@ async def register_model( if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError( - "Embedding model must have an embedding dimension in its metadata" - ) + raise ValueError("Embedding model must have an embedding dimension in its metadata") model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -277,9 +265,7 @@ async def unregister_model(self, model_id: str) -> None: class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse( - data=await self.get_all_with_type(ResourceType.shield.value) - ) + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier("shield", identifier) @@ -338,18 +324,14 @@ async def register_vector_db( f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError( - "No provider available. Please configure a vector_io provider." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError( - f"Model {embedding_model} does not have an embedding dimension" - ) + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -371,9 +353,7 @@ async def unregister_vector_db(self, vector_db_id: str) -> None: class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse( - data=await self.get_all_with_type(ResourceType.dataset.value) - ) + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) @@ -426,9 +406,7 @@ async def unregister_dataset(self, dataset_id: str) -> None: class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse( - data=await self.get_all_with_type(ResourceType.scoring_function.value) - ) + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -525,12 +503,8 @@ async def register_tool_group( args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( - toolgroup_id, mcp_endpoint - ) - tool_host = ( - ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - ) + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution for tool_def in tool_defs: tools.append( diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index e14c1e01b6..5ce5bc5d26 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -230,9 +230,7 @@ def run_evaluation_3(): output_res[scoring_fn] = [] output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) - progress_text_container.write( - f"Expand to see current processed result ({i + 1} / {len(rows)})" - ) + progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})") results_container.json(eval_res, expanded=2) progress_bar.progress(1.0, text="Evaluation complete!") diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 54e9900526..afa9ee0ff5 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -161,9 +161,7 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None new_rows_df = pandas.DataFrame(rows) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) - dataset_impl.df = pandas.concat( - [dataset_impl.df, new_rows_df], ignore_index=True - ) + dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) url = str(dataset_info.dataset_def.url.uri) parsed_url = urlparse(url) @@ -178,12 +176,8 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None raise ValueError("Data URL must be a base64-encoded CSV") csv_buffer = dataset_impl.df.to_csv(index=False) - base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( - "utf-8" - ) - dataset_info.dataset_def.url = URL( - uri=f"data:text/csv;base64,{base64_content}" - ) + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8") + dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}") else: raise ValueError( f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index e55c588888..67e2eb193a 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -89,16 +89,10 @@ async def run_eval( dataset_id = task_def.dataset_id scoring_functions = task_def.scoring_functions dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema( - dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) - ) + validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, - rows_in_page=( - -1 - if benchmark_config.num_examples is None - else benchmark_config.num_examples - ), + rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), ) res = await self.evaluate_rows( benchmark_id=benchmark_id, @@ -124,14 +118,10 @@ async def _run_agent_generation( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [ - UserMessage(**x) for x in input_messages if x["role"] == "user" - ] + input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session( - agent_id, f"session-{i}" - ) + session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") session_id = session_create_response.session_id turn_request = dict( @@ -140,12 +130,7 @@ async def _run_agent_generation( messages=input_messages, stream=True, ) - turn_response = [ - chunk - async for chunk in await self.agents_api.create_agent_turn( - **turn_request - ) - ] + turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] final_event = turn_response[-1].event.payload # check if there's a memory retrieval step and extract the context @@ -154,14 +139,10 @@ async def _run_agent_generation( if step.step_type == StepType.tool_execution.value: for tool_response in step.tool_responses: if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join( - x.text for x in tool_response.content - ) + memory_rag_context = " ".join(x.text for x in tool_response.content) agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = ( - final_event.turn.output_message.content - ) + agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content if memory_rag_context: agent_generation[ColumnName.context.value] = memory_rag_context @@ -173,9 +154,7 @@ async def _run_model_generation( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = benchmark_config.eval_candidate - assert ( - candidate.sampling_params.max_tokens is not None - ), "SamplingParams.max_tokens must be provided" + assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" generations = [] for x in tqdm(input_rows): @@ -186,39 +165,21 @@ async def _run_model_generation( content=input_content, sampling_params=candidate.sampling_params, ) - generations.append( - { - ColumnName.generated_answer.value: response.completion_message.content - } - ) + generations.append({ColumnName.generated_answer.value: response.completion_message.content}) elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads( - x[ColumnName.chat_completion_input.value] - ) - input_messages = [ - UserMessage(**x) - for x in chat_completion_input_json - if x["role"] == "user" - ] + chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) + input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages += [ - SystemMessage(**x) - for x in chat_completion_input_json - if x["role"] == "system" - ] + messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) - generations.append( - { - ColumnName.generated_answer.value: response.completion_message.content - } - ) + generations.append({ColumnName.generated_answer.value: response.completion_message.content}) else: raise ValueError("Invalid input row") @@ -241,8 +202,7 @@ async def evaluate_rows( # scoring with generated_answer score_input_rows = [ - input_r | generated_r - for input_r, generated_r in zip(input_rows, generations, strict=False) + input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) ] if benchmark_config.scoring_params is not None: @@ -251,9 +211,7 @@ async def evaluate_rows( for scoring_fn_id in scoring_functions } else: - scoring_functions_dict = { - scoring_fn_id: None for scoring_fn_id in scoring_functions - } + scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} score_response = await self.scoring_api.score( input_rows=score_input_rows, scoring_functions=scoring_functions_dict diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 24ad95b3d1..482bbd309e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -17,7 +17,8 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training, utils as torchtune_utils +from torchtune import modules, training +from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -88,9 +89,7 @@ def __init__( self.job_uuid = job_uuid self.training_config = training_config if not isinstance(algorithm_config, LoraFinetuningConfig): - raise ValueError( - "You need to speicifc LoraFinetuningConfig for LoRA finetuning" - ) + raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning") self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device() self._dtype = training.get_dtype(training_config.dtype, device=self._device) @@ -99,10 +98,7 @@ def __init__( def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) - paths = [ - Path(checkpoint_dir / f"consolidated.{ext}") - for ext in ["pth", "00.pth"] - ] + paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] if not any(p.exists() for p in paths): checkpoint_dir = checkpoint_dir / "original" @@ -117,9 +113,7 @@ def model_checkpoint_dir(model) -> str: else: model = resolve_model(self.model_id) if model is None: - raise ValueError( - f"{self.model_id} not found. Your model id should be in the llama models SKU list" - ) + raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") self.checkpoint_dir = model_checkpoint_dir(model) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) @@ -191,9 +185,7 @@ async def setup(self) -> None: self._tokenizer = await self._setup_tokenizer() log.info("Tokenizer is initialized.") - self._optimizer = await self._setup_optimizer( - optimizer_config=self.training_config.optimizer_config - ) + self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config) log.info("Optimizer is initialized.") self._loss_fn = CEWithChunkedOutputLoss() @@ -221,13 +213,8 @@ async def setup(self) -> None: # by the dataloader and the max_steps_per_epoch param set by the user and is used # for logging and tracking training state. This should be computed after the dataloader # has been setup - self._steps_per_epoch = ( - len(self._training_dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): + self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps + if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch @@ -241,9 +228,7 @@ async def setup(self) -> None: log.info("Learning rate scheduler is initialized.") # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (self._batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) + self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device) def _log_memory_stats(self): # torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing @@ -284,13 +269,9 @@ async def _setup_model( set_trainable_params(model, self.adapter_params) if enable_activation_checkpointing: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} - ) + training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}) - base_missing, base_unexpected = model.load_state_dict( - base_model_state_dict, strict=False - ) + base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False) # This is for any adapters that need to be initialized after base weights # have been loaded (e.g. DoRA). @@ -299,9 +280,7 @@ async def _setup_model( if hasattr(m, "initialize_dora_magnitude"): m.initialize_dora_magnitude() if lora_weights_state_dict: - lora_missing, lora_unexpected = model.load_state_dict( - lora_weights_state_dict, strict=False - ) + lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False) else: lora_missing, lora_unexpected = None, None validate_missing_and_unexpected_for_lora( @@ -315,14 +294,10 @@ async def _setup_model( ) # Validate model adapter params were loaded in with the expected dtype - training.validate_expected_param_dtype( - self.adapter_params.items(), dtype=self._dtype - ) + training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype) # activation offloading - self.activations_handling_ctx = training.get_act_offloading_ctx_manager( - model, enable_activation_offloading - ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading) self._log_memory_stats() @@ -458,9 +433,7 @@ async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) + labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])) if not isinstance(logits, list): labels = labels.reshape(-1) logits = logits.reshape(-1, logits.size(-1)) @@ -489,9 +462,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True - metric_logger = DiskLogger( - log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log" - ) + metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log") self._training_sampler.set_epoch(curr_epoch) loss_to_log = 0.0 @@ -499,8 +470,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: for idx, batch in enumerate(self._training_dataloader): if ( self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) - == self.max_steps_per_epoch + and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch ): break @@ -508,9 +478,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() + current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() num_tokens += current_num_tokens # Loss is normalized by default so we multiply by the number of tokens @@ -535,9 +503,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: loss_to_log = running_loss.item() / num_tokens pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" - ) + pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") time_per_step = time.perf_counter() - t0 log_dict = { diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 2a0613bfe9..915c33c8d4 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -64,15 +64,11 @@ async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = [ - fn_def - for impl in self.scoring_fn_id_impls.values() - for fn_def in impl.get_supported_scoring_fn_defs() + fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() ] for f in scoring_fn_defs_list: - assert f.identifier.startswith( - "basic" - ), "All basic scoring fn must have identifier prefixed with 'basic'! " + assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! " return scoring_fn_defs_list @@ -86,9 +82,7 @@ async def score_batch( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema( - dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) - ) + validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, @@ -118,12 +112,8 @@ async def score( raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score( - input_rows, scoring_fn_id, scoring_fn_params - ) - agg_results = await scoring_fn.aggregate( - score_results, scoring_fn_id, scoring_fn_params - ) + score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) + agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 81b309b8fe..1f5c3e147e 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -122,12 +122,10 @@ def __init__( self.datasets_api = datasets_api self.braintrust_evaluators = { - entry.identifier: entry.evaluator - for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY + entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } self.supported_fn_defs_registry = { - entry.identifier: entry.fn_def - for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY + entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } async def initialize(self) -> None: ... @@ -137,16 +135,14 @@ async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) for f in scoring_fn_defs_list: - assert f.identifier.startswith( - "braintrust" - ), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " + assert f.identifier.startswith("braintrust"), ( + "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " + ) return scoring_fn_defs_list async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: - raise NotImplementedError( - "Registering scoring function not allowed for braintrust provider" - ) + raise NotImplementedError("Registering scoring function not allowed for braintrust provider") async def set_api_key(self) -> None: # api key is in the request headers @@ -169,17 +165,13 @@ async def score_batch( await self.set_api_key() dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema( - dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) - ) + validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) - res = await self.score( - input_rows=all_rows.rows, scoring_functions=scoring_functions - ) + res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions) if save_results_dataset: # TODO: persist and register dataset on to server for reading # self.datasets_api.register_dataset() @@ -220,13 +212,8 @@ async def score( if scoring_fn_id not in self.supported_fn_defs_registry: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - score_results = [ - await self.score_row(input_row, scoring_fn_id) - for input_row in input_rows - ] - aggregation_functions = self.supported_fn_defs_registry[ - scoring_fn_id - ].params.aggregation_functions + score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows] + aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions # override scoring_fn params if provided if scoring_functions[scoring_fn_id] is not None: diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index d6d051f9e2..c6e0d39c91 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -54,9 +54,9 @@ async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): - assert f.identifier.startswith( - "llm-as-judge" - ), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " + assert f.identifier.startswith("llm-as-judge"), ( + "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " + ) return scoring_fn_defs_list @@ -70,9 +70,7 @@ async def score_batch( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema( - dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) - ) + validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, @@ -100,12 +98,8 @@ async def score( for scoring_fn_id in scoring_functions.keys(): scoring_fn = self.llm_as_judge_fn scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score( - input_rows, scoring_fn_id, scoring_fn_params - ) - agg_results = await scoring_fn.aggregate( - score_results, scoring_fn_id, scoring_fn_params - ) + score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) + agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 8a6599f8f0..d59edda301 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -104,13 +104,9 @@ async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None new_dataset = hf_datasets.Dataset.from_list(rows) # Concatenate the new rows with existing dataset - updated_dataset = hf_datasets.concatenate_datasets( - [loaded_dataset, new_dataset] - ) + updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) if dataset_def.metadata.get("path", None): updated_dataset.push_to_hub(dataset_def.metadata["path"]) else: - raise NotImplementedError( - "Uploading to URL-based datasets is not supported yet" - ) + raise NotImplementedError("Uploading to URL-based datasets is not supported yet")