diff --git a/django/contrib/gis/utils/layermapping.py b/django/contrib/gis/utils/layermapping.py index a4cd04dc057a..8ff9ca9ccde6 100644 --- a/django/contrib/gis/utils/layermapping.py +++ b/django/contrib/gis/utils/layermapping.py @@ -7,9 +7,11 @@ https://docs.djangoproject.com/en/dev/ref/contrib/gis/layermapping/ """ import sys +from contextlib import nullcontext from decimal import Decimal from decimal import InvalidOperation as DecimalInvalidOperation from pathlib import Path +from typing import Any, Optional from django.contrib.gis.db.models import GeometryField from django.contrib.gis.gdal import ( @@ -20,6 +22,7 @@ OGRGeomType, SpatialReference, ) +from django.contrib.gis.gdal.feature import Feature from django.contrib.gis.gdal.field import ( OFTDate, OFTDateTime, @@ -105,6 +108,7 @@ def __init__( transform=True, unique=None, using=None, + faster_verify_fk=False, ): """ A LayerMapping object is initialized using the given Model (not an @@ -127,10 +131,20 @@ def __init__( self.mapping = mapping self.model = model + # Flag to enable fetch required FK model pks in batch to avoid N+1 queries + self.faster_verify_fk = faster_verify_fk + if self.faster_verify_fk: + print("Faster foreign key verification is enabled") + # Checking the layer -- initialization of the object will fail if # things don't check out before hand. self.check_layer() + self.fk_field_names = self.get_model_fk_field_names() + + # TODO: Improve naming, it stores FK id to actual FK instance pk + self.fks_uid_pk_map = {} + # Getting the geometry column associated with the model (an # exception will be raised if there is no geometry column). if connection.features.supports_transform: @@ -340,6 +354,47 @@ def check_unique(self, unique): "Unique keyword argument must be set with a tuple, list, or string." ) + # TODO: assign the fk_field_names by other methods + def get_model_fk_field_names(self): + fk_field_names = [] + for field_name, ogr_name in self.mapping.items(): + model_field = self.fields[field_name] + if isinstance(model_field, models.base.ModelBase): + fk_field_names.append(field_name) + return fk_field_names + + # Implement after real batching + # def verify_fk_exists(self, field_name: str, fk_ids: list[int]): + # if field_name not in self.fk_field_names: + # raise Exception("Field name not found in fk_field_names") + + # related_model = self.fields[field_name] + # related_model_pks = related_model.objects.filter(pk__in=fk_ids).values_list( + # "pk", flat=True + # ) + # if missing_pks := set(fk_ids) - set(related_model_pks): + # raise Exception(f"Missing {field_name} foreign key ids: {missing_pks}") + + # TODO: Implement real batching + def batch_fetch_fk_pks(self, field_name: str, uids: Optional[list[int]] = None): + related_model = self.fields[field_name] + + uid_pk_map = {} + + for related_model_column_name, ogr_name_fk in self.mapping[field_name].items(): + queryset = related_model.objects.all() + if uids: + queryset = queryset.filter(**{f"{related_model_column_name}__in": uids}) + + result = queryset.values_list(related_model_column_name, "pk") + for uid, pk in result: + uid_pk_map[uid] = pk + return uid_pk_map + + def load_fks_uid_pk_map(self): + for field_name in self.fk_field_names: + self.fks_uid_pk_map[field_name] = self.batch_fetch_fk_pks(field_name) + # Keyword argument retrieval routines. def feature_kwargs(self, feat): """ @@ -363,7 +418,24 @@ def feature_kwargs(self, feat): elif isinstance(model_field, models.base.ModelBase): # The related _model_, not a field was passed in -- indicating # another mapping for the related Model. - val = self.verify_fk(feat, model_field, ogr_name) + if not self.faster_verify_fk: + val = self.verify_fk(feat, model_field, ogr_name) + else: + for rel_model_column_name, ogr_name_fk in ogr_name.items(): + # Seems not very helpful + fk_val = self.verify_ogr_field(feat[ogr_name_fk], model_field) + if fk_val not in self.fks_uid_pk_map[field_name]: + # TODO: batch matching to report all missing fk ids + raise Exception( + f"Missing {field_name} foreign key ids: {fk_val}" + ) + else: + rel_model_pk = self.fks_uid_pk_map[field_name][fk_val] + val = rel_model_pk + # TODO: Validate it is always correct + field_name = f"{field_name}_id" + # Should only have one key inside the dict + break else: # Otherwise, verify OGR Field type. val = self.verify_ogr_field(feat[ogr_name], model_field) @@ -460,6 +532,7 @@ def verify_ogr_field(self, ogr_field, model_field): val = ogr_field.value return val + # Will be removed def verify_fk(self, feat, rel_model, rel_mapping): """ Given an OGR Feature, the related model and its dictionary mapping, @@ -478,7 +551,8 @@ def verify_fk(self, feat, rel_model, rel_mapping): # Attempting to retrieve and return the related model. try: - return rel_model.objects.using(self.using).get(**fk_kwargs) + # Lighter query by only fetching pk + return rel_model.objects.using(self.using).only("pk").get(**fk_kwargs) except ObjectDoesNotExist: raise MissingForeignKey( "No ForeignKey %s model found with keyword arguments: %s" @@ -499,6 +573,10 @@ def verify_geom(self, geom, model_field): if self.coord_dim == 2 and geom.is_3d: geom.set_3d(False) + # Downgrade a curved geom to a linear one so that it can be saved + if geom.has_curve: + geom = geom.get_linear_geometry() + if self.make_multi(geom.geom_type, model_field): # Constructing a multi-geometry type to contain the single geometry multi_type = self.MULTI_TYPES[geom.geom_type.num] @@ -540,6 +618,11 @@ def geometry_field(self): Return the GeometryField instance associated with the geographic column. """ + + # Allow layer has no geometry field being mapped + if self.geom_field is None: + return None + # Use `get_field()` on the model's options so that we # get the correct field instance if there's model inheritance. opts = self.model._meta @@ -616,6 +699,9 @@ def save( else: progress_interval = progress + if self.faster_verify_fk: + self.load_fks_uid_pk_map() + def _save(feat_range=default_range, num_feat=0, num_saved=0): if feat_range: layer_iter = self.layer[feat_range] @@ -734,3 +820,65 @@ def _save(feat_range=default_range, num_feat=0, num_saved=0): else: # Otherwise, just calling the previously defined _save() function. _save() + + def _split_layer(self, batch_size: int = 1000): + """ + Split the features in the layer into batches of the given size. + """ + current_batch = [] + for feature in self.layer: + current_batch.append(feature) + if len(current_batch) >= batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + + def _bulk_create_batch( + self, + features_batch: list[Feature], + overwrite_kwargs: Optional[dict[str, Any]] = None, + ): + """ + Given a batch of features, bulk create these features. + """ + + if not overwrite_kwargs: + overwrite_kwargs = {} + + if self.faster_verify_fk: + features_kwargs = [ + {**self.feature_kwargs(feature), **overwrite_kwargs} + for feature in features_batch + ] + # Verify FK existence + for fk_field_name in self.fk_field_names: + uids = [kwargs[fk_field_name + "_id"] for kwargs in features_kwargs] + if missing_uids := set(uids) - set( + self.fks_uid_pk_map[fk_field_name].values() + ): + raise Exception( + f"Missing {fk_field_name} foreign key ids: {missing_uids}" + ) + features = [self.model(**kwargs) for kwargs in features_kwargs] + else: + features = [ + self.model(**{**self.feature_kwargs(feature), **overwrite_kwargs}) + for feature in features_batch + ] + self.model.objects.using(self.using).bulk_create(features) + + def bulk_create_all( + self, overwrite_kwargs: Optional[dict[str, Any]] = None, batch_size: int = 1000 + ): + if self.faster_verify_fk: + self.load_fks_uid_pk_map() + + context = ( + transaction.atomic() + if self.transaction_mode == "commit_on_success" + else nullcontext() + ) + with context: + for features_batch in self._split_layer(batch_size): + self._bulk_create_batch(features_batch, overwrite_kwargs)