Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions airflow/providers/google/cloud/transfers/mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import datetime
import decimal
from typing import Sequence

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
Expand All @@ -29,6 +30,10 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
"""Copy data from Microsoft SQL Server to Google Cloud Storage
in JSON, CSV or Parquet format.

:param bit_fields: Sequence of fields names of MSSQL "BIT" data type,
to be interpreted in the schema as "BOOLEAN". "BIT" fields that won't
be included in this sequence, will be interpreted as "INTEGER" by
default.
:param mssql_conn_id: Reference to a specific MSSQL hook.

**Example**:
Expand All @@ -39,6 +44,7 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
export_customers = MsSqlToGoogleCloudStorageOperator(
task_id='export_customers',
sql='SELECT * FROM dbo.Customers;',
bit_fields=['some_bit_field', 'another_bit_field'],
bucket='mssql-export',
filename='data/customers/export.json',
schema_filename='schemas/export.json',
Expand All @@ -55,11 +61,18 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):

ui_color = "#e0a98c"

type_map = {3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"}
type_map = {2: "BOOLEAN", 3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"}

def __init__(self, *, mssql_conn_id="mssql_default", **kwargs):
def __init__(
self,
*,
bit_fields: Sequence[str] | None = None,
mssql_conn_id="mssql_default",
**kwargs,
):
super().__init__(**kwargs)
self.mssql_conn_id = mssql_conn_id
self.bit_fields = bit_fields if bit_fields else []

def query(self):
"""
Expand All @@ -74,6 +87,9 @@ def query(self):
return cursor

def field_to_bigquery(self, field) -> dict[str, str]:
if field[0] in self.bit_fields:
field = (field[0], 2)

return {
"name": field[0].replace(" ", "_"),
"type": self.type_map.get(field[1], "STRING"),
Expand Down
37 changes: 30 additions & 7 deletions tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,35 @@
JSON_FILENAME = "test_{}.ndjson"
GZIP = False

ROWS = [("mock_row_content_1", 42), ("mock_row_content_2", 43), ("mock_row_content_3", 44)]
ROWS = [
("mock_row_content_1", 42, True, True),
("mock_row_content_2", 43, False, False),
("mock_row_content_3", 44, True, True),
]
CURSOR_DESCRIPTION = (
("some_str", 0, None, None, None, None, None),
("some_num", 3, None, None, None, None, None),
("some_binary", 2, None, None, None, None, None),
("some_bit", 3, None, None, None, None, None),
)
NDJSON_LINES = [
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n',
b'{"some_binary": true, "some_bit": true, "some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_binary": false, "some_bit": false, "some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_binary": true, "some_bit": true, "some_num": 44, "some_str": "mock_row_content_3"}\n',
]
SCHEMA_FILENAME = "schema_test.json"
SCHEMA_JSON = [
b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}, ',
b'{"mode": "NULLABLE", "name": "some_binary", "type": "BOOLEAN"}, ',
b'{"mode": "NULLABLE", "name": "some_bit", "type": "BOOLEAN"}]',
]

SCHEMA_JSON_BIT_FIELDS = [
b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}, ',
b'{"mode": "NULLABLE", "name": "some_binary", "type": "BOOLEAN"}, ',
b'{"mode": "NULLABLE", "name": "some_bit", "type": "INTEGER"}]',
]


Expand Down Expand Up @@ -148,7 +163,10 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metada

@mock.patch("airflow.providers.google.cloud.transfers.mssql_to_gcs.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class):
@pytest.mark.parametrize(
"bit_fields,schema_json", [(None, SCHEMA_JSON), (["bit_fields", SCHEMA_JSON_BIT_FIELDS])]
)
def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class, bit_fields, schema_json):
"""Test writing schema files."""
mssql_hook_mock = mssql_hook_mock_class.return_value
mssql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
Expand All @@ -164,7 +182,12 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
gcs_hook_mock.upload.side_effect = _assert_upload

op = MSSQLToGCSOperator(
task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME
task_id=TASK_ID,
sql=SQL,
bucket=BUCKET,
filename=JSON_FILENAME,
schema_filename=SCHEMA_FILENAME,
bit_fields=["some_bit"],
)
op.execute(None)

Expand Down