Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/llama_stack_client/resources/eval/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ...types.job import Job
from ..._base_client import make_request_options
from ...types.evaluate_response import EvaluateResponse
from ...types.eval.job_status_response import JobStatusResponse

__all__ = ["JobsResource", "AsyncJobsResource"]

Expand Down Expand Up @@ -124,7 +124,7 @@ def status(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> JobStatusResponse:
) -> Job:
"""
Get the status of a job.

Expand All @@ -146,7 +146,7 @@ def status(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=JobStatusResponse,
cast_to=Job,
)


Expand Down Expand Up @@ -254,7 +254,7 @@ async def status(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> JobStatusResponse:
) -> Job:
"""
Get the status of a job.

Expand All @@ -276,7 +276,7 @@ async def status(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=JobStatusResponse,
cast_to=Job,
)


Expand Down
10 changes: 7 additions & 3 deletions src/llama_stack_client/resources/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
JobResourceWithStreamingResponse,
AsyncJobResourceWithStreamingResponse,
)
from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params
from ...types import (
post_training_preference_optimize_params,
post_training_supervised_fine_tune_params,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import (
maybe_transform,
Expand All @@ -30,6 +33,7 @@
)
from ..._base_client import make_request_options
from ...types.post_training_job import PostTrainingJob
from ...types.algorithm_config_param import AlgorithmConfigParam

__all__ = ["PostTrainingResource", "AsyncPostTrainingResource"]

Expand Down Expand Up @@ -111,7 +115,7 @@ def supervised_fine_tune(
logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
model: str,
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN,
algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN,
checkpoint_dir: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -228,7 +232,7 @@ async def supervised_fine_tune(
logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
model: str,
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN,
algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN,
checkpoint_dir: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down
1 change: 1 addition & 0 deletions src/llama_stack_client/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .model_register_params import ModelRegisterParams as ModelRegisterParams
from .query_chunks_response import QueryChunksResponse as QueryChunksResponse
from .query_condition_param import QueryConditionParam as QueryConditionParam
from .algorithm_config_param import AlgorithmConfigParam as AlgorithmConfigParam
from .benchmark_config_param import BenchmarkConfigParam as BenchmarkConfigParam
from .list_datasets_response import ListDatasetsResponse as ListDatasetsResponse
from .provider_list_response import ProviderListResponse as ProviderListResponse
Expand Down
37 changes: 37 additions & 0 deletions src/llama_stack_client/types/algorithm_config_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from __future__ import annotations

from typing import List, Union
from typing_extensions import Literal, Required, TypeAlias, TypedDict

__all__ = ["AlgorithmConfigParam", "LoraFinetuningConfig", "QatFinetuningConfig"]


class LoraFinetuningConfig(TypedDict, total=False):
alpha: Required[int]

apply_lora_to_mlp: Required[bool]

apply_lora_to_output: Required[bool]

lora_attn_modules: Required[List[str]]

rank: Required[int]

type: Required[Literal["LoRA"]]

quantize_base: bool

use_dora: bool


class QatFinetuningConfig(TypedDict, total=False):
group_size: Required[int]

quantizer_name: Required[str]

type: Required[Literal["QAT"]]


AlgorithmConfigParam: TypeAlias = Union[LoraFinetuningConfig, QatFinetuningConfig]
2 changes: 0 additions & 2 deletions src/llama_stack_client/types/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from __future__ import annotations

from .job_status_response import JobStatusResponse as JobStatusResponse
7 changes: 0 additions & 7 deletions src/llama_stack_client/types/eval/job_status_response.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/llama_stack_client/types/job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing_extensions import Literal

from .._models import BaseModel

Expand All @@ -8,3 +9,5 @@

class Job(BaseModel):
job_id: str

status: Literal["completed", "in_progress", "failed", "scheduled"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@

from __future__ import annotations

from typing import Dict, List, Union, Iterable
from typing_extensions import Literal, Required, TypeAlias, TypedDict
from typing import Dict, Union, Iterable
from typing_extensions import Literal, Required, TypedDict

from .algorithm_config_param import AlgorithmConfigParam

__all__ = [
"PostTrainingSupervisedFineTuneParams",
"TrainingConfig",
"TrainingConfigDataConfig",
"TrainingConfigOptimizerConfig",
"TrainingConfigEfficiencyConfig",
"AlgorithmConfig",
"AlgorithmConfigLoraFinetuningConfig",
"AlgorithmConfigQatFinetuningConfig",
]


Expand All @@ -28,7 +27,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False):

training_config: Required[TrainingConfig]

algorithm_config: AlgorithmConfig
algorithm_config: AlgorithmConfigParam

checkpoint_dir: str

Expand Down Expand Up @@ -85,32 +84,3 @@ class TrainingConfig(TypedDict, total=False):
dtype: str

efficiency_config: TrainingConfigEfficiencyConfig


class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False):
alpha: Required[int]

apply_lora_to_mlp: Required[bool]

apply_lora_to_output: Required[bool]

lora_attn_modules: Required[List[str]]

rank: Required[int]

type: Required[Literal["LoRA"]]

quantize_base: bool

use_dora: bool


class AlgorithmConfigQatFinetuningConfig(TypedDict, total=False):
group_size: Required[int]

quantizer_name: Required[str]

type: Required[Literal["QAT"]]


AlgorithmConfig: TypeAlias = Union[AlgorithmConfigLoraFinetuningConfig, AlgorithmConfigQatFinetuningConfig]
2 changes: 1 addition & 1 deletion src/llama_stack_client/types/tool_invocation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class ToolInvocationResult(BaseModel):
content: InterleavedContent
content: Optional[InterleavedContent] = None
"""A image content item"""

error_code: Optional[int] = None
Expand Down
15 changes: 7 additions & 8 deletions tests/api_resources/eval/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

from tests.utils import assert_matches_type
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
from llama_stack_client.types import EvaluateResponse
from llama_stack_client.types.eval import JobStatusResponse
from llama_stack_client.types import Job, EvaluateResponse

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")

Expand Down Expand Up @@ -120,7 +119,7 @@ def test_method_status(self, client: LlamaStackClient) -> None:
job_id="job_id",
benchmark_id="benchmark_id",
)
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

@parametrize
def test_raw_response_status(self, client: LlamaStackClient) -> None:
Expand All @@ -132,7 +131,7 @@ def test_raw_response_status(self, client: LlamaStackClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
job = response.parse()
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

@parametrize
def test_streaming_response_status(self, client: LlamaStackClient) -> None:
Expand All @@ -144,7 +143,7 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"

job = response.parse()
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

assert cast(Any, response.is_closed) is True

Expand Down Expand Up @@ -268,7 +267,7 @@ async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None:
job_id="job_id",
benchmark_id="benchmark_id",
)
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

@parametrize
async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None:
Expand All @@ -280,7 +279,7 @@ async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) ->
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
job = await response.parse()
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

@parametrize
async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None:
Expand All @@ -292,7 +291,7 @@ async def test_streaming_response_status(self, async_client: AsyncLlamaStackClie
assert response.http_request.headers.get("X-Stainless-Lang") == "python"

job = await response.parse()
assert_matches_type(JobStatusResponse, job, path=["response"])
assert_matches_type(Job, job, path=["response"])

assert cast(Any, response.is_closed) is True

Expand Down