Skip to content

Commit

Permalink
Add Oracle Profile mapping (#1404)
Browse files Browse the repository at this point in the history
This PR adds the ability to map Oracle connections from Airflow to
Cosmos.

Co-authored-by: Shad L. Lords
[slords@lordsfam.net](mailto:slords@lordsfam.net)

Original PR by @slords:
#1190

closes: #1189
  • Loading branch information
pankajkoti authored Dec 19, 2024
1 parent c5edba0 commit ad89757
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .databricks.oauth import DatabricksOauthProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .oracle.user_pass import OracleUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
Expand All @@ -34,6 +35,7 @@
GoogleCloudOauthProfileMapping,
DatabricksTokenProfileMapping,
DatabricksOauthProfileMapping,
OracleUserPasswordProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
Expand Down Expand Up @@ -77,6 +79,7 @@ def get_automatic_profile_mapping(
"DatabricksTokenProfileMapping",
"DatabricksOauthProfileMapping",
"DbtProfileConfigVars",
"OracleUserPasswordProfileMapping",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/oracle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Oracle Airflow connection -> dbt profile mappings"""

from .user_pass import OracleUserPasswordProfileMapping

__all__ = ["OracleUserPasswordProfileMapping"]
89 changes: 89 additions & 0 deletions cosmos/profiles/oracle/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Maps Airflow Oracle connections using user + password authentication to dbt profiles."""

from __future__ import annotations

import re
from typing import Any

from ..base import BaseProfileMapping


class OracleUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Oracle connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/oracle-setup
https://airflow.apache.org/docs/apache-airflow-providers-oracle/stable/connections/oracle.html
"""

airflow_connection_type: str = "oracle"
dbt_profile_type: str = "oracle"
is_community: bool = True

required_fields = [
"user",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"port": "port",
"service": "extra.service_name",
"user": "login",
"password": "password",
"database": "extra.service_name",
"connection_string": "extra.dsn",
}

@property
def env_vars(self) -> dict[str, str]:
"""Set oracle thick mode."""
env_vars = super().env_vars
if self._get_airflow_conn_field("extra.thick_mode"):
env_vars["ORA_PYTHON_DRIVER_TYPE"] = "thick"
return env_vars

@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile. The password is stored in an environment variable."""
profile = {
"protocol": "tcp",
"port": 1521,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

if "schema" not in profile and "user" in profile:
proxy = re.search(r"\[([^]]+)\]", profile["user"])
if proxy:
profile["schema"] = proxy.group(1)
else:
profile["schema"] = profile["user"]
if "schema" in self.profile_args:
profile["schema"] = self.profile_args["schema"]

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"""Gets mock profile. Defaults port to 1521."""
profile_dict = {
"protocol": "tcp",
"port": 1521,
**super().mock_profile,
}

if "schema" not in profile_dict and "user" in profile_dict:
proxy = re.search(r"\[([^]]+)\]", profile_dict["user"])
if proxy:
profile_dict["schema"] = proxy.group(1)
else:
profile_dict["schema"] = profile_dict["user"]

user_defined_schema = self.profile_args.get("schema")
if user_defined_schema:
profile_dict["schema"] = user_defined_schema
return profile_dict
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dbt-all = [
# See: https://github.com/astronomer/astronomer-cosmos/issues/1379
"dbt-databricks!=1.9.0",
"dbt-exasol",
"dbt-oracle",
"dbt-postgres",
"dbt-redshift",
"dbt-snowflake",
Expand All @@ -61,6 +62,7 @@ dbt-bigquery = ["dbt-bigquery"]
dbt-clickhouse = ["dbt-clickhouse"]
dbt-databricks = ["dbt-databricks"]
dbt-exasol = ["dbt-exasol"]
dbt-oracle = ["dbt-oracle"]
dbt-postgres = ["dbt-postgres"]
dbt-redshift = ["dbt-redshift"]
dbt-snowflake = ["dbt-snowflake"]
Expand Down
254 changes: 254 additions & 0 deletions tests/profiles/oracle/test_oracle_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
"""Tests for the Oracle profile."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.oracle.user_pass import OracleUserPasswordProfileMapping


@pytest.fixture()
def mock_oracle_conn(): # type: ignore
"""
Sets the Oracle connection as an environment variable.
"""
conn = Connection(
conn_id="my_oracle_connection",
conn_type="oracle",
host="my_host",
login="my_user",
password="my_password",
port=1521,
extra='{"service_name": "my_service"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_oracle_conn_custom_port(): # type: ignore
"""
Sets the Oracle connection with a custom port as an environment variable.
"""
conn = Connection(
conn_id="my_oracle_connection",
conn_type="oracle",
host="my_host",
login="my_user",
password="my_password",
port=1600,
extra='{"service_name": "my_service"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Oracle profile mapping claims the correct connection type.
"""
potential_values = {
"conn_type": "oracle",
"login": "my_user",
"password": "my_password",
}

# if we're missing any of the required values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = OracleUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert not profile_mapping.can_claim_connection()

# if we have all the required values, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = OracleUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_oracle_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_oracle_conn.conn_id,
{"schema": "my_schema"},
)
assert isinstance(profile_mapping, OracleUserPasswordProfileMapping)


def test_profile_mapping_keeps_custom_port(mock_oracle_conn_custom_port: Connection) -> None:
profile = OracleUserPasswordProfileMapping(mock_oracle_conn_custom_port.conn_id, {"schema": "my_schema"})
assert profile.profile["port"] == 1600


def test_profile_args(
mock_oracle_conn: Connection,
) -> None:
"""
Tests that the profile values are set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_oracle_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
}

assert profile_mapping.profile == {
"type": mock_oracle_conn.conn_type,
"host": mock_oracle_conn.host,
"user": mock_oracle_conn.login,
"password": "{{ env_var('COSMOS_CONN_ORACLE_PASSWORD') }}",
"port": mock_oracle_conn.port,
"database": "my_service",
"service": "my_service",
"schema": "my_schema",
"protocol": "tcp",
}


def test_profile_args_overrides(
mock_oracle_conn: Connection,
) -> None:
"""
Tests that profile values can be overridden.
"""
profile_mapping = get_automatic_profile_mapping(
mock_oracle_conn.conn_id,
profile_args={
"schema": "my_schema_override",
"database": "my_database_override",
"service": "my_service_override",
},
)
assert profile_mapping.profile_args == {
"schema": "my_schema_override",
"database": "my_database_override",
"service": "my_service_override",
}

assert profile_mapping.profile == {
"type": mock_oracle_conn.conn_type,
"host": mock_oracle_conn.host,
"user": mock_oracle_conn.login,
"password": "{{ env_var('COSMOS_CONN_ORACLE_PASSWORD') }}",
"port": mock_oracle_conn.port,
"database": "my_database_override",
"service": "my_service_override",
"schema": "my_schema_override",
"protocol": "tcp",
}


def test_profile_env_vars(
mock_oracle_conn: Connection,
) -> None:
"""
Tests that environment variables are set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_oracle_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_ORACLE_PASSWORD": mock_oracle_conn.password,
}


def test_env_vars_thick_mode(mock_oracle_conn: Connection) -> None:
"""
Tests that `env_vars` includes `ORA_PYTHON_DRIVER_TYPE` when `extra.thick_mode` is enabled.
"""
mock_oracle_conn.extra = '{"service_name": "my_service", "thick_mode": true}'
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"})
assert profile_mapping.env_vars == {
"COSMOS_CONN_ORACLE_PASSWORD": mock_oracle_conn.password,
"ORA_PYTHON_DRIVER_TYPE": "thick",
}


def test_profile_filter_null(mock_oracle_conn: Connection) -> None:
"""
Tests that `profile` filters out null values.
"""
mock_oracle_conn.extra = '{"service_name": "my_service"}'
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": None})
profile = profile_mapping.profile
assert "schema" not in profile


def test_mock_profile(mock_oracle_conn: Connection) -> None:
"""
Tests that `mock_profile` sets default port and schema correctly.
"""
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"})
mock_profile = profile_mapping.mock_profile
assert mock_profile["port"] == 1521
assert mock_profile["schema"] == "my_schema"
assert mock_profile["protocol"] == "tcp"


def test_invalid_connection_type() -> None:
"""
Tests that the profile mapping does not claim a non-oracle connection type.
"""
conn = Connection(conn_id="invalid_conn", conn_type="postgres", login="my_user", password="my_password")
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = OracleUserPasswordProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()


def test_airflow_param_mapping(mock_oracle_conn: Connection) -> None:
"""
Tests that `airflow_param_mapping` correctly maps Airflow fields to dbt profile fields.
"""
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"})
mapped_params = profile_mapping.mapped_params

assert mapped_params["host"] == mock_oracle_conn.host
assert mapped_params["port"] == mock_oracle_conn.port
assert mapped_params["service"] == "my_service"
assert mapped_params["user"] == mock_oracle_conn.login
assert mapped_params["password"] == mock_oracle_conn.password


def test_profile_schema_extraction_with_proxy(mock_oracle_conn: Connection) -> None:
"""
Tests that the `schema` is extracted correctly from the `user` field
when a proxy schema is provided in square brackets.
"""
mock_oracle_conn.login = "my_user[proxy_schema]"
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {})

assert profile_mapping.profile["schema"] == "proxy_schema"


def test_profile_schema_defaults_to_user(mock_oracle_conn: Connection) -> None:
"""
Tests that the `schema` defaults to the `user` field when no proxy schema is provided.
"""
mock_oracle_conn.login = "my_user"
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {})

assert profile_mapping.profile["schema"] == "my_user"


def test_mock_profile_schema_extraction_with_proxy_gets_mock_value(mock_oracle_conn: Connection) -> None:
mock_oracle_conn.login = "my_user[proxy_schema]"
profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {})

mock_profile = profile_mapping.mock_profile

assert mock_profile["schema"] == "mock_value"

0 comments on commit ad89757

Please sign in to comment.