Skip to content

Commit

Permalink
Fixed #35936 -- Used unnest for bulk inserts on Postgres when possible.
Browse files Browse the repository at this point in the history
This should make bulk_create significantly faster on Postgres when provided
only literal values.

Thanks James Sewell for writing about this technique, Tom Forbes for
validating the performance benefits, David Sanders and Mariusz Felisiak
for the review.
  • Loading branch information
charettes authored and sarahboyce committed Dec 11, 2024
1 parent 2638b75 commit a16eedc
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
50 changes: 50 additions & 0 deletions django/db/backends/postgresql/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from django.db.models.sql.compiler import (
SQLAggregateCompiler,
SQLCompiler,
SQLDeleteCompiler,
)
from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
from django.db.models.sql.compiler import SQLUpdateCompiler

__all__ = [
"SQLAggregateCompiler",
"SQLCompiler",
"SQLDeleteCompiler",
"SQLInsertCompiler",
"SQLUpdateCompiler",
]


class InsertUnnest(list):
"""
Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
UNNEST strategy should be used for the bulk insert.
"""

def __str__(self):
return "UNNEST(%s)" % ", ".join(self)


class SQLInsertCompiler(BaseSQLInsertCompiler):
def assemble_as_sql(self, fields, value_rows):
# Specialize bulk-insertion of literal non-array values through
# UNNEST to reduce the time spent planning the query.
if (
# The optimization is not worth doing if there is a single
# row as it will result in the same number of placeholders.
len(value_rows) <= 1
# Lack of fields denote the usage of the DEFAULT keyword
# for the insertion of empty rows.
or any(field is None for field in fields)
# Compilable cannot be combined in an array of literal values.
or any(any(hasattr(value, "as_sql") for value in row) for row in value_rows)
):
return super().assemble_as_sql(fields, value_rows)
db_types = [field.db_type(self.connection) for field in fields]
# Abort if any of the fields are arrays as UNNEST indiscriminately
# flatten them instead of reducing their nesting by one.
if any(db_type.endswith("[]") for db_type in db_types):
return super().assemble_as_sql(fields, value_rows)
return InsertUnnest(["(%%s)::%s[]" % db_type for db_type in db_types]), [
list(map(list, zip(*value_rows)))
]
7 changes: 7 additions & 0 deletions django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.postgresql.compiler import InsertUnnest
from django.db.backends.postgresql.psycopg_any import (
Inet,
Jsonb,
Expand All @@ -24,6 +25,7 @@ def get_json_dumps(encoder):


class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.postgresql.compiler"
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
explain_options = frozenset(
Expand Down Expand Up @@ -148,6 +150,11 @@ def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"

def bulk_insert_sql(self, fields, placeholder_rows):
if isinstance(placeholder_rows, InsertUnnest):
return f"SELECT * FROM {placeholder_rows}"
return super().bulk_insert_sql(fields, placeholder_rows)

def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
Expand Down
29 changes: 29 additions & 0 deletions tests/backends/postgresql/test_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

from django.db import connection
from django.db.models.expressions import RawSQL
from django.test import TestCase

from ..models import Square


@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
class BulkCreateUnnestTests(TestCase):
def test_single_object(self):
with self.assertNumQueries(1) as ctx:
Square.objects.bulk_create([Square(root=2, square=4)])
self.assertNotIn("UNNEST", ctx[0]["sql"])

def test_non_literal(self):
with self.assertNumQueries(1) as ctx:
Square.objects.bulk_create(
[Square(root=2, square=RawSQL("%s", (4,))), Square(root=3, square=9)]
)
self.assertNotIn("UNNEST", ctx[0]["sql"])

def test_unnest_eligible(self):
with self.assertNumQueries(1) as ctx:
Square.objects.bulk_create(
[Square(root=2, square=4), Square(root=3, square=9)]
)
self.assertIn("UNNEST", ctx[0]["sql"])

0 comments on commit a16eedc

Please sign in to comment.