Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #210 - Incorrect variable for calc_batch_size #211

Merged
merged 3 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions dbt/include/sqlserver/macros/materializations/seeds/helpers.sql
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
{% macro calc_batch_size(num_columns,max_batch_size) %}
{#
SQL Server allows for a max of 2100 parameters in a single statement.
Check if the max_batch_size fits with the number of columns, otherwise
reduce the batch size so it fits.
#}
{% if num_columns * max_batch_size < 2100 %}
{% set batch_size = max_batch_size %}
{% else %}
{% set batch_size = (2100 / num_columns)|int %}
{% endif %}

{{ return(batch_size) }}
{% endmacro %}

{% macro sqlserver__get_binding_char() %}
{{ return('?') }}
{% endmacro %}
Expand All @@ -21,13 +6,27 @@
{{ return(400) }}
{% endmacro %}

{% macro basic_load_csv_rows(model, batch_size, agate_table) %}
{% macro calc_batch_size(num_columns) %}
{#
SQL Server allows for a max of 2100 parameters in a single statement.
Check if the max_batch_size fits with the number of columns, otherwise
reduce the batch size so it fits.
#}
{% set max_batch_size = get_batch_size() %}
{% set calculated_batch = (2100 / num_columns)|int %}
{% set batch_size = [max_batch_size, calculated_batch] | min %}

{{ return(batch_size) }}
{% endmacro %}

{% macro sqlserver__load_csv_rows(model, agate_table) %}
{% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %}
{% set batch_size = calc_batch_size(agate_table.column_names|length) %}
{% set bindings = [] %}

{% set statements = [] %}

{{ log("Inserting batches of " ~ batch_size ~ " records") }}

{% for chunk in agate_table.rows | batch(batch_size) %}
{% set bindings = [] %}

Expand Down Expand Up @@ -56,11 +55,3 @@
{# Return SQL so we can render it out into the compiled files #}
{{ return(statements[0]) }}
{% endmacro %}

{% macro sqlserver__load_csv_rows(model, agate_table) %}
{% set max_batch_size = get_batch_size() %}
{% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %}
{% set batch_size = calc_batch_size(cols_sql|length, max_batch_size) %}

{{ return(basic_load_csv_rows(model, batch_size, agate_table) )}}
{% endmacro %}
45 changes: 44 additions & 1 deletion tests/functional/adapter/test_seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

import pytest
from dbt.tests.adapter.simple_seed.seeds import seeds__expected_sql
from dbt.tests.adapter.simple_seed.test_seed import SeedConfigBase
from dbt.tests.adapter.simple_seed.test_seed import TestBasicSeedTests as BaseBasicSeedTests
from dbt.tests.adapter.simple_seed.test_seed import (
TestSeedConfigFullRefreshOff as BaseSeedConfigFullRefreshOff,
Expand All @@ -20,7 +23,7 @@
seeds__disabled_in_config_csv,
seeds__enabled_in_config_csv,
)
from dbt.tests.util import get_connection
from dbt.tests.util import get_connection, run_dbt

from dbt.adapters.sqlserver import SQLServerAdapter

Expand Down Expand Up @@ -180,3 +183,43 @@ def setUp(self, project):

class TestSeedSpecificFormatsSQLServer(BaseSeedSpecificFormats):
pass


class TestSeedBatchSizeMaxSQLServer(SeedConfigBase):
@pytest.fixture(scope="class")
def seeds(self, test_data_dir):
return {
"five_columns.csv": """
seed_id,first_name,email,ip_address,birthday
1,Larry,lking0@miitbeian.gov.cn,69.135.206.194,2008-09-12 19:08:31
2,Larry,lperkins1@toplist.cz,64.210.133.162,1978-05-09 04:15:14
3,Anna,amontgomery2@miitbeian.gov.cn,168.104.64.114,2011-10-16 04:07:57
"""
}

def test_max_batch_size(self, project, logs_dir):
run_dbt(["seed"])
with open(os.path.join(logs_dir, "dbt.log"), "r") as fp:
logs = "".join(fp.readlines())

assert "Inserting batches of 400 records" in logs


class TestSeedBatchSizeCustomSQLServer(SeedConfigBase):
@pytest.fixture(scope="class")
def seeds(self, test_data_dir):
return {
"six_columns.csv": """
seed_id,first_name,last_name,email,ip_address,birthday
1,Larry,King,lking0@miitbeian.gov.cn,69.135.206.194,2008-09-12 19:08:31
2,Larry,Perkins,lperkins1@toplist.cz,64.210.133.162,1978-05-09 04:15:14
3,Anna,Montgomery,amontgomery2@miitbeian.gov.cn,168.104.64.114,2011-10-16 04:07:57
"""
}

def test_custom_batch_size(self, project, logs_dir):
run_dbt(["seed"])
with open(os.path.join(logs_dir, "dbt.log"), "r") as fp:
logs = "".join(fp.readlines())

assert "Inserting batches of 350 records" in logs