28
28
from openapi_core .spec .paths import Spec
29
29
from openapi_core .templating .media_types .exceptions import MediaTypeFinderError
30
30
from openapi_core .templating .paths .exceptions import PathError
31
- from openapi_core .unmarshalling .schemas .enums import UnmarshalContext
31
+ from openapi_core .unmarshalling .schemas import (
32
+ oas30_request_schema_unmarshallers_factory ,
33
+ )
34
+ from openapi_core .unmarshalling .schemas import (
35
+ oas31_schema_unmarshallers_factory ,
36
+ )
32
37
from openapi_core .unmarshalling .schemas .exceptions import UnmarshalError
33
38
from openapi_core .unmarshalling .schemas .exceptions import ValidateError
34
39
from openapi_core .unmarshalling .schemas .factories import (
50
55
51
56
52
57
class BaseRequestValidator (BaseValidator ):
58
+
59
+ schema_unmarshallers_factories = {
60
+ ("openapi" , "3.0" ): oas30_request_schema_unmarshallers_factory ,
61
+ ("openapi" , "3.1" ): oas31_schema_unmarshallers_factory ,
62
+ }
63
+
53
64
def __init__ (
54
65
self ,
55
- schema_unmarshallers_factory : SchemaUnmarshallersFactory ,
66
+ spec : Spec ,
67
+ base_url : Optional [str ] = None ,
68
+ schema_unmarshallers_factory : Optional [
69
+ SchemaUnmarshallersFactory
70
+ ] = None ,
56
71
schema_casters_factory : SchemaCastersFactory = schema_casters_factory ,
57
72
parameter_deserializers_factory : ParameterDeserializersFactory = parameter_deserializers_factory ,
58
73
media_type_deserializers_factory : MediaTypeDeserializersFactory = media_type_deserializers_factory ,
59
74
security_provider_factory : SecurityProviderFactory = security_provider_factory ,
60
75
):
61
76
super ().__init__ (
62
- schema_unmarshallers_factory ,
77
+ spec ,
78
+ base_url = base_url ,
79
+ schema_unmarshallers_factory = schema_unmarshallers_factory ,
63
80
schema_casters_factory = schema_casters_factory ,
64
81
parameter_deserializers_factory = parameter_deserializers_factory ,
65
82
media_type_deserializers_factory = media_type_deserializers_factory ,
66
83
)
67
84
self .security_provider_factory = security_provider_factory
68
85
69
- def iter_errors (
70
- self ,
71
- spec : Spec ,
72
- request : Request ,
73
- base_url : Optional [str ] = None ,
74
- ) -> Iterator [Exception ]:
75
- result = self .validate (spec , request , base_url = base_url )
86
+ def iter_errors (self , request : Request ) -> Iterator [Exception ]:
87
+ result = self .validate (request )
76
88
yield from result .errors
77
89
78
- def validate (
79
- self ,
80
- spec : Spec ,
81
- request : Request ,
82
- base_url : Optional [str ] = None ,
83
- ) -> RequestValidationResult :
90
+ def validate (self , request : Request ) -> RequestValidationResult :
84
91
raise NotImplementedError
85
92
86
93
def _get_parameters (
@@ -143,11 +150,11 @@ def _get_parameter(self, param: Spec, request: Request) -> Any:
143
150
raise MissingParameter (name )
144
151
145
152
def _get_security (
146
- self , spec : Spec , request : Request , operation : Spec
153
+ self , request : Request , operation : Spec
147
154
) -> Optional [Dict [str , str ]]:
148
155
security = None
149
- if "security" in spec :
150
- security = spec / "security"
156
+ if "security" in self . spec :
157
+ security = self . spec / "security"
151
158
if "security" in operation :
152
159
security = operation / "security"
153
160
@@ -157,20 +164,16 @@ def _get_security(
157
164
for security_requirement in security :
158
165
try :
159
166
return {
160
- scheme_name : self ._get_security_value (
161
- spec , scheme_name , request
162
- )
167
+ scheme_name : self ._get_security_value (scheme_name , request )
163
168
for scheme_name in list (security_requirement .keys ())
164
169
}
165
170
except SecurityError :
166
171
continue
167
172
168
173
raise InvalidSecurity
169
174
170
- def _get_security_value (
171
- self , spec : Spec , scheme_name : str , request : Request
172
- ) -> Any :
173
- security_schemes = spec / "components#securitySchemes"
175
+ def _get_security_value (self , scheme_name : str , request : Request ) -> Any :
176
+ security_schemes = self .spec / "components#securitySchemes"
174
177
if scheme_name not in security_schemes :
175
178
return
176
179
scheme = security_schemes [scheme_name ]
@@ -207,16 +210,9 @@ def _get_body_value(self, request_body: Spec, request: Request) -> Any:
207
210
208
211
209
212
class RequestParametersValidator (BaseRequestValidator ):
210
- def validate (
211
- self ,
212
- spec : Spec ,
213
- request : Request ,
214
- base_url : Optional [str ] = None ,
215
- ) -> RequestValidationResult :
213
+ def validate (self , request : Request ) -> RequestValidationResult :
216
214
try :
217
- path , operation , _ , path_result , _ = self ._find_path (
218
- spec , request , base_url = base_url
219
- )
215
+ path , operation , _ , path_result , _ = self ._find_path (request )
220
216
except PathError as exc :
221
217
return RequestValidationResult (errors = [exc ])
222
218
@@ -239,16 +235,9 @@ def validate(
239
235
240
236
241
237
class RequestBodyValidator (BaseRequestValidator ):
242
- def validate (
243
- self ,
244
- spec : Spec ,
245
- request : Request ,
246
- base_url : Optional [str ] = None ,
247
- ) -> RequestValidationResult :
238
+ def validate (self , request : Request ) -> RequestValidationResult :
248
239
try :
249
- _ , operation , _ , _ , _ = self ._find_path (
250
- spec , request , base_url = base_url
251
- )
240
+ _ , operation , _ , _ , _ = self ._find_path (request )
252
241
except PathError as exc :
253
242
return RequestValidationResult (errors = [exc ])
254
243
@@ -277,21 +266,14 @@ def validate(
277
266
278
267
279
268
class RequestSecurityValidator (BaseRequestValidator ):
280
- def validate (
281
- self ,
282
- spec : Spec ,
283
- request : Request ,
284
- base_url : Optional [str ] = None ,
285
- ) -> RequestValidationResult :
269
+ def validate (self , request : Request ) -> RequestValidationResult :
286
270
try :
287
- _ , operation , _ , _ , _ = self ._find_path (
288
- spec , request , base_url = base_url
289
- )
271
+ _ , operation , _ , _ , _ = self ._find_path (request )
290
272
except PathError as exc :
291
273
return RequestValidationResult (errors = [exc ])
292
274
293
275
try :
294
- security = self ._get_security (spec , request , operation )
276
+ security = self ._get_security (request , operation )
295
277
except InvalidSecurity as exc :
296
278
return RequestValidationResult (errors = [exc ])
297
279
@@ -302,22 +284,15 @@ def validate(
302
284
303
285
304
286
class RequestValidator (BaseRequestValidator ):
305
- def validate (
306
- self ,
307
- spec : Spec ,
308
- request : Request ,
309
- base_url : Optional [str ] = None ,
310
- ) -> RequestValidationResult :
287
+ def validate (self , request : Request ) -> RequestValidationResult :
311
288
try :
312
- path , operation , _ , path_result , _ = self ._find_path (
313
- spec , request , base_url = base_url
314
- )
289
+ path , operation , _ , path_result , _ = self ._find_path (request )
315
290
# don't process if operation errors
316
291
except PathError as exc :
317
292
return RequestValidationResult (errors = [exc ])
318
293
319
294
try :
320
- security = self ._get_security (spec , request , operation )
295
+ security = self ._get_security (request , operation )
321
296
except InvalidSecurity as exc :
322
297
return RequestValidationResult (errors = [exc ])
323
298
0 commit comments