diff --git a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py index 2c57f493444c8..527f002e21001 100644 --- a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py @@ -25,6 +25,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook +DEFAULT_DB_PORT = 1521 PARAM_TYPES = {bool, float, int, str} @@ -183,7 +184,7 @@ def get_conn(self) -> oracledb.Connection: # Set up DSN service_name = conn.extra_dejson.get("service_name") - port = conn.port if conn.port else 1521 + port = conn.port if conn.port else DEFAULT_DB_PORT if conn.host and sid and not service_name: conn_config["dsn"] = oracledb.makedsn(conn.host, port, sid) elif conn.host and service_name and not sid: @@ -443,3 +444,26 @@ def handler(cursor): ) return result + + def get_uri(self) -> str: + """Get the URI for the Oracle connection.""" + conn = self.get_connection(self.oracle_conn_id) # type: ignore[attr-defined] + login = conn.login + password = conn.password + host = conn.host + port = conn.port or DEFAULT_DB_PORT + service_name = conn.extra_dejson.get("service_name") + sid = conn.extra_dejson.get("sid") + + if sid and service_name: + raise ValueError("At most one allowed for 'sid', and 'service name'.") + + uri = f"oracle://{login}:{password}@{host}:{port}" + if sid: + uri = f"{uri}/{sid}" + elif service_name: + uri = f"{uri}/{service_name}" + elif conn.schema: + uri = f"{uri}/{conn.schema}" + + return uri diff --git a/providers/oracle/tests/unit/oracle/hooks/test_oracle.py b/providers/oracle/tests/unit/oracle/hooks/test_oracle.py index 9f86df3013fa4..a82a6d207c111 100644 --- a/providers/oracle/tests/unit/oracle/hooks/test_oracle.py +++ b/providers/oracle/tests/unit/oracle/hooks/test_oracle.py @@ -256,6 +256,59 @@ def test_type_checking_thick_mode_config_dir(self): with pytest.raises(TypeError, match=r"thick_mode_config_dir expected str or None, got.*"): self.db_hook.get_conn() + @pytest.mark.parametrize( + "connection_params, expected_uri", + [ + pytest.param( + {"extra": '{"service_name": "service"}', "schema": None, "port": 1521}, + "oracle://login:password@host:1521/service", + id="service_name_in_extra", + ), + pytest.param( + {"extra": '{"sid": "sid"}', "schema": None, "port": 1521}, + "oracle://login:password@host:1521/sid", + id="sid_in_extra", + ), + pytest.param( + {"extra": "{}", "schema": "db_schema", "port": 1521}, + "oracle://login:password@host:1521/db_schema", + id="schema_only", + ), + pytest.param( + {"extra": "{}", "schema": None, "port": 1521}, + "oracle://login:password@host:1521", + id="no_schema_no_extra", + ), + pytest.param( + {"extra": "{}", "schema": "db_schema", "port": None}, + "oracle://login:password@host:1521/db_schema", + id="schema_only_default_port", + ), + pytest.param( + {"extra": '{"service_name": "service"}', "schema": "db_schema", "port": 1521}, + "oracle://login:password@host:1521/service", + id="service_name_with_schema", + ), + pytest.param( + { + "extra": '{"service_name": "(DESCRIPTION=(ADDRESS=(host=oracle://somedb.example.com)(protocol=TCP)(port=1521))(CONNECT_DATA=(SERVICE_NAME=orclpdb)))"}', + "schema": None, + "port": 1521, + }, + "oracle://login:password@host:1521/(DESCRIPTION=(ADDRESS=(host=oracle://somedb.example.com)(protocol=TCP)(port=1521))(CONNECT_DATA=(SERVICE_NAME=orclpdb)))", + id="complex_service_name", + ), + ], + ) + @mock.patch("airflow.providers.oracle.hooks.oracle.oracledb.connect") + def test_get_uri(self, mock_connect, connection_params, expected_uri): + self.connection.extra = connection_params["extra"] + self.connection.schema = connection_params["schema"] + self.connection.port = connection_params["port"] + + uri = self.db_hook.get_uri() + assert uri == expected_uri + class TestOracleHook: def setup_method(self):