Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions django/contrib/gis/utils/layermapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
https://docs.djangoproject.com/en/dev/ref/contrib/gis/layermapping/
"""
import sys
from contextlib import nullcontext
from decimal import Decimal
from decimal import InvalidOperation as DecimalInvalidOperation
from pathlib import Path
from typing import Any, Optional

from django.contrib.gis.db.models import GeometryField
from django.contrib.gis.gdal import (
Expand All @@ -20,6 +22,7 @@
OGRGeomType,
SpatialReference,
)
from django.contrib.gis.gdal.feature import Feature
from django.contrib.gis.gdal.field import (
OFTDate,
OFTDateTime,
Expand Down Expand Up @@ -105,6 +108,7 @@
transform=True,
unique=None,
using=None,
faster_verify_fk=False,
):
"""
A LayerMapping object is initialized using the given Model (not an
Expand All @@ -127,10 +131,20 @@
self.mapping = mapping
self.model = model

# Flag to enable fetch required FK model pks in batch to avoid N+1 queries

Check warning on line 134 in django/contrib/gis/utils/layermapping.py

View workflow job for this annotation

GitHub Actions / flake8

doc line too long (82 > 79 characters)
self.faster_verify_fk = faster_verify_fk
if self.faster_verify_fk:
print("Faster foreign key verification is enabled")

# Checking the layer -- initialization of the object will fail if
# things don't check out before hand.
self.check_layer()

self.fk_field_names = self.get_model_fk_field_names()

# TODO: Improve naming, it stores FK id to actual FK instance pk
self.fks_uid_pk_map = {}

# Getting the geometry column associated with the model (an
# exception will be raised if there is no geometry column).
if connection.features.supports_transform:
Expand Down Expand Up @@ -340,6 +354,47 @@
"Unique keyword argument must be set with a tuple, list, or string."
)

# TODO: assign the fk_field_names by other methods
def get_model_fk_field_names(self):
fk_field_names = []
for field_name, ogr_name in self.mapping.items():
model_field = self.fields[field_name]
if isinstance(model_field, models.base.ModelBase):
fk_field_names.append(field_name)
return fk_field_names

# Implement after real batching
# def verify_fk_exists(self, field_name: str, fk_ids: list[int]):
# if field_name not in self.fk_field_names:
# raise Exception("Field name not found in fk_field_names")

# related_model = self.fields[field_name]
# related_model_pks = related_model.objects.filter(pk__in=fk_ids).values_list(

Check warning on line 372 in django/contrib/gis/utils/layermapping.py

View workflow job for this annotation

GitHub Actions / flake8

doc line too long (86 > 79 characters)
# "pk", flat=True
# )
# if missing_pks := set(fk_ids) - set(related_model_pks):
# raise Exception(f"Missing {field_name} foreign key ids: {missing_pks}")

Check warning on line 376 in django/contrib/gis/utils/layermapping.py

View workflow job for this annotation

GitHub Actions / flake8

doc line too long (85 > 79 characters)

# TODO: Implement real batching
def batch_fetch_fk_pks(self, field_name: str, uids: Optional[list[int]] = None):
related_model = self.fields[field_name]

uid_pk_map = {}

for related_model_column_name, ogr_name_fk in self.mapping[field_name].items():
queryset = related_model.objects.all()
if uids:
queryset = queryset.filter(**{f"{related_model_column_name}__in": uids})

result = queryset.values_list(related_model_column_name, "pk")
for uid, pk in result:
uid_pk_map[uid] = pk
return uid_pk_map

def load_fks_uid_pk_map(self):
for field_name in self.fk_field_names:
self.fks_uid_pk_map[field_name] = self.batch_fetch_fk_pks(field_name)

# Keyword argument retrieval routines.
def feature_kwargs(self, feat):
"""
Expand All @@ -363,7 +418,24 @@
elif isinstance(model_field, models.base.ModelBase):
# The related _model_, not a field was passed in -- indicating
# another mapping for the related Model.
val = self.verify_fk(feat, model_field, ogr_name)
if not self.faster_verify_fk:
val = self.verify_fk(feat, model_field, ogr_name)
else:
for rel_model_column_name, ogr_name_fk in ogr_name.items():
# Seems not very helpful
fk_val = self.verify_ogr_field(feat[ogr_name_fk], model_field)
if fk_val not in self.fks_uid_pk_map[field_name]:
# TODO: batch matching to report all missing fk ids
raise Exception(
f"Missing {field_name} foreign key ids: {fk_val}"
)
else:
rel_model_pk = self.fks_uid_pk_map[field_name][fk_val]
val = rel_model_pk
# TODO: Validate it is always correct
field_name = f"{field_name}_id"
# Should only have one key inside the dict
break
else:
# Otherwise, verify OGR Field type.
val = self.verify_ogr_field(feat[ogr_name], model_field)
Expand Down Expand Up @@ -460,6 +532,7 @@
val = ogr_field.value
return val

# Will be removed
def verify_fk(self, feat, rel_model, rel_mapping):
"""
Given an OGR Feature, the related model and its dictionary mapping,
Expand All @@ -478,7 +551,8 @@

# Attempting to retrieve and return the related model.
try:
return rel_model.objects.using(self.using).get(**fk_kwargs)
# Lighter query by only fetching pk
return rel_model.objects.using(self.using).only("pk").get(**fk_kwargs)
except ObjectDoesNotExist:
raise MissingForeignKey(
"No ForeignKey %s model found with keyword arguments: %s"
Expand All @@ -499,6 +573,10 @@
if self.coord_dim == 2 and geom.is_3d:
geom.set_3d(False)

# Downgrade a curved geom to a linear one so that it can be saved
if geom.has_curve:
geom = geom.get_linear_geometry()

if self.make_multi(geom.geom_type, model_field):
# Constructing a multi-geometry type to contain the single geometry
multi_type = self.MULTI_TYPES[geom.geom_type.num]
Expand Down Expand Up @@ -540,6 +618,11 @@
Return the GeometryField instance associated with the geographic
column.
"""

# Allow layer has no geometry field being mapped
if self.geom_field is None:
return None

# Use `get_field()` on the model's options so that we
# get the correct field instance if there's model inheritance.
opts = self.model._meta
Expand Down Expand Up @@ -616,6 +699,9 @@
else:
progress_interval = progress

if self.faster_verify_fk:
self.load_fks_uid_pk_map()

def _save(feat_range=default_range, num_feat=0, num_saved=0):
if feat_range:
layer_iter = self.layer[feat_range]
Expand Down Expand Up @@ -734,3 +820,65 @@
else:
# Otherwise, just calling the previously defined _save() function.
_save()

def _split_layer(self, batch_size: int = 1000):
"""
Split the features in the layer into batches of the given size.
"""
current_batch = []
for feature in self.layer:
current_batch.append(feature)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
if current_batch:
yield current_batch

def _bulk_create_batch(
self,
features_batch: list[Feature],
overwrite_kwargs: Optional[dict[str, Any]] = None,
):
"""
Given a batch of features, bulk create these features.
"""

if not overwrite_kwargs:
overwrite_kwargs = {}

if self.faster_verify_fk:
features_kwargs = [
{**self.feature_kwargs(feature), **overwrite_kwargs}
for feature in features_batch
]
# Verify FK existence
for fk_field_name in self.fk_field_names:
uids = [kwargs[fk_field_name + "_id"] for kwargs in features_kwargs]
if missing_uids := set(uids) - set(
self.fks_uid_pk_map[fk_field_name].values()
):
raise Exception(
f"Missing {fk_field_name} foreign key ids: {missing_uids}"
)
features = [self.model(**kwargs) for kwargs in features_kwargs]
else:
features = [
self.model(**{**self.feature_kwargs(feature), **overwrite_kwargs})
for feature in features_batch
]
self.model.objects.using(self.using).bulk_create(features)

def bulk_create_all(
self, overwrite_kwargs: Optional[dict[str, Any]] = None, batch_size: int = 1000
):
if self.faster_verify_fk:
self.load_fks_uid_pk_map()

context = (
transaction.atomic()
if self.transaction_mode == "commit_on_success"
else nullcontext()
)
with context:
for features_batch in self._split_layer(batch_size):
self._bulk_create_batch(features_batch, overwrite_kwargs)
Loading