diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 376292e1b3c6f..1fc35d779a230 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -160,13 +160,29 @@ def __init__( self.conn_id = conn_id self.timeout = timeout self.proxies = proxies - self.host = host + self.host = self._ensure_protocol(host) if isinstance(scopes, str): self.scopes = [scopes] else: self.scopes = scopes or [self.DEFAULT_SCOPE] self.api_version = self.resolve_api_version_from_value(api_version) + def _ensure_protocol(self, host: str | None, schema: str = "https") -> str | None: + """Ensure URL has http:// or https:// protocol prefix.""" + if not host: + return None + + if host.startswith(("http://", "https://")): + return host + + self.log.warning( + "URL '%s' is missing protocol prefix. Automatically adding '%s://'. " + "Please update your connection configuration to include the full URL with protocol.", + host, + schema, + ) + return f"{schema}://{host}" + @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" @@ -232,9 +248,9 @@ def get_host(self, connection: Connection) -> str: if connection.schema and connection.host: return f"{connection.schema}://{connection.host}" return NationalClouds.Global.value - if not self.host.startswith("http://") or not self.host.startswith("https://"): - return f"{connection.schema}://{self.host}" - return self.host + + schema = connection.schema or "https" + return cast("str", self._ensure_protocol(self.host, schema)) def get_base_url(self, host: str, api_version: str, config: dict) -> str: base_url = config.get("base_url", urljoin(host, api_version)).strip() diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index 6f6216d8fa9d1..a414663d6f461 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -433,6 +433,85 @@ async def test_build_request_adapter_masks_secrets(self): mock_redact.assert_any_call("my_secret_password", name="client_secret") +class TestKiotaRequestAdapterHookProtocol: + """Test protocol handling in KiotaRequestAdapterHook.""" + + def test_init_with_https_protocol(self): + """Test that URL with https protocol is preserved.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="https://api.powerbi.com") + assert hook.host == "https://api.powerbi.com" + + def test_init_with_http_protocol(self): + """Test that URL with http protocol is preserved.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="http://api.powerbi.com") + assert hook.host == "http://api.powerbi.com" + + def test_init_without_protocol(self): + """Test that URL without protocol gets https added.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="api.powerbi.com") + assert hook.host == "https://api.powerbi.com" + + def test_init_with_none_host(self): + """Test that None host remains None.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None) + assert hook.host is None + + def test_init_with_empty_host(self): + """Test that empty string host becomes None.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="") + assert hook.host is None + + def test_get_host_with_protocol_in_host_parameter(self): + """Test get_host returns self.host when it already has protocol.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="https://api.powerbi.com") + connection = mock_connection(schema="https", host="graph.microsoft.com") + actual = hook.get_host(connection) + assert actual == "https://api.powerbi.com" + + def test_get_host_without_host_parameter_uses_connection(self): + """Test get_host builds URL from connection when self.host is None.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None) + connection = mock_connection(schema="https", host="graph.microsoft.com") + actual = hook.get_host(connection) + assert actual == "https://graph.microsoft.com" + + def test_get_host_fallback_to_default_when_no_connection_info(self): + """Test get_host returns default when no host info available.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None) + connection = mock_connection(schema=None, host=None) + actual = hook.get_host(connection) + assert actual == NationalClouds.Global.value + + def test_get_host_with_none_schema_uses_https_fallback(self): + """Test get_host uses https fallback when connection.schema is None but host exists.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None) + hook.host = "api.powerbi.com" + connection = mock_connection(schema=None, host="dummy.com") + actual = hook.get_host(connection) + assert actual == "https://api.powerbi.com" + + def test_ensure_protocol_warns_when_adding_protocol(self): + """Test that _ensure_protocol logs warning when adding protocol.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + + with patch.object(hook.log, "warning") as mock_warning: + result = hook._ensure_protocol("api.powerbi.com") + + assert result == "https://api.powerbi.com" + mock_warning.assert_called_once() + assert "missing protocol prefix" in mock_warning.call_args[0][0].lower() + + class TestResponseHandler: def test_default_response_handler_when_json(self): users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")