1
+ # Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+
1
14
import json
15
+ from collections .abc import AsyncIterator
2
16
3
17
import pytest
4
18
5
- from smithy_python ._private .http import Response
6
- from smithy_python .interfaces .http import HeadersList
19
+ from smithy_python ._private .http import HTTPResponse
20
+ from smithy_python .async_utils import async_list
21
+ from smithy_python .interfaces .http import Fields
7
22
from smithy_python .protocolutils import RestJsonErrorInfo , parse_rest_json_error_info
8
23
from smithy_python .types import Document
9
24
10
25
11
- class _AsyncReader :
12
- def __init__ (self , body : str ):
13
- self ._body : bytes = body .encode ("utf-8" )
14
-
15
- async def read (self , size : int = - 1 ) -> bytes :
16
- result : bytes = self ._body
17
- if size <= 0 :
18
- self ._body = b""
19
- else :
20
- result = self ._body [:size ]
21
- self ._body = self ._body [size :]
22
- return result
23
-
24
-
25
26
@pytest .mark .parametrize (
26
27
"headers, body, expected" ,
27
28
[
@@ -82,17 +83,17 @@ async def read(self, size: int = -1) -> bytes:
82
83
],
83
84
)
84
85
async def test_parse_rest_json_error_info (
85
- headers : HeadersList , body : Document , expected : RestJsonErrorInfo
86
+ fields : Fields , body : Document , expected : RestJsonErrorInfo
86
87
) -> None :
87
- response = Response (
88
- status_code = 400 , headers = headers , body = _AsyncReader ( json .dumps (body ))
88
+ response = HTTPResponse (
89
+ status = 400 , fields = fields , body = async_list ([ json .dumps (body ). encode ()] )
89
90
)
90
91
actual = await parse_rest_json_error_info (response )
91
92
assert actual == expected
92
93
93
94
94
95
class _ExceptionThrowingBody :
95
- async def read (self , size : int = - 1 ) -> bytes :
96
+ def __aiter__ (self ) -> AsyncIterator [ bytes ] :
96
97
raise Exception ("Body unexpectedly accessed" )
97
98
98
99
@@ -117,8 +118,8 @@ async def read(self, size: int = -1) -> bytes:
117
118
],
118
119
)
119
120
async def test_parse_rest_json_error_info_without_body (
120
- headers : HeadersList , expected : RestJsonErrorInfo
121
+ fields : Fields , expected : RestJsonErrorInfo
121
122
) -> None :
122
- response = Response ( status_code = 400 , headers = headers , body = _ExceptionThrowingBody ())
123
+ response = HTTPResponse ( status = 400 , fields = fields , body = _ExceptionThrowingBody ())
123
124
actual = await parse_rest_json_error_info (response , check_body = False )
124
125
assert actual == expected
0 commit comments