99
1010from gradientai import GradientAI , AsyncGradientAI
1111from tests .utils import assert_matches_type
12- from gradientai .types .inference import ModelListResponse , ModelRetrieveResponse
12+ from gradientai .types .inference import Model , ModelListResponse
1313
1414base_url = os .environ .get ("TEST_API_BASE_URL" , "http://127.0.0.1:4010" )
1515
@@ -23,7 +23,7 @@ def test_method_retrieve(self, client: GradientAI) -> None:
2323 model = client .inference .models .retrieve (
2424 "llama3-8b-instruct" ,
2525 )
26- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
26+ assert_matches_type (Model , model , path = ["response" ])
2727
2828 @pytest .mark .skip ()
2929 @parametrize
@@ -35,7 +35,7 @@ def test_raw_response_retrieve(self, client: GradientAI) -> None:
3535 assert response .is_closed is True
3636 assert response .http_request .headers .get ("X-Stainless-Lang" ) == "python"
3737 model = response .parse ()
38- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
38+ assert_matches_type (Model , model , path = ["response" ])
3939
4040 @pytest .mark .skip ()
4141 @parametrize
@@ -47,7 +47,7 @@ def test_streaming_response_retrieve(self, client: GradientAI) -> None:
4747 assert response .http_request .headers .get ("X-Stainless-Lang" ) == "python"
4848
4949 model = response .parse ()
50- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
50+ assert_matches_type (Model , model , path = ["response" ])
5151
5252 assert cast (Any , response .is_closed ) is True
5353
@@ -97,7 +97,7 @@ async def test_method_retrieve(self, async_client: AsyncGradientAI) -> None:
9797 model = await async_client .inference .models .retrieve (
9898 "llama3-8b-instruct" ,
9999 )
100- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
100+ assert_matches_type (Model , model , path = ["response" ])
101101
102102 @pytest .mark .skip ()
103103 @parametrize
@@ -109,7 +109,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncGradientAI) -> Non
109109 assert response .is_closed is True
110110 assert response .http_request .headers .get ("X-Stainless-Lang" ) == "python"
111111 model = await response .parse ()
112- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
112+ assert_matches_type (Model , model , path = ["response" ])
113113
114114 @pytest .mark .skip ()
115115 @parametrize
@@ -121,7 +121,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncGradientAI)
121121 assert response .http_request .headers .get ("X-Stainless-Lang" ) == "python"
122122
123123 model = await response .parse ()
124- assert_matches_type (ModelRetrieveResponse , model , path = ["response" ])
124+ assert_matches_type (Model , model , path = ["response" ])
125125
126126 assert cast (Any , response .is_closed ) is True
127127
0 commit comments