Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ repos:
^providers/edge3/.*\.py$|
^providers/mysql/.*\.py$|
^providers/openlineage/.*\.py$|
^providers/google/src/airflow/providers/google/cloud/triggers/dataproc\.py$|
^providers/google/src/airflow/providers/google/cloud/triggers/bigquery\.py$|
^providers/google/tests/unit/google/cloud/utils/gcp_authenticator\.py$|
^providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager\.py$|
^providers/google/tests/system/google/gcp_api_client_helpers\.py$|
^providers/standard/tests/unit/standard/operators/test_latest_only_operator\.py$|
^providers/standard/tests/unit/standard/operators/test_trigger_dagrun\.py$|
^providers/standard/tests/unit/standard/operators/test_weekday\.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from sqlalchemy.orm.session import Session

if not AIRFLOW_V_3_0_PLUS:
from sqlalchemy import select

from airflow.models.taskinstance import TaskInstance
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -105,13 +107,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
task_instance = session.scalar(
select(TaskInstance).where(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from sqlalchemy.orm.session import Session

if not AIRFLOW_V_3_0_PLUS:
from sqlalchemy import select

from airflow.models.taskinstance import TaskInstance
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -130,13 +132,14 @@ def get_task_instance(self, session: Session) -> TaskInstance:

:param session: Sqlalchemy session
"""
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
task_instance = session.scalar(
select(TaskInstance).where(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
Expand Down Expand Up @@ -266,13 +269,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
task_instance = session.scalar(
select(TaskInstance).where(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
Expand Down
10 changes: 6 additions & 4 deletions providers/google/tests/system/google/gcp_api_client_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ def create_airflow_connection(
connection_id=connection_id, connection=connection_conf, is_composer=is_composer
)
else:
from sqlalchemy import delete

from airflow.models import Connection
from airflow.settings import Session

if Session is None:
raise RuntimeError("Session not configured. Call configure_orm() first.")
session = Session()
query = session.query(Connection).filter(Connection.conn_id == connection_id)
query.delete()
session.execute(delete(Connection).where(Connection.conn_id == connection_id))
connection = Connection(conn_id=connection_id, **connection_conf)
session.add(connection)
session.commit()
Expand All @@ -114,12 +115,13 @@ def delete_airflow_connection(connection_id: str, is_composer: bool = False) ->
if AIRFLOW_V_3_0_PLUS:
delete_connection_request(connection_id=connection_id, is_composer=is_composer)
else:
from sqlalchemy import delete

from airflow.models import Connection
from airflow.settings import Session

if Session is None:
raise RuntimeError("Session not configured. Call configure_orm() first.")
session = Session()
query = session.query(Connection).filter(Connection.conn_id == connection_id)
query.delete()
session.execute(delete(Connection).where(Connection.conn_id == connection_id))
session.commit()
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import os
import subprocess

from sqlalchemy import select

from airflow import settings
from airflow.models import Connection
from airflow.providers.common.compat.sdk import AirflowException
Expand Down Expand Up @@ -97,7 +99,7 @@ def set_key_path_in_airflow_connection(self):
:return: None
"""
with settings.Session() as session:
conn = session.query(Connection).filter(Connection.conn_id == "google_cloud_default")[0]
conn = session.scalar(select(Connection).where(Connection.conn_id == "google_cloud_default"))
extras = conn.extra_dejson
extras[KEYPATH_EXTRA] = self.full_key_path
if extras.get(KEYFILE_DICT_EXTRA):
Expand All @@ -113,7 +115,7 @@ def set_dictionary_in_airflow_connection(self):
:return: None
"""
with settings.Session() as session:
conn = session.query(Connection).filter(Connection.conn_id == "google_cloud_default")[0]
conn = session.scalar(select(Connection).where(Connection.conn_id == "google_cloud_default"))
extras = conn.extra_dejson
with open(self.full_key_path) as path_file:
content = json.load(path_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock

import pytest
from sqlalchemy import delete

from airflow.models import TaskInstance as TI
from airflow.providers.google.marketing_platform.operators.campaign_manager import (
Expand Down Expand Up @@ -91,11 +92,11 @@ def test_execute(self, mock_base_op, hook_mock):
class TestGoogleCampaignManagerDownloadReportOperator:
def setup_method(self):
with create_session() as session:
session.query(TI).delete()
session.execute(delete(TI))

def teardown_method(self):
with create_session() as session:
session.query(TI).delete()
session.execute(delete(TI))

@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.http")
@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.tempfile")
Expand Down