7
7
from functools import partial
8
8
from typing import (
9
9
TYPE_CHECKING ,
10
+ AbstractSet ,
10
11
Any ,
11
12
Callable ,
12
13
ContextManager ,
16
17
Literal ,
17
18
Optional ,
18
19
Protocol ,
20
+ Sequence ,
19
21
Tuple ,
20
22
Union ,
21
23
)
@@ -119,6 +121,54 @@ def django_db_createdb(request: pytest.FixtureRequest) -> bool:
119
121
return create_db
120
122
121
123
124
+ def _get_databases_for_test (test : pytest .Item ) -> tuple [Iterable [str ], bool ]:
125
+ """Get the database aliases that need to be setup for a test, and whether
126
+ they need to be serialized."""
127
+ from django .db import DEFAULT_DB_ALIAS , connections
128
+ from django .test import TransactionTestCase
129
+
130
+ test_cls = getattr (test , "cls" , None )
131
+ if test_cls and issubclass (test_cls , TransactionTestCase ):
132
+ serialized_rollback = getattr (test , "serialized_rollback" , False )
133
+ databases = getattr (test , "databases" , None )
134
+ else :
135
+ fixtures = getattr (test , "fixturenames" , ())
136
+ marker_db = test .get_closest_marker ("django_db" )
137
+ if marker_db :
138
+ (
139
+ transaction ,
140
+ reset_sequences ,
141
+ databases ,
142
+ serialized_rollback ,
143
+ available_apps ,
144
+ ) = validate_django_db (marker_db )
145
+ elif "db" in fixtures or "transactional_db" in fixtures or "live_server" in fixtures :
146
+ serialized_rollback = "django_db_serialized_rollback" in fixtures
147
+ databases = None
148
+ else :
149
+ return (), False
150
+ if databases is None :
151
+ return (DEFAULT_DB_ALIAS ,), serialized_rollback
152
+ elif databases == "__all__" :
153
+ return connections , serialized_rollback
154
+ else :
155
+ return databases , serialized_rollback
156
+
157
+
158
+ def _get_databases_for_setup (items : Sequence [pytest .Item ]) -> tuple [AbstractSet [str ], AbstractSet [str ]]:
159
+ """Get the database aliases that need to be setup, and the subset that needs
160
+ to be serialized."""
161
+ # Code derived from django.test.utils.DiscoverRunner.get_databases().
162
+ aliases : set [str ] = set ()
163
+ serialized_aliases : set [str ] = set ()
164
+ for test in items :
165
+ databases , serialized_rollback = _get_databases_for_test (test )
166
+ aliases .update (databases )
167
+ if serialized_rollback :
168
+ serialized_aliases .update (databases )
169
+ return aliases , serialized_aliases
170
+
171
+
122
172
@pytest .fixture (scope = "session" )
123
173
def django_db_setup (
124
174
request : pytest .FixtureRequest ,
@@ -140,10 +190,14 @@ def django_db_setup(
140
190
if django_db_keepdb and not django_db_createdb :
141
191
setup_databases_args ["keepdb" ] = True
142
192
193
+ aliases , serialized_aliases = _get_databases_for_setup (request .session .items )
194
+
143
195
with django_db_blocker .unblock ():
144
196
db_cfg = setup_databases (
145
197
verbosity = request .config .option .verbose ,
146
198
interactive = False ,
199
+ aliases = aliases ,
200
+ serialized_aliases = serialized_aliases ,
147
201
** setup_databases_args ,
148
202
)
149
203
0 commit comments