diff --git a/django/db/backends/postgresql/compiler.py b/django/db/backends/postgresql/compiler.py new file mode 100644 index 000000000000..2394d90f55d7 --- /dev/null +++ b/django/db/backends/postgresql/compiler.py @@ -0,0 +1,50 @@ +from django.db.models.sql.compiler import ( + SQLAggregateCompiler, + SQLCompiler, + SQLDeleteCompiler, +) +from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler +from django.db.models.sql.compiler import SQLUpdateCompiler + +__all__ = [ + "SQLAggregateCompiler", + "SQLCompiler", + "SQLDeleteCompiler", + "SQLInsertCompiler", + "SQLUpdateCompiler", +] + + +class InsertUnnest(list): + """ + Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the + UNNEST strategy should be used for the bulk insert. + """ + + def __str__(self): + return "UNNEST(%s)" % ", ".join(self) + + +class SQLInsertCompiler(BaseSQLInsertCompiler): + def assemble_as_sql(self, fields, value_rows): + # Specialize bulk-insertion of literal non-array values through + # UNNEST to reduce the time spent planning the query. + if ( + # The optimization is not worth doing if there is a single + # row as it will result in the same number of placeholders. + len(value_rows) <= 1 + # Lack of fields denote the usage of the DEFAULT keyword + # for the insertion of empty rows. + or any(field is None for field in fields) + # Compilable cannot be combined in an array of literal values. + or any(any(hasattr(value, "as_sql") for value in row) for row in value_rows) + ): + return super().assemble_as_sql(fields, value_rows) + db_types = [field.db_type(self.connection) for field in fields] + # Abort if any of the fields are arrays as UNNEST indiscriminately + # flatten them instead of reducing their nesting by one. + if any(db_type.endswith("[]") for db_type in db_types): + return super().assemble_as_sql(fields, value_rows) + return InsertUnnest(["(%%s)::%s[]" % db_type for db_type in db_types]), [ + list(map(list, zip(*value_rows))) + ] diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 8a0ca36a29f7..9db755bb8919 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,6 +3,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.backends.postgresql.compiler import InsertUnnest from django.db.backends.postgresql.psycopg_any import ( Inet, Jsonb, @@ -24,6 +25,7 @@ def get_json_dumps(encoder): class DatabaseOperations(BaseDatabaseOperations): + compiler_module = "django.db.backends.postgresql.compiler" cast_char_field_without_max_length = "varchar" explain_prefix = "EXPLAIN" explain_options = frozenset( @@ -148,6 +150,11 @@ def time_trunc_sql(self, lookup_type, sql, params, tzname=None): def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" + def bulk_insert_sql(self, fields, placeholder_rows): + if isinstance(placeholder_rows, InsertUnnest): + return f"SELECT * FROM {placeholder_rows}" + return super().bulk_insert_sql(fields, placeholder_rows) + def fetch_returned_insert_rows(self, cursor): """ Given a cursor object that has just performed an INSERT...RETURNING diff --git a/tests/backends/postgresql/test_compilation.py b/tests/backends/postgresql/test_compilation.py new file mode 100644 index 000000000000..67fe893e35d2 --- /dev/null +++ b/tests/backends/postgresql/test_compilation.py @@ -0,0 +1,29 @@ +import unittest + +from django.db import connection +from django.db.models.expressions import RawSQL +from django.test import TestCase + +from ..models import Square + + +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") +class BulkCreateUnnestTests(TestCase): + def test_single_object(self): + with self.assertNumQueries(1) as ctx: + Square.objects.bulk_create([Square(root=2, square=4)]) + self.assertNotIn("UNNEST", ctx[0]["sql"]) + + def test_non_literal(self): + with self.assertNumQueries(1) as ctx: + Square.objects.bulk_create( + [Square(root=2, square=RawSQL("%s", (4,))), Square(root=3, square=9)] + ) + self.assertNotIn("UNNEST", ctx[0]["sql"]) + + def test_unnest_eligible(self): + with self.assertNumQueries(1) as ctx: + Square.objects.bulk_create( + [Square(root=2, square=4), Square(root=3, square=9)] + ) + self.assertIn("UNNEST", ctx[0]["sql"])