Skip to content

Commit 8cf5462

Browse files
authored
Merge pull request #406 from tisnik/lcore-405-configurable-cors-settings
LCORE-405: configurable CORS settings
2 parents db61452 + 8f3970b commit 8cf5462

File tree

7 files changed

+104
-5
lines changed

7 files changed

+104
-5
lines changed

docs/config.png

11 KB
Loading

docs/config.puml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ class "AuthenticationConfiguration" as src.models.config.AuthenticationConfigura
99
skip_tls_verification : bool
1010
check_authentication_model() -> Self
1111
}
12+
class "CORSConfiguration" as src.models.config.CORSConfiguration {
13+
allow_credentials : bool
14+
allow_headers : list[str]
15+
allow_methods : list[str]
16+
allow_origins : list[str]
17+
check_cors_configuration() -> Self
18+
}
1219
class "Configuration" as src.models.config.Configuration {
1320
authentication
1421
customization : Optional[Customization]
@@ -78,6 +85,7 @@ class "ServiceConfiguration" as src.models.config.ServiceConfiguration {
7885
access_log : bool
7986
auth_enabled : bool
8087
color_log : bool
88+
cors
8189
host : str
8290
port : int
8391
tls_config
@@ -98,6 +106,7 @@ class "UserDataCollection" as src.models.config.UserDataCollection {
98106
check_storage_location_is_set_when_needed() -> Self
99107
}
100108
src.models.config.AuthenticationConfiguration --* src.models.config.Configuration : authentication
109+
src.models.config.CORSConfiguration --* src.models.config.ServiceConfiguration : cors
101110
src.models.config.DatabaseConfiguration --* src.models.config.Configuration : database
102111
src.models.config.InferenceConfiguration --* src.models.config.Configuration : inference
103112
src.models.config.JwtConfiguration --* src.models.config.JwkConfiguration : jwt_configuration

src/app/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@
4040
],
4141
)
4242

43+
cors = configuration.service_configuration.cors
44+
4345
app.add_middleware(
4446
CORSMiddleware,
45-
allow_origins=["*"],
46-
allow_credentials=True,
47-
allow_methods=["*"],
48-
allow_headers=["*"],
47+
allow_origins=cors.allow_origins,
48+
allow_credentials=cors.allow_credentials,
49+
allow_methods=cors.allow_methods,
50+
allow_headers=cors.allow_headers,
4951
)
5052

5153

src/models/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ def check_tls_configuration(self) -> Self:
2424
return self
2525

2626

27+
class CORSConfiguration(BaseModel):
28+
"""CORS configuration."""
29+
30+
allow_origins: list[str] = [
31+
"*"
32+
] # not AnyHttpUrl: we need to support "*" that is not valid URL
33+
allow_credentials: bool = True
34+
allow_methods: list[str] = ["*"]
35+
allow_headers: list[str] = ["*"]
36+
37+
@model_validator(mode="after")
38+
def check_cors_configuration(self) -> Self:
39+
"""Check CORS configuration."""
40+
return self
41+
42+
2743
class SQLiteDatabaseConfiguration(BaseModel):
2844
"""SQLite database configuration."""
2945

@@ -106,6 +122,7 @@ class ServiceConfiguration(BaseModel):
106122
color_log: bool = True
107123
access_log: bool = True
108124
tls_config: TLSConfiguration = TLSConfiguration()
125+
cors: CORSConfiguration = CORSConfiguration()
109126

110127
@model_validator(mode="after")
111128
def check_service_configuration(self) -> Self:

tests/configuration/lightspeed-stack.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@ service:
66
workers: 1
77
color_log: true
88
access_log: true
9+
cors:
10+
allow_origins:
11+
- foo_origin
12+
- bar_origin
13+
- baz_origin
14+
allow_credentials: false
15+
allow_methods:
16+
- foo_method
17+
- bar_method
18+
- baz_method
19+
allow_headers:
20+
- foo_header
21+
- bar_header
22+
- baz_header
923
llama_stack:
1024
# Uses a remote llama-stack service
1125
# The instance would have already been started with a llama-stack-run.yaml file

tests/integration/test_configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def test_loading_proper_configuration(configuration_filename: str) -> None:
4747
assert svc_config.color_log is True
4848
assert svc_config.access_log is True
4949

50+
# check 'service.cors' section
51+
cors_config = cfg.service_configuration.cors
52+
assert cors_config.allow_origins == ["foo_origin", "bar_origin", "baz_origin"]
53+
assert cors_config.allow_credentials is False
54+
assert cors_config.allow_methods == ["foo_method", "bar_method", "baz_method"]
55+
assert cors_config.allow_headers == ["foo_header", "bar_header", "baz_header"]
56+
5057
# check 'llama_stack' section
5158
ls_config = cfg.llama_stack_configuration
5259
assert ls_config.use_as_library_client is False

tests/unit/models/test_config.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ServiceConfiguration,
2222
UserDataCollection,
2323
TLSConfiguration,
24+
CORSConfiguration,
2425
ModelContextProtocolServer,
2526
InferenceConfiguration,
2627
)
@@ -214,6 +215,31 @@ def test_user_data_collection_transcripts_disabled() -> None:
214215
UserDataCollection(transcripts_enabled=True, transcripts_storage=None)
215216

216217

218+
def test_cors_default_configuration() -> None:
219+
"""Test the CORS configuration."""
220+
cfg = CORSConfiguration()
221+
assert cfg is not None
222+
assert cfg.allow_origins == ["*"]
223+
assert cfg.allow_credentials is True
224+
assert cfg.allow_methods == ["*"]
225+
assert cfg.allow_headers == ["*"]
226+
227+
228+
def test_cors_custom_configuration() -> None:
229+
"""Test the CORS configuration."""
230+
cfg = CORSConfiguration(
231+
allow_origins=["foo_origin", "bar_origin", "baz_origin"],
232+
allow_credentials=False,
233+
allow_methods=["foo_method", "bar_method", "baz_method"],
234+
allow_headers=["foo_header", "bar_header", "baz_header"],
235+
)
236+
assert cfg is not None
237+
assert cfg.allow_origins == ["foo_origin", "bar_origin", "baz_origin"]
238+
assert cfg.allow_credentials is False
239+
assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"]
240+
assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"]
241+
242+
217243
def test_tls_configuration() -> None:
218244
"""Test the TLS configuration."""
219245
cfg = TLSConfiguration(
@@ -437,7 +463,13 @@ def test_dump_configuration(tmp_path) -> None:
437463
tls_certificate_path=Path("tests/configuration/server.crt"),
438464
tls_key_path=Path("tests/configuration/server.key"),
439465
tls_key_password=Path("tests/configuration/password"),
440-
)
466+
),
467+
cors=CORSConfiguration(
468+
allow_origins=["foo_origin", "bar_origin", "baz_origin"],
469+
allow_credentials=False,
470+
allow_methods=["foo_method", "bar_method", "baz_method"],
471+
allow_headers=["foo_header", "bar_header", "baz_header"],
472+
),
441473
),
442474
llama_stack=LlamaStackConfiguration(
443475
use_as_library_client=True,
@@ -488,6 +520,24 @@ def test_dump_configuration(tmp_path) -> None:
488520
"tls_key_password": "tests/configuration/password",
489521
"tls_key_path": "tests/configuration/server.key",
490522
},
523+
"cors": {
524+
"allow_credentials": False,
525+
"allow_headers": [
526+
"foo_header",
527+
"bar_header",
528+
"baz_header",
529+
],
530+
"allow_methods": [
531+
"foo_method",
532+
"bar_method",
533+
"baz_method",
534+
],
535+
"allow_origins": [
536+
"foo_origin",
537+
"bar_origin",
538+
"baz_origin",
539+
],
540+
},
491541
},
492542
"llama_stack": {
493543
"url": None,

0 commit comments

Comments
 (0)