|
21 | 21 | ServiceConfiguration, |
22 | 22 | UserDataCollection, |
23 | 23 | TLSConfiguration, |
| 24 | + CORSConfiguration, |
24 | 25 | ModelContextProtocolServer, |
25 | 26 | InferenceConfiguration, |
26 | 27 | ) |
@@ -214,6 +215,31 @@ def test_user_data_collection_transcripts_disabled() -> None: |
214 | 215 | UserDataCollection(transcripts_enabled=True, transcripts_storage=None) |
215 | 216 |
|
216 | 217 |
|
| 218 | +def test_cors_default_configuration() -> None: |
| 219 | + """Test the CORS configuration.""" |
| 220 | + cfg = CORSConfiguration() |
| 221 | + assert cfg is not None |
| 222 | + assert cfg.allow_origins == ["*"] |
| 223 | + assert cfg.allow_credentials is True |
| 224 | + assert cfg.allow_methods == ["*"] |
| 225 | + assert cfg.allow_headers == ["*"] |
| 226 | + |
| 227 | + |
| 228 | +def test_cors_custom_configuration() -> None: |
| 229 | + """Test the CORS configuration.""" |
| 230 | + cfg = CORSConfiguration( |
| 231 | + allow_origins=["foo_origin", "bar_origin", "baz_origin"], |
| 232 | + allow_credentials=False, |
| 233 | + allow_methods=["foo_method", "bar_method", "baz_method"], |
| 234 | + allow_headers=["foo_header", "bar_header", "baz_header"], |
| 235 | + ) |
| 236 | + assert cfg is not None |
| 237 | + assert cfg.allow_origins == ["foo_origin", "bar_origin", "baz_origin"] |
| 238 | + assert cfg.allow_credentials is False |
| 239 | + assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"] |
| 240 | + assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"] |
| 241 | + |
| 242 | + |
217 | 243 | def test_tls_configuration() -> None: |
218 | 244 | """Test the TLS configuration.""" |
219 | 245 | cfg = TLSConfiguration( |
@@ -437,7 +463,13 @@ def test_dump_configuration(tmp_path) -> None: |
437 | 463 | tls_certificate_path=Path("tests/configuration/server.crt"), |
438 | 464 | tls_key_path=Path("tests/configuration/server.key"), |
439 | 465 | tls_key_password=Path("tests/configuration/password"), |
440 | | - ) |
| 466 | + ), |
| 467 | + cors=CORSConfiguration( |
| 468 | + allow_origins=["foo_origin", "bar_origin", "baz_origin"], |
| 469 | + allow_credentials=False, |
| 470 | + allow_methods=["foo_method", "bar_method", "baz_method"], |
| 471 | + allow_headers=["foo_header", "bar_header", "baz_header"], |
| 472 | + ), |
441 | 473 | ), |
442 | 474 | llama_stack=LlamaStackConfiguration( |
443 | 475 | use_as_library_client=True, |
@@ -488,6 +520,24 @@ def test_dump_configuration(tmp_path) -> None: |
488 | 520 | "tls_key_password": "tests/configuration/password", |
489 | 521 | "tls_key_path": "tests/configuration/server.key", |
490 | 522 | }, |
| 523 | + "cors": { |
| 524 | + "allow_credentials": False, |
| 525 | + "allow_headers": [ |
| 526 | + "foo_header", |
| 527 | + "bar_header", |
| 528 | + "baz_header", |
| 529 | + ], |
| 530 | + "allow_methods": [ |
| 531 | + "foo_method", |
| 532 | + "bar_method", |
| 533 | + "baz_method", |
| 534 | + ], |
| 535 | + "allow_origins": [ |
| 536 | + "foo_origin", |
| 537 | + "bar_origin", |
| 538 | + "baz_origin", |
| 539 | + ], |
| 540 | + }, |
491 | 541 | }, |
492 | 542 | "llama_stack": { |
493 | 543 | "url": None, |
|
0 commit comments