Skip to content

Commit c39eefa

Browse files
committed
Validators refactor
1 parent 6d8f1d9 commit c39eefa

File tree

9 files changed

+193
-244
lines changed

9 files changed

+193
-244
lines changed

openapi_core/templating/paths/finders.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,8 @@ def __init__(self, spec: Spec, base_url: Optional[str] = None):
2929
def find(
3030
self,
3131
method: str,
32-
host_url: str,
33-
path: str,
34-
path_pattern: Optional[str] = None,
32+
full_url: str,
3533
) -> ServerOperationPath:
36-
if path_pattern is not None:
37-
full_url = urljoin(host_url, path_pattern)
38-
else:
39-
full_url = urljoin(host_url, path)
40-
4134
paths_iter = self._get_paths_iter(full_url)
4235
paths_iter_peek = peekable(paths_iter)
4336

openapi_core/validation/request/__init__.py

+28-32
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openapi_core.unmarshalling.schemas import (
66
oas31_schema_unmarshallers_factory,
77
)
8-
from openapi_core.validation.request.proxies import DetectRequestValidatorProxy
8+
from openapi_core.validation.request.proxies import RequestValidatorProxy
99
from openapi_core.validation.request.validators import RequestBodyValidator
1010
from openapi_core.validation.request.validators import (
1111
RequestParametersValidator,
@@ -32,29 +32,37 @@
3232
"openapi_request_validator",
3333
]
3434

35-
openapi_v30_request_body_validator = RequestBodyValidator(
35+
openapi_v30_request_body_validator = RequestValidatorProxy(
36+
RequestBodyValidator,
3637
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
3738
)
38-
openapi_v30_request_parameters_validator = RequestParametersValidator(
39+
openapi_v30_request_parameters_validator = RequestValidatorProxy(
40+
RequestParametersValidator,
3941
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
4042
)
41-
openapi_v30_request_security_validator = RequestSecurityValidator(
43+
openapi_v30_request_security_validator = RequestValidatorProxy(
44+
RequestSecurityValidator,
4245
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
4346
)
44-
openapi_v30_request_validator = RequestValidator(
47+
openapi_v30_request_validator = RequestValidatorProxy(
48+
RequestValidator,
4549
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
4650
)
4751

48-
openapi_v31_request_body_validator = RequestBodyValidator(
52+
openapi_v31_request_body_validator = RequestValidatorProxy(
53+
RequestBodyValidator,
4954
schema_unmarshallers_factory=oas31_schema_unmarshallers_factory,
5055
)
51-
openapi_v31_request_parameters_validator = RequestParametersValidator(
56+
openapi_v31_request_parameters_validator = RequestValidatorProxy(
57+
RequestParametersValidator,
5258
schema_unmarshallers_factory=oas31_schema_unmarshallers_factory,
5359
)
54-
openapi_v31_request_security_validator = RequestSecurityValidator(
60+
openapi_v31_request_security_validator = RequestValidatorProxy(
61+
RequestSecurityValidator,
5562
schema_unmarshallers_factory=oas31_schema_unmarshallers_factory,
5663
)
57-
openapi_v31_request_validator = RequestValidator(
64+
openapi_v31_request_validator = RequestValidatorProxy(
65+
RequestValidator,
5866
schema_unmarshallers_factory=oas31_schema_unmarshallers_factory,
5967
)
6068

@@ -67,27 +75,15 @@
6775
openapi_v3_request_validator = openapi_v31_request_validator
6876

6977
# detect version spec
70-
openapi_request_body_validator = DetectRequestValidatorProxy(
71-
{
72-
("openapi", "3.0"): openapi_v30_request_body_validator,
73-
("openapi", "3.1"): openapi_v31_request_body_validator,
74-
},
75-
)
76-
openapi_request_parameters_validator = DetectRequestValidatorProxy(
77-
{
78-
("openapi", "3.0"): openapi_v30_request_parameters_validator,
79-
("openapi", "3.1"): openapi_v31_request_parameters_validator,
80-
},
81-
)
82-
openapi_request_security_validator = DetectRequestValidatorProxy(
83-
{
84-
("openapi", "3.0"): openapi_v30_request_security_validator,
85-
("openapi", "3.1"): openapi_v31_request_security_validator,
86-
},
87-
)
88-
openapi_request_validator = DetectRequestValidatorProxy(
89-
{
90-
("openapi", "3.0"): openapi_v30_request_validator,
91-
("openapi", "3.1"): openapi_v31_request_validator,
92-
},
78+
openapi_request_body_validator = RequestValidatorProxy(
79+
RequestBodyValidator,
80+
)
81+
openapi_request_parameters_validator = RequestValidatorProxy(
82+
RequestParametersValidator,
83+
)
84+
openapi_request_security_validator = RequestValidatorProxy(
85+
RequestSecurityValidator,
86+
)
87+
openapi_request_validator = RequestValidatorProxy(
88+
RequestValidator,
9389
)
+18-21
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,45 @@
11
"""OpenAPI spec validator validation proxies module."""
22
from typing import Any
3-
from typing import Hashable
43
from typing import Iterator
5-
from typing import Mapping
64
from typing import Optional
7-
from typing import Tuple
5+
from typing import Type
86

9-
from openapi_core.exceptions import OpenAPIError
107
from openapi_core.spec import Spec
11-
from openapi_core.validation.exceptions import ValidatorDetectError
128
from openapi_core.validation.request.datatypes import RequestValidationResult
139
from openapi_core.validation.request.protocols import Request
1410
from openapi_core.validation.request.validators import BaseRequestValidator
1511

1612

17-
class DetectRequestValidatorProxy:
13+
class RequestValidatorProxy:
1814
def __init__(
19-
self, choices: Mapping[Tuple[str, str], BaseRequestValidator]
15+
self,
16+
validator_cls: Type[BaseRequestValidator],
17+
**validator_kwargs: Any,
2018
):
21-
self.choices = choices
22-
23-
def detect(self, spec: Spec) -> BaseRequestValidator:
24-
for (key, value), validator in self.choices.items():
25-
if key in spec and spec[key].startswith(value):
26-
return validator
27-
raise ValidatorDetectError("Spec schema version not detected")
19+
self.validator_cls = validator_cls
20+
self.validator_kwargs = validator_kwargs
2821

2922
def validate(
3023
self,
3124
spec: Spec,
3225
request: Request,
3326
base_url: Optional[str] = None,
3427
) -> RequestValidationResult:
35-
validator = self.detect(spec)
36-
return validator.validate(spec, request, base_url=base_url)
28+
validator = self.validator_cls(
29+
spec, base_url=base_url, **self.validator_kwargs
30+
)
31+
return validator.validate(request)
3732

3833
def is_valid(
3934
self,
4035
spec: Spec,
4136
request: Request,
4237
base_url: Optional[str] = None,
4338
) -> bool:
44-
validator = self.detect(spec)
45-
error = next(
46-
validator.iter_errors(spec, request, base_url=base_url), None
39+
validator = self.validator_cls(
40+
spec, base_url=base_url, **self.validator_kwargs
4741
)
42+
error = next(validator.iter_errors(request), None)
4843
return error is None
4944

5045
def iter_errors(
@@ -53,5 +48,7 @@ def iter_errors(
5348
request: Request,
5449
base_url: Optional[str] = None,
5550
) -> Iterator[Exception]:
56-
validator = self.detect(spec)
57-
yield from validator.iter_errors(spec, request, base_url=base_url)
51+
validator = self.validator_cls(
52+
spec, base_url=base_url, **self.validator_kwargs
53+
)
54+
yield from validator.iter_errors(request)

openapi_core/validation/request/validators.py

+39-64
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from openapi_core.spec.paths import Spec
2929
from openapi_core.templating.media_types.exceptions import MediaTypeFinderError
3030
from openapi_core.templating.paths.exceptions import PathError
31-
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
31+
from openapi_core.unmarshalling.schemas import (
32+
oas30_request_schema_unmarshallers_factory,
33+
)
34+
from openapi_core.unmarshalling.schemas import (
35+
oas31_schema_unmarshallers_factory,
36+
)
3237
from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError
3338
from openapi_core.unmarshalling.schemas.exceptions import ValidateError
3439
from openapi_core.unmarshalling.schemas.factories import (
@@ -50,37 +55,39 @@
5055

5156

5257
class BaseRequestValidator(BaseValidator):
58+
59+
schema_unmarshallers_factories = {
60+
("openapi", "3.0"): oas30_request_schema_unmarshallers_factory,
61+
("openapi", "3.1"): oas31_schema_unmarshallers_factory,
62+
}
63+
5364
def __init__(
5465
self,
55-
schema_unmarshallers_factory: SchemaUnmarshallersFactory,
66+
spec: Spec,
67+
base_url: Optional[str] = None,
68+
schema_unmarshallers_factory: Optional[
69+
SchemaUnmarshallersFactory
70+
] = None,
5671
schema_casters_factory: SchemaCastersFactory = schema_casters_factory,
5772
parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory,
5873
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
5974
security_provider_factory: SecurityProviderFactory = security_provider_factory,
6075
):
6176
super().__init__(
62-
schema_unmarshallers_factory,
77+
spec,
78+
base_url=base_url,
79+
schema_unmarshallers_factory=schema_unmarshallers_factory,
6380
schema_casters_factory=schema_casters_factory,
6481
parameter_deserializers_factory=parameter_deserializers_factory,
6582
media_type_deserializers_factory=media_type_deserializers_factory,
6683
)
6784
self.security_provider_factory = security_provider_factory
6885

69-
def iter_errors(
70-
self,
71-
spec: Spec,
72-
request: Request,
73-
base_url: Optional[str] = None,
74-
) -> Iterator[Exception]:
75-
result = self.validate(spec, request, base_url=base_url)
86+
def iter_errors(self, request: Request) -> Iterator[Exception]:
87+
result = self.validate(request)
7688
yield from result.errors
7789

78-
def validate(
79-
self,
80-
spec: Spec,
81-
request: Request,
82-
base_url: Optional[str] = None,
83-
) -> RequestValidationResult:
90+
def validate(self, request: Request) -> RequestValidationResult:
8491
raise NotImplementedError
8592

8693
def _get_parameters(
@@ -143,11 +150,11 @@ def _get_parameter(self, param: Spec, request: Request) -> Any:
143150
raise MissingParameter(name)
144151

145152
def _get_security(
146-
self, spec: Spec, request: Request, operation: Spec
153+
self, request: Request, operation: Spec
147154
) -> Optional[Dict[str, str]]:
148155
security = None
149-
if "security" in spec:
150-
security = spec / "security"
156+
if "security" in self.spec:
157+
security = self.spec / "security"
151158
if "security" in operation:
152159
security = operation / "security"
153160

@@ -157,20 +164,16 @@ def _get_security(
157164
for security_requirement in security:
158165
try:
159166
return {
160-
scheme_name: self._get_security_value(
161-
spec, scheme_name, request
162-
)
167+
scheme_name: self._get_security_value(scheme_name, request)
163168
for scheme_name in list(security_requirement.keys())
164169
}
165170
except SecurityError:
166171
continue
167172

168173
raise InvalidSecurity
169174

170-
def _get_security_value(
171-
self, spec: Spec, scheme_name: str, request: Request
172-
) -> Any:
173-
security_schemes = spec / "components#securitySchemes"
175+
def _get_security_value(self, scheme_name: str, request: Request) -> Any:
176+
security_schemes = self.spec / "components#securitySchemes"
174177
if scheme_name not in security_schemes:
175178
return
176179
scheme = security_schemes[scheme_name]
@@ -207,16 +210,9 @@ def _get_body_value(self, request_body: Spec, request: Request) -> Any:
207210

208211

209212
class RequestParametersValidator(BaseRequestValidator):
210-
def validate(
211-
self,
212-
spec: Spec,
213-
request: Request,
214-
base_url: Optional[str] = None,
215-
) -> RequestValidationResult:
213+
def validate(self, request: Request) -> RequestValidationResult:
216214
try:
217-
path, operation, _, path_result, _ = self._find_path(
218-
spec, request, base_url=base_url
219-
)
215+
path, operation, _, path_result, _ = self._find_path(request)
220216
except PathError as exc:
221217
return RequestValidationResult(errors=[exc])
222218

@@ -239,16 +235,9 @@ def validate(
239235

240236

241237
class RequestBodyValidator(BaseRequestValidator):
242-
def validate(
243-
self,
244-
spec: Spec,
245-
request: Request,
246-
base_url: Optional[str] = None,
247-
) -> RequestValidationResult:
238+
def validate(self, request: Request) -> RequestValidationResult:
248239
try:
249-
_, operation, _, _, _ = self._find_path(
250-
spec, request, base_url=base_url
251-
)
240+
_, operation, _, _, _ = self._find_path(request)
252241
except PathError as exc:
253242
return RequestValidationResult(errors=[exc])
254243

@@ -277,21 +266,14 @@ def validate(
277266

278267

279268
class RequestSecurityValidator(BaseRequestValidator):
280-
def validate(
281-
self,
282-
spec: Spec,
283-
request: Request,
284-
base_url: Optional[str] = None,
285-
) -> RequestValidationResult:
269+
def validate(self, request: Request) -> RequestValidationResult:
286270
try:
287-
_, operation, _, _, _ = self._find_path(
288-
spec, request, base_url=base_url
289-
)
271+
_, operation, _, _, _ = self._find_path(request)
290272
except PathError as exc:
291273
return RequestValidationResult(errors=[exc])
292274

293275
try:
294-
security = self._get_security(spec, request, operation)
276+
security = self._get_security(request, operation)
295277
except InvalidSecurity as exc:
296278
return RequestValidationResult(errors=[exc])
297279

@@ -302,22 +284,15 @@ def validate(
302284

303285

304286
class RequestValidator(BaseRequestValidator):
305-
def validate(
306-
self,
307-
spec: Spec,
308-
request: Request,
309-
base_url: Optional[str] = None,
310-
) -> RequestValidationResult:
287+
def validate(self, request: Request) -> RequestValidationResult:
311288
try:
312-
path, operation, _, path_result, _ = self._find_path(
313-
spec, request, base_url=base_url
314-
)
289+
path, operation, _, path_result, _ = self._find_path(request)
315290
# don't process if operation errors
316291
except PathError as exc:
317292
return RequestValidationResult(errors=[exc])
318293

319294
try:
320-
security = self._get_security(spec, request, operation)
295+
security = self._get_security(request, operation)
321296
except InvalidSecurity as exc:
322297
return RequestValidationResult(errors=[exc])
323298

0 commit comments

Comments
 (0)