Skip to content

Commit

Permalink
Feature/refactor azure synapse pipeline class (#38723)
Browse files Browse the repository at this point in the history
* redesign azure synapse pipeline

* add breaking change

* fix statics

* add breaking change

* change mapping

* change changelog

* fix provider.yaml
  • Loading branch information
romsharon98 authored Apr 4, 2024
1 parent be89300 commit 901c3a6
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 139 deletions.
9 changes: 9 additions & 0 deletions airflow/providers/microsoft/azure/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
Changelog
---------

10.0.0
......

Breaking changes
~~~~~~~~~~~~~~~~
.. warning::
* ``azure_synapse_pipeline`` connection type has been changed to ``azure_synapse``.
* The usage of ``default_conn_name=azure_synapse_connection`` is deprecated and will be removed in future. The new default connection name for ``AzureSynapsePipelineHook`` is: ``default_conn_name=azure_synapse_default``.

9.0.1
.....

Expand Down
76 changes: 60 additions & 16 deletions airflow/providers/microsoft/azure/hooks/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from __future__ import annotations

import time
import warnings
from typing import TYPE_CHECKING, Any, Union

from azure.core.exceptions import ServiceRequestError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.synapse.artifacts import ArtifactsClient
from azure.synapse.spark import SparkClient

from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
Expand Down Expand Up @@ -240,20 +241,20 @@ class AzureSynapsePipelineRunException(AirflowException):
"""An exception that indicates a pipeline run failed to complete."""


class AzureSynapsePipelineHook(BaseHook):
class BaseAzureSynapseHook(BaseHook):
"""
A hook to interact with Azure Synapse Pipeline.
A base hook class to create session and connection to Azure Synapse using connection id.
:param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`.
:param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace development endpoint.
"""

conn_type: str = "azure_synapse_pipeline"
conn_type: str = "azure_synapse"
conn_name_attr: str = "azure_synapse_conn_id"
default_conn_name: str = "azure_synapse_connection"
hook_name: str = "Azure Synapse Pipeline"
default_conn_name: str = "azure_synapse_default"
hook_name: str = "Azure Synapse"

@classmethod
@add_managed_identity_connection_widgets
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
Expand All @@ -262,23 +263,59 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:

return {
"tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()),
}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
"relabeling": {"login": "Client ID", "password": "Secret", "host": "Synapse Workspace URL"},
"relabeling": {
"login": "Client ID",
"password": "Secret",
"host": "Synapse Workspace URL",
},
}

def __init__(self, azure_synapse_conn_id: str = default_conn_name, **kwargs) -> None:
super().__init__(**kwargs)
self.conn_id = azure_synapse_conn_id

def _get_field(self, extras: dict, field_name: str) -> str:
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=field_name,
)


class AzureSynapsePipelineHook(BaseAzureSynapseHook):
"""
A hook to interact with Azure Synapse Pipeline.
:param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`.
:param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace development endpoint.
"""

default_conn_name: str = "azure_synapse_connection"

def __init__(
self, azure_synapse_workspace_dev_endpoint: str, azure_synapse_conn_id: str = default_conn_name
self,
azure_synapse_workspace_dev_endpoint: str,
azure_synapse_conn_id: str = default_conn_name,
**kwargs,
):
self._conn = None
self.conn_id = azure_synapse_conn_id
# Handling deprecation of "default_conn_name"
if azure_synapse_conn_id == self.default_conn_name:
warnings.warn(
"The usage of `default_conn_name=azure_synapse_connection` is deprecated and will be removed in future. Please update your code to use the new default connection name: `default_conn_name=azure_synapse_default`. ",
AirflowProviderDeprecationWarning,
)
self._conn: ArtifactsClient | None = None
self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint
super().__init__()
super().__init__(azure_synapse_conn_id=azure_synapse_conn_id, **kwargs)

def _get_field(self, extras, name):
return get_field(
Expand All @@ -297,15 +334,22 @@ def get_conn(self) -> ArtifactsClient:
tenant = self._get_field(extras, "tenantId")

credential: Credentials
if conn.login is not None and conn.password is not None:
if not conn.login or not conn.password:
managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")

credential = get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
else:
if not tenant:
raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")

credential = ClientSecretCredential(
client_id=conn.login, client_secret=conn.password, tenant_id=tenant
)
else:
credential = DefaultAzureCredential()

self._conn = self._create_client(credential, self.azure_synapse_workspace_dev_endpoint)

if self._conn is not None:
Expand All @@ -314,7 +358,7 @@ def get_conn(self) -> ArtifactsClient:
raise ValueError("Failed to create ArtifactsClient")

@staticmethod
def _create_client(credential: Credentials, endpoint: str):
def _create_client(credential: Credentials, endpoint: str) -> ArtifactsClient:
return ArtifactsClient(credential=credential, endpoint=endpoint)

def run_pipeline(self, pipeline_name: str, **config: Any) -> CreateRunResponse:
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ state: ready
source-date-epoch: 1709555852
# note that those versions are maintained by release manager - do not update them manually
versions:
- 10.0.0
- 9.0.1
- 9.0.0
- 8.5.1
Expand Down Expand Up @@ -302,12 +303,10 @@ connection-types:
connection-type: azure_container_registry
- hook-class-name: airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook
connection-type: azure_service_bus
- hook-class-name: airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook
- hook-class-name: airflow.providers.microsoft.azure.hooks.synapse.BaseAzureSynapseHook
connection-type: azure_synapse
- hook-class-name: airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook
connection-type: adls
- hook-class-name: airflow.providers.microsoft.azure.hooks.synapse.AzureSynapsePipelineHook
connection-type: azure_synapse_pipeline

secrets-backends:
- airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend
Expand Down
Loading

0 comments on commit 901c3a6

Please sign in to comment.