Skip to content

Commit

Permalink
Adjust test for Airflow-2.9 (#2149)
Browse files Browse the repository at this point in the history
- pin google-cloud-bigquery
- Remove conditional query building for different Airflow version
- Adjust dataset and cleanoperator test 
- clean checkoperator test
  • Loading branch information
pankajastro authored May 16, 2024
1 parent 690db3c commit 149dd9e
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 20 deletions.
6 changes: 4 additions & 2 deletions python-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ google = [
"protobuf",
"apache-airflow-providers-google>=10.15.0",
"sqlalchemy-bigquery>=1.3.0",
"smart-open[gcs]>=5.2.1,<7.0.0"
"smart-open[gcs]>=5.2.1,<7.0.0",
"google-cloud-bigquery<3.21.0"
]
snowflake = [
"apache-airflow-providers-snowflake>=5.3.0",
Expand Down Expand Up @@ -126,7 +127,8 @@ all = [
"azure-storage-blob",
"apache-airflow-providers-microsoft-mssql>=3.2",
"airflow-provider-duckdb>=0.0.2",
"apache-airflow-providers-mysql"
"apache-airflow-providers-mysql",
"google-cloud-bigquery<3.21.0"
]
doc = [
"myst-parser>=0.17",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ def execute(self, context: "Context"):
db = create_database(self.dataset.conn_id)
self.table = db.get_table_qualified_name(self.dataset)
self.conn_id = self.dataset.conn_id
# apache-airflow-providers-common-sql == 1.2.0 which is compatible with airflow 2.2.5 implements the self.sql
# differently compared to apache-airflow-providers-common-sql == 1.3.3
try:
self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table"
except AttributeError:
self.sql = f"SELECT * FROM {self.table};"
self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table"
super().execute(context)

def get_db_hook(self) -> Any:
Expand Down
4 changes: 1 addition & 3 deletions python-sdk/tests/airflow_tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_kwargs_with_temp_table():
@pytest.mark.skipif(airflow.__version__ < "2.4.0", reason="Require Airflow version >= 2.4.0")
def test_example_dataset_dag():
from airflow.datasets import Dataset
from airflow.models.dataset import DatasetModel

dir_path = os.path.dirname(os.path.realpath(__file__))
db = DagBag(dir_path + "/../../example_dags/example_datasets.py")
Expand All @@ -115,9 +114,8 @@ def test_example_dataset_dag():
outlets = producer_dag.tasks[-1].outlets
assert isinstance(outlets[0], Dataset)
# Test that dataset_triggers is only set if all the instances passed to the DAG object are Datasets
assert consumer_dag.dataset_triggers == outlets
assert consumer_dag.dataset_triggers.objects[0] == outlets[0]
assert outlets[0].uri == "astro://postgres_conn@?table=imdb_movies"
assert DatasetModel.from_public(outlets[0]) == Dataset("astro://postgres_conn@?table=imdb_movies")


def test_disable_auto_inlets_outlets():
Expand Down
12 changes: 7 additions & 5 deletions python-sdk/tests/sql/operators/test_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,20 @@ def test_error_raised_with_blocking_op_executors(
reason="BackfillJobRunner and Job classes are only available in airflow >= 2.6",
)
@pytest.mark.parametrize(
"executor_in_job,executor_in_cfg,expected_val",
"executor_in_job, executor_in_cfg, expected_val",
[
(SequentialExecutor(), "LocalExecutor", True),
(SequentialExecutor(), "SequentialExecutor", True),
(LocalExecutor(), "LocalExecutor", False),
(None, "LocalExecutor", False),
(None, "SequentialExecutor", True),
],
)
def test_single_worker_mode_backfill(executor_in_job, executor_in_cfg, expected_val):
def test_single_worker_mode_backfill(monkeypatch, executor_in_job, executor_in_cfg, expected_val):
"""Test that if we run Backfill Job it should be marked as single worker node"""
from airflow.jobs.backfill_job_runner import BackfillJobRunner
from airflow.jobs.job import Job

monkeypatch.setattr("airflow.executors.executor_loader._executor_names", [])
dag = DAG("test_single_worker_mode_backfill", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

Expand Down Expand Up @@ -175,17 +176,18 @@ def test_single_worker_mode_backfill_airflow_2_5(executor_in_job, executor_in_cf
@pytest.mark.parametrize(
"executor_in_job,executor_in_cfg,expected_val",
[
(SequentialExecutor(), "LocalExecutor", True),
(SequentialExecutor(), "SequentialExecutor", True),
(LocalExecutor(), "LocalExecutor", False),
(None, "LocalExecutor", False),
(None, "SequentialExecutor", True),
],
)
def test_single_worker_mode_scheduler_job(executor_in_job, executor_in_cfg, expected_val):
def test_single_worker_mode_scheduler_job(monkeypatch, executor_in_job, executor_in_cfg, expected_val):
"""Test that if we run Scheduler Job it should be marked as single worker node"""
from airflow.jobs.job import Job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner

monkeypatch.setattr("airflow.executors.executor_loader._executor_names", [])
dag = DAG("test_single_worker_mode_scheduler_job", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from astro import sql as aql
from astro.constants import Database
from astro.files import File
from astro.table import Table
from tests.sql.operators import utils as test_utils

CWD = pathlib.Path(__file__).parent
Expand All @@ -22,7 +21,6 @@
{
"database": Database.BIGQUERY,
"file": File(path=str(CWD) + "/../../../data/data_validation.csv"),
"table": Table(conn_id="gcp_conn_project"),
},
{
"database": Database.POSTGRES,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from astro import sql as aql
from astro.constants import Database
from astro.files import File
from astro.table import Table
from tests.sql.operators import utils as test_utils

CWD = pathlib.Path(__file__).parent
Expand All @@ -22,7 +21,6 @@
{
"database": Database.BIGQUERY,
"file": File(path=str(CWD) + "/../../../data/homes_main.csv"),
"table": Table(conn_id="gcp_conn_project"),
},
{
"database": Database.POSTGRES,
Expand Down

0 comments on commit 149dd9e

Please sign in to comment.