Skip to content
Merged
26 changes: 25 additions & 1 deletion providers/oracle/src/airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.providers.common.sql.hooks.sql import DbApiHook

DEFAULT_DB_PORT = 1521
PARAM_TYPES = {bool, float, int, str}


Expand Down Expand Up @@ -183,7 +184,7 @@ def get_conn(self) -> oracledb.Connection:

# Set up DSN
service_name = conn.extra_dejson.get("service_name")
port = conn.port if conn.port else 1521
port = conn.port if conn.port else DEFAULT_DB_PORT
if conn.host and sid and not service_name:
conn_config["dsn"] = oracledb.makedsn(conn.host, port, sid)
elif conn.host and service_name and not sid:
Expand Down Expand Up @@ -443,3 +444,26 @@ def handler(cursor):
)

return result

def get_uri(self) -> str:
"""Get the URI for the Oracle connection."""
conn = self.get_connection(self.oracle_conn_id) # type: ignore[attr-defined]
login = conn.login
password = conn.password
host = conn.host
port = conn.port or DEFAULT_DB_PORT
service_name = conn.extra_dejson.get("service_name")
sid = conn.extra_dejson.get("sid")

if sid and service_name:
raise ValueError("At most one allowed for 'sid', and 'service name'.")

uri = f"oracle://{login}:{password}@{host}:{port}"
if sid:
uri = f"{uri}/{sid}"
elif service_name:
uri = f"{uri}/{service_name}"
elif conn.schema:
uri = f"{uri}/{conn.schema}"

return uri
53 changes: 53 additions & 0 deletions providers/oracle/tests/unit/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,59 @@ def test_type_checking_thick_mode_config_dir(self):
with pytest.raises(TypeError, match=r"thick_mode_config_dir expected str or None, got.*"):
self.db_hook.get_conn()

@pytest.mark.parametrize(
"connection_params, expected_uri",
[
pytest.param(
{"extra": '{"service_name": "service"}', "schema": None, "port": 1521},
"oracle://login:password@host:1521/service",
id="service_name_in_extra",
),
pytest.param(
{"extra": '{"sid": "sid"}', "schema": None, "port": 1521},
"oracle://login:password@host:1521/sid",
id="sid_in_extra",
),
pytest.param(
{"extra": "{}", "schema": "db_schema", "port": 1521},
"oracle://login:password@host:1521/db_schema",
id="schema_only",
),
pytest.param(
{"extra": "{}", "schema": None, "port": 1521},
"oracle://login:password@host:1521",
id="no_schema_no_extra",
),
pytest.param(
{"extra": "{}", "schema": "db_schema", "port": None},
"oracle://login:password@host:1521/db_schema",
id="schema_only_default_port",
),
pytest.param(
{"extra": '{"service_name": "service"}', "schema": "db_schema", "port": 1521},
"oracle://login:password@host:1521/service",
id="service_name_with_schema",
),
pytest.param(
{
"extra": '{"service_name": "(DESCRIPTION=(ADDRESS=(host=oracle://somedb.example.com)(protocol=TCP)(port=1521))(CONNECT_DATA=(SERVICE_NAME=orclpdb)))"}',
"schema": None,
"port": 1521,
},
"oracle://login:password@host:1521/(DESCRIPTION=(ADDRESS=(host=oracle://somedb.example.com)(protocol=TCP)(port=1521))(CONNECT_DATA=(SERVICE_NAME=orclpdb)))",
id="complex_service_name",
),
],
)
@mock.patch("airflow.providers.oracle.hooks.oracle.oracledb.connect")
def test_get_uri(self, mock_connect, connection_params, expected_uri):
self.connection.extra = connection_params["extra"]
self.connection.schema = connection_params["schema"]
self.connection.port = connection_params["port"]

uri = self.db_hook.get_uri()
assert uri == expected_uri


class TestOracleHook:
def setup_method(self):
Expand Down