diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py new file mode 100644 index 0000000000000..cbc35b16bde33 --- /dev/null +++ b/airflow/contrib/hooks/spark_jdbc_hook.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook +from airflow.exceptions import AirflowException + + +class SparkJDBCHook(SparkSubmitHook): + """ + This hook extends the SparkSubmitHook specifically for performing data + transfers to/from JDBC-based databases with Apache Spark. + + :param spark_app_name: Name of the job (default airflow-spark-jdbc) + :type spark_app_name: str + :param spark_conn_id: Connection id as configured in Airflow administration + :type spark_conn_id: str + :param spark_conf: Any additional Spark configuration properties + :type spark_conf: dict + :param spark_py_files: Additional python files used (.zip, .egg, or .py) + :type spark_py_files: str + :param spark_files: Additional files to upload to the container running the job + :type spark_files: str + :param spark_jars: Additional jars to upload and add to the driver and + executor classpath + :type spark_jars: str + :param num_executors: number of executor to run. This should be set so as to manage + the number of connections made with the JDBC database + :type num_executors: int + :param executor_cores: Number of cores per executor + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) + :type driver_memory: str + :param verbose: Whether to pass the verbose flag to spark-submit for debugging + :type verbose: bool + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param principal: The name of the kerberos principal used for keytab + :type principal: str + :param cmd_type: Which way the data should flow. 2 possible values: + spark_to_jdbc: data written by spark from metastore to jdbc + jdbc_to_spark: data written by spark from jdbc to metastore + :type cmd_type: str + :param jdbc_table: The name of the JDBC table + :type jdbc_table: str + :param jdbc_conn_id: Connection id used for connection to JDBC database + :type: jdbc_conn_id: str + :param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This + driver (usually a jar) should be passed in the 'jars' parameter + :type jdbc_driver: str + :param metastore_table: The name of the metastore table, + :type metastore_table: str + :param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or + drop and recreate the JDBC table. This only takes effect if + 'save_mode' is set to Overwrite. Also, if the schema is + different, Spark cannot truncate, and will drop and recreate + :type jdbc_truncate: bool + :param save_mode: The Spark save-mode to use (e.g. overwrite, append, etc.) + :type save_mode: str + :param save_format: (jdbc_to_spark-only) The Spark save-format to use (e.g. parquet) + :type save_format: str + :param batch_size: (spark_to_jdbc only) The size of the batch to insert per round + trip to the JDBC database. Defaults to 1000 + :type batch_size: int + :param fetch_size: (jdbc_to_spark only) The size of the batch to fetch per round trip + from the JDBC database. Default depends on the JDBC driver + :type fetch_size: int + :param num_partitions: The maximum number of partitions that can be used by Spark + simultaneously, both for spark_to_jdbc and jdbc_to_spark + operations. This will also cap the number of JDBC connections + that can be opened + :type num_partitions: int + :param partition_column: (jdbc_to_spark-only) A numeric column to be used to + partition the metastore table by. If specified, you must + also specify: + num_partitions, lower_bound, upper_bound + :type partition_column: str + :param lower_bound: (jdbc_to_spark-only) Lower bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, upper_bound + :type lower_bound: int + :param upper_bound: (jdbc_to_spark-only) Upper bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, lower_bound + :type upper_bound: int + :param create_table_column_types: (spark_to_jdbc-only) The database column data types + to use instead of the defaults, when creating the + table. Data type information should be specified in + the same format as CREATE TABLE columns syntax + (e.g: "name CHAR(64), comments VARCHAR(1024)"). + The specified types should be valid spark sql data + types. + """ + def __init__(self, + spark_app_name='airflow-spark-jdbc', + spark_conn_id='spark-default', + spark_conf=None, + spark_py_files=None, + spark_files=None, + spark_jars=None, + num_executors=None, + executor_cores=None, + executor_memory=None, + driver_memory=None, + verbose=False, + principal=None, + keytab=None, + cmd_type='spark_to_jdbc', + jdbc_table=None, + jdbc_conn_id='jdbc-default', + jdbc_driver=None, + metastore_table=None, + jdbc_truncate=False, + save_mode=None, + save_format=None, + batch_size=None, + fetch_size=None, + num_partitions=None, + partition_column=None, + lower_bound=None, + upper_bound=None, + create_table_column_types=None, + *args, + **kwargs + ): + super(SparkJDBCHook, self).__init__(*args, **kwargs) + self._name = spark_app_name + self._conn_id = spark_conn_id + self._conf = spark_conf + self._py_files = spark_py_files + self._files = spark_files + self._jars = spark_jars + self._num_executors = num_executors + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._verbose = verbose + self._keytab = keytab + self._principal = principal + self._cmd_type = cmd_type + self._jdbc_table = jdbc_table + self._jdbc_conn_id = jdbc_conn_id + self._jdbc_driver = jdbc_driver + self._metastore_table = metastore_table + self._jdbc_truncate = jdbc_truncate + self._save_mode = save_mode + self._save_format = save_format + self._batch_size = batch_size + self._fetch_size = fetch_size + self._num_partitions = num_partitions + self._partition_column = partition_column + self._lower_bound = lower_bound + self._upper_bound = upper_bound + self._create_table_column_types = create_table_column_types + self._jdbc_connection = self._resolve_jdbc_connection() + + def _resolve_jdbc_connection(self): + conn_data = {'url': '', + 'schema': '', + 'conn_prefix': '', + 'user': '', + 'password': '' + } + try: + conn = self.get_connection(self._jdbc_conn_id) + if conn.port: + conn_data['url'] = "{}:{}".format(conn.host, conn.port) + else: + conn_data['url'] = conn.host + conn_data['schema'] = conn.schema + conn_data['user'] = conn.login + conn_data['password'] = conn.password + extra = conn.extra_dejson + conn_data['conn_prefix'] = extra.get('conn_prefix', '') + except AirflowException: + self.log.debug( + "Could not load jdbc connection string %s, defaulting to %s", + self._jdbc_conn_id, "" + ) + return conn_data + + def _build_jdbc_application_arguments(self, jdbc_conn): + arguments = [] + arguments += ["-cmdType", self._cmd_type] + if self._jdbc_connection['url']: + arguments += ['-url', "{0}{1}/{2}".format( + jdbc_conn['conn_prefix'], jdbc_conn['url'], jdbc_conn['schema'] + )] + if self._jdbc_connection['user']: + arguments += ['-user', self._jdbc_connection['user']] + if self._jdbc_connection['password']: + arguments += ['-password', self._jdbc_connection['password']] + if self._metastore_table: + arguments += ['-metastoreTable', self._metastore_table] + if self._jdbc_table: + arguments += ['-jdbcTable', self._jdbc_table] + if self._jdbc_truncate: + arguments += ['-jdbcTruncate', str(self._jdbc_truncate)] + if self._jdbc_driver: + arguments += ['-jdbcDriver', self._jdbc_driver] + if self._batch_size: + arguments += ['-batchsize', str(self._batch_size)] + if self._fetch_size: + arguments += ['-fetchsize', str(self._fetch_size)] + if self._num_partitions: + arguments += ['-numPartitions', str(self._num_partitions)] + if (self._partition_column and self._lower_bound and + self._upper_bound and self._num_partitions): + # these 3 parameters need to be used all together to take effect. + arguments += ['-partitionColumn', self._partition_column, + '-lowerBound', self._lower_bound, + '-upperBound', self._upper_bound] + if self._save_mode: + arguments += ['-saveMode', self._save_mode] + if self._save_format: + arguments += ['-saveFormat', self._save_format] + if self._create_table_column_types: + arguments += ['-createTableColumnTypes', self._create_table_column_types] + return arguments + + def submit_jdbc_job(self): + self._application_args = \ + self._build_jdbc_application_arguments(self._jdbc_connection) + self.submit(application=os.path.dirname(os.path.abspath(__file__)) + + "/spark_jdbc_script.py") + + def get_conn(self): + pass diff --git a/airflow/contrib/hooks/spark_jdbc_script.py b/airflow/contrib/hooks/spark_jdbc_script.py new file mode 100644 index 0000000000000..41d32aa26096c --- /dev/null +++ b/airflow/contrib/hooks/spark_jdbc_script.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +from pyspark.sql import SparkSession + + +def set_common_options(spark_source, + url='localhost:5432', + jdbc_table='default.default', + user='root', + password='root', + driver='driver'): + + spark_source = spark_source \ + .format('jdbc') \ + .option('url', url) \ + .option('dbtable', jdbc_table) \ + .option('user', user) \ + .option('password', password) \ + .option('driver', driver) + return spark_source + + +def spark_write_to_jdbc(spark, url, user, password, metastore_table, jdbc_table, driver, + truncate, save_mode, batch_size, num_partitions, + create_table_column_types): + writer = spark \ + .table(metastore_table) \ + .write \ + + # first set common options + writer = set_common_options(writer, url, jdbc_table, user, password, driver) + + # now set write-specific options + if truncate: + writer = writer.option('truncate', truncate) + if batch_size: + writer = writer.option('batchsize', batch_size) + if num_partitions: + writer = writer.option('numPartitions', num_partitions) + if create_table_column_types: + writer = writer.option("createTableColumnTypes", create_table_column_types) + + writer \ + .save(mode=save_mode) + + +def spark_read_from_jdbc(spark, url, user, password, metastore_table, jdbc_table, driver, + save_mode, save_format, fetch_size, num_partitions, + partition_column, lower_bound, upper_bound): + + # first set common options + reader = set_common_options(spark.read, url, jdbc_table, user, password, driver) + + # now set specific read options + if fetch_size: + reader = reader.option('fetchsize', fetch_size) + if num_partitions: + reader = reader.option('numPartitions', num_partitions) + if partition_column and lower_bound and upper_bound: + reader = reader \ + .option('partitionColumn', partition_column) \ + .option('lowerBound', lower_bound) \ + .option('upperBound', upper_bound) + + reader \ + .load() \ + .write \ + .saveAsTable(metastore_table, format=save_format, mode=save_mode) + + +if __name__ == "__main__": # pragma: no cover + # parse the parameters + parser = argparse.ArgumentParser(description='Spark-JDBC') + parser.add_argument('-cmdType', dest='cmd_type', action='store') + parser.add_argument('-url', dest='url', action='store') + parser.add_argument('-user', dest='user', action='store') + parser.add_argument('-password', dest='password', action='store') + parser.add_argument('-metastoreTable', dest='metastore_table', action='store') + parser.add_argument('-jdbcTable', dest='jdbc_table', action='store') + parser.add_argument('-jdbcDriver', dest='jdbc_driver', action='store') + parser.add_argument('-jdbcTruncate', dest='truncate', action='store') + parser.add_argument('-saveMode', dest='save_mode', action='store') + parser.add_argument('-saveFormat', dest='save_format', action='store') + parser.add_argument('-batchsize', dest='batch_size', action='store') + parser.add_argument('-fetchsize', dest='fetch_size', action='store') + parser.add_argument('-name', dest='name', action='store') + parser.add_argument('-numPartitions', dest='num_partitions', action='store') + parser.add_argument('-partitionColumn', dest='partition_column', action='store') + parser.add_argument('-lowerBound', dest='lower_bound', action='store') + parser.add_argument('-upperBound', dest='upper_bound', action='store') + parser.add_argument('-createTableColumnTypes', + dest='create_table_column_types', action='store') + arguments = parser.parse_args() + + # Disable dynamic allocation by default to allow num_executors to take effect. + spark = SparkSession.builder \ + .appName(arguments.name) \ + .enableHiveSupport() \ + .getOrCreate() + + if arguments.cmd_type == "spark_to_jdbc": + spark_write_to_jdbc(spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.truncate, + arguments.save_mode, + arguments.batch_size, + arguments.num_partitions, + arguments.create_table_column_types) + elif arguments.cmd_type == "jdbc_to_spark": + spark_read_from_jdbc(spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.save_mode, + arguments.save_format, + arguments.fetch_size, + arguments.num_partitions, + arguments.partition_column, + arguments.lower_bound, + arguments.upper_bound) diff --git a/airflow/contrib/operators/spark_jdbc_operator.py b/airflow/contrib/operators/spark_jdbc_operator.py new file mode 100644 index 0000000000000..cab4336ca54d2 --- /dev/null +++ b/airflow/contrib/operators/spark_jdbc_operator.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator +from airflow.contrib.hooks.spark_jdbc_hook import SparkJDBCHook +from airflow.utils.decorators import apply_defaults + + +class SparkJDBCOperator(SparkSubmitOperator): + """ + This operator extends the SparkSubmitOperator specifically for performing data + transfers to/from JDBC-based databases with Apache Spark. As with the + SparkSubmitOperator, it assumes that the "spark-submit" binary is available on the + PATH. + + :param spark_app_name: Name of the job (default airflow-spark-jdbc) + :type spark_app_name: str + :param spark_conn_id: Connection id as configured in Airflow administration + :type spark_conn_id: str + :param spark_conf: Any additional Spark configuration properties + :type spark_conf: dict + :param spark_py_files: Additional python files used (.zip, .egg, or .py) + :type spark_py_files: str + :param spark_files: Additional files to upload to the container running the job + :type spark_files: str + :param spark_jars: Additional jars to upload and add to the driver and + executor classpath + :type spark_jars: str + :param num_executors: number of executor to run. This should be set so as to manage + the number of connections made with the JDBC database + :type num_executors: int + :param executor_cores: Number of cores per executor + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) + :type driver_memory: str + :param verbose: Whether to pass the verbose flag to spark-submit for debugging + :type verbose: bool + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param principal: The name of the kerberos principal used for keytab + :type principal: str + :param cmd_type: Which way the data should flow. 2 possible values: + spark_to_jdbc: data written by spark from metastore to jdbc + jdbc_to_spark: data written by spark from jdbc to metastore + :type cmd_type: str + :param jdbc_table: The name of the JDBC table + :type jdbc_table: str + :param jdbc_conn_id: Connection id used for connection to JDBC database + :type: jdbc_conn_id: str + :param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This + driver (usually a jar) should be passed in the 'jars' parameter + :type jdbc_driver: str + :param metastore_table: The name of the metastore table, + :type metastore_table: str + :param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or + drop and recreate the JDBC table. This only takes effect if + 'save_mode' is set to Overwrite. Also, if the schema is + different, Spark cannot truncate, and will drop and recreate + :type jdbc_truncate: bool + :param save_mode: The Spark save-mode to use (e.g. overwrite, append, etc.) + :type save_mode: str + :param save_format: (jdbc_to_spark-only) The Spark save-format to use (e.g. parquet) + :type save_format: str + :param batch_size: (spark_to_jdbc only) The size of the batch to insert per round + trip to the JDBC database. Defaults to 1000 + :type batch_size: int + :param fetch_size: (jdbc_to_spark only) The size of the batch to fetch per round trip + from the JDBC database. Default depends on the JDBC driver + :type fetch_size: int + :param num_partitions: The maximum number of partitions that can be used by Spark + simultaneously, both for spark_to_jdbc and jdbc_to_spark + operations. This will also cap the number of JDBC connections + that can be opened + :type num_partitions: int + :param partition_column: (jdbc_to_spark-only) A numeric column to be used to + partition the metastore table by. If specified, you must + also specify: + num_partitions, lower_bound, upper_bound + :type partition_column: str + :param lower_bound: (jdbc_to_spark-only) Lower bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, upper_bound + :type lower_bound: int + :param upper_bound: (jdbc_to_spark-only) Upper bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, lower_bound + :type upper_bound: int + :param create_table_column_types: (spark_to_jdbc-only) The database column data types + to use instead of the defaults, when creating the + table. Data type information should be specified in + the same format as CREATE TABLE columns syntax + (e.g: "name CHAR(64), comments VARCHAR(1024)"). + The specified types should be valid spark sql data + types. + """ + + @apply_defaults + def __init__(self, + spark_app_name='airflow-spark-jdbc', + spark_conn_id='spark-default', + spark_conf=None, + spark_py_files=None, + spark_files=None, + spark_jars=None, + num_executors=None, + executor_cores=None, + executor_memory=None, + driver_memory=None, + verbose=False, + keytab=None, + principal=None, + cmd_type='spark_to_jdbc', + jdbc_table=None, + jdbc_conn_id='jdbc-default', + jdbc_driver=None, + metastore_table=None, + jdbc_truncate=False, + save_mode=None, + save_format=None, + batch_size=None, + fetch_size=None, + num_partitions=None, + partition_column=None, + lower_bound=None, + upper_bound=None, + create_table_column_types=None, + *args, + **kwargs): + super(SparkJDBCOperator, self).__init__(*args, **kwargs) + self._spark_app_name = spark_app_name + self._spark_conn_id = spark_conn_id + self._spark_conf = spark_conf + self._spark_py_files = spark_py_files + self._spark_files = spark_files + self._spark_jars = spark_jars + self._num_executors = num_executors + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._verbose = verbose + self._keytab = keytab + self._principal = principal + self._cmd_type = cmd_type + self._jdbc_table = jdbc_table + self._jdbc_conn_id = jdbc_conn_id + self._jdbc_driver = jdbc_driver + self._metastore_table = metastore_table + self._jdbc_truncate = jdbc_truncate + self._save_mode = save_mode + self._save_format = save_format + self._batch_size = batch_size + self._fetch_size = fetch_size + self._num_partitions = num_partitions + self._partition_column = partition_column + self._lower_bound = lower_bound + self._upper_bound = upper_bound + self._create_table_column_types = create_table_column_types + + def execute(self, context): + """ + Call the SparkSubmitHook to run the provided spark job + """ + self._hook = SparkJDBCHook( + spark_app_name=self._spark_app_name, + spark_conn_id=self._spark_conn_id, + spark_conf=self._spark_conf, + spark_py_files=self._spark_py_files, + spark_files=self._spark_files, + spark_jars=self._spark_jars, + num_executors=self._num_executors, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + driver_memory=self._driver_memory, + verbose=self._verbose, + keytab=self._keytab, + principal=self._principal, + cmd_type=self._cmd_type, + jdbc_table=self._jdbc_table, + jdbc_conn_id=self._jdbc_conn_id, + jdbc_driver=self._jdbc_driver, + metastore_table=self._metastore_table, + jdbc_truncate=self._jdbc_truncate, + save_mode=self._save_mode, + save_format=self._save_format, + batch_size=self._batch_size, + fetch_size=self._fetch_size, + num_partitions=self._num_partitions, + partition_column=self._partition_column, + lower_bound=self._lower_bound, + upper_bound=self._upper_bound, + create_table_column_types=self._create_table_column_types + ) + self._hook.submit_jdbc_job() + + def on_kill(self): + self._hook.on_kill() diff --git a/tests/contrib/hooks/test_spark_jdbc_hook.py b/tests/contrib/hooks/test_spark_jdbc_hook.py new file mode 100644 index 0000000000000..377183dc691e9 --- /dev/null +++ b/tests/contrib/hooks/test_spark_jdbc_hook.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from airflow import configuration, models +from airflow.utils import db + +from airflow.contrib.hooks.spark_jdbc_hook import SparkJDBCHook + + +class TestSparkJDBCHook(unittest.TestCase): + + _config = { + 'cmd_type': 'spark_to_jdbc', + 'jdbc_table': 'tableMcTableFace', + 'jdbc_driver': 'org.postgresql.Driver', + 'metastore_table': 'hiveMcHiveFace', + 'jdbc_truncate': False, + 'save_mode': 'append', + 'save_format': 'parquet', + 'batch_size': 100, + 'fetch_size': 200, + 'num_partitions': 10, + 'partition_column': 'columnMcColumnFace', + 'lower_bound': '10', + 'upper_bound': '20', + 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' + 'comments VARCHAR(1024)' + } + + # this config is invalid because if one of [partitionColumn, lowerBound, upperBound] + # is set, all of the options must be enabled (enforced by Spark) + _invalid_config = { + 'cmd_type': 'spark_to_jdbc', + 'jdbc_table': 'tableMcTableFace', + 'jdbc_driver': 'org.postgresql.Driver', + 'metastore_table': 'hiveMcHiveFace', + 'jdbc_truncate': False, + 'save_mode': 'append', + 'save_format': 'parquet', + 'batch_size': 100, + 'fetch_size': 200, + 'num_partitions': 10, + 'partition_column': 'columnMcColumnFace', + 'upper_bound': '20', + 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' + 'comments VARCHAR(1024)' + } + + def setUp(self): + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='spark-default', conn_type='spark', + host='yarn://yarn-master', + extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + ) + db.merge_conn( + models.Connection( + conn_id='jdbc-default', conn_type='postgres', + host='localhost', schema='default', port=5432, + login='user', password='supersecret', + extra='{"conn_prefix":"jdbc:postgresql://"}' + ) + ) + + def test_resolve_jdbc_connection(self): + # Given + hook = SparkJDBCHook(jdbc_conn_id='jdbc-default') + expected_connection = { + 'url': 'localhost:5432', + 'schema': 'default', + 'conn_prefix': 'jdbc:postgresql://', + 'user': 'user', + 'password': 'supersecret' + } + + # When + connection = hook._resolve_jdbc_connection() + + # Then + self.assertEqual(connection, expected_connection) + + def test_build_jdbc_arguments(self): + # Given + hook = SparkJDBCHook(**self._config) + + # When + cmd = hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection()) + + # Then + expected_jdbc_arguments = [ + '-cmdType', 'spark_to_jdbc', + '-url', 'jdbc:postgresql://localhost:5432/default', + '-user', 'user', + '-password', 'supersecret', + '-metastoreTable', 'hiveMcHiveFace', + '-jdbcTable', 'tableMcTableFace', + '-jdbcDriver', 'org.postgresql.Driver', + '-batchsize', '100', + '-fetchsize', '200', + '-numPartitions', '10', + '-partitionColumn', 'columnMcColumnFace', + '-lowerBound', '10', + '-upperBound', '20', + '-saveMode', 'append', + '-saveFormat', 'parquet', + '-createTableColumnTypes', 'columnMcColumnFace INTEGER(100), name CHAR(64),' + 'comments VARCHAR(1024)' + ] + self.assertEquals(expected_jdbc_arguments, cmd) + + def test_build_jdbc_arguments_invalid(self): + # Given + hook = SparkJDBCHook(**self._invalid_config) + + # Expect Exception + hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_spark_jdbc_operator.py b/tests/contrib/operators/test_spark_jdbc_operator.py new file mode 100644 index 0000000000000..5996db0c4765d --- /dev/null +++ b/tests/contrib/operators/test_spark_jdbc_operator.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow import DAG, configuration + +from airflow.contrib.operators.spark_jdbc_operator import SparkJDBCOperator +from airflow.utils import timezone + +DEFAULT_DATE = timezone.datetime(2017, 1, 1) + + +class TestSparkJDBCOperator(unittest.TestCase): + _config = { + 'spark_app_name': '{{ task_instance.task_id }}', + 'spark_conf': { + 'parquet.compression': 'SNAPPY' + }, + 'spark_files': 'hive-site.xml', + 'spark_py_files': 'sample_library.py', + 'spark_jars': 'parquet.jar', + 'num_executors': 4, + 'executor_cores': 4, + 'executor_memory': '22g', + 'driver_memory': '3g', + 'verbose': True, + 'keytab': 'privileged_user.keytab', + 'principal': 'user/spark@airflow.org', + 'cmd_type': 'spark_to_jdbc', + 'jdbc_table': 'tableMcTableFace', + 'jdbc_driver': 'org.postgresql.Driver', + 'metastore_table': 'hiveMcHiveFace', + 'jdbc_truncate': False, + 'save_mode': 'append', + 'save_format': 'parquet', + 'batch_size': 100, + 'fetch_size': 200, + 'num_partitions': 10, + 'partition_column': 'columnMcColumnFace', + 'lower_bound': '10', + 'upper_bound': '20', + 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' + 'comments VARCHAR(1024)' + } + + def setUp(self): + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + self.dag = DAG('test_dag_id', default_args=args) + + def test_execute(self): + # Given / When + spark_conn_id = 'spark-default' + jdbc_conn_id = 'jdbc-default' + + operator = SparkJDBCOperator( + task_id='spark_jdbc_job', + dag=self.dag, + **self._config + ) + + # Then + expected_dict = { + 'spark_app_name': '{{ task_instance.task_id }}', + 'spark_conf': { + 'parquet.compression': 'SNAPPY' + }, + 'spark_files': 'hive-site.xml', + 'spark_py_files': 'sample_library.py', + 'spark_jars': 'parquet.jar', + 'num_executors': 4, + 'executor_cores': 4, + 'executor_memory': '22g', + 'driver_memory': '3g', + 'verbose': True, + 'keytab': 'privileged_user.keytab', + 'principal': 'user/spark@airflow.org', + 'cmd_type': 'spark_to_jdbc', + 'jdbc_table': 'tableMcTableFace', + 'jdbc_driver': 'org.postgresql.Driver', + 'metastore_table': 'hiveMcHiveFace', + 'jdbc_truncate': False, + 'save_mode': 'append', + 'save_format': 'parquet', + 'batch_size': 100, + 'fetch_size': 200, + 'num_partitions': 10, + 'partition_column': 'columnMcColumnFace', + 'lower_bound': '10', + 'upper_bound': '20', + 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' + 'comments VARCHAR(1024)' + } + + self.assertEqual(spark_conn_id, operator._spark_conn_id) + self.assertEqual(jdbc_conn_id, operator._jdbc_conn_id) + self.assertEqual(expected_dict['spark_app_name'], operator._spark_app_name) + self.assertEqual(expected_dict['spark_conf'], operator._spark_conf) + self.assertEqual(expected_dict['spark_files'], operator._spark_files) + self.assertEqual(expected_dict['spark_py_files'], operator._spark_py_files) + self.assertEqual(expected_dict['spark_jars'], operator._spark_jars) + self.assertEqual(expected_dict['num_executors'], operator._num_executors) + self.assertEqual(expected_dict['executor_cores'], operator._executor_cores) + self.assertEqual(expected_dict['executor_memory'], operator._executor_memory) + self.assertEqual(expected_dict['driver_memory'], operator._driver_memory) + self.assertEqual(expected_dict['verbose'], operator._verbose) + self.assertEqual(expected_dict['keytab'], operator._keytab) + self.assertEqual(expected_dict['principal'], operator._principal) + self.assertEqual(expected_dict['cmd_type'], operator._cmd_type) + self.assertEqual(expected_dict['jdbc_table'], operator._jdbc_table) + self.assertEqual(expected_dict['jdbc_driver'], operator._jdbc_driver) + self.assertEqual(expected_dict['metastore_table'], operator._metastore_table) + self.assertEqual(expected_dict['jdbc_truncate'], operator._jdbc_truncate) + self.assertEqual(expected_dict['save_mode'], operator._save_mode) + self.assertEqual(expected_dict['save_format'], operator._save_format) + self.assertEqual(expected_dict['batch_size'], operator._batch_size) + self.assertEqual(expected_dict['fetch_size'], operator._fetch_size) + self.assertEqual(expected_dict['num_partitions'], operator._num_partitions) + self.assertEqual(expected_dict['partition_column'], operator._partition_column) + self.assertEqual(expected_dict['lower_bound'], operator._lower_bound) + self.assertEqual(expected_dict['upper_bound'], operator._upper_bound) + self.assertEqual(expected_dict['create_table_column_types'], + operator._create_table_column_types) + + +if __name__ == '__main__': + unittest.main()