Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,29 @@ def __init__(
self.conn_id = conn_id
self.timeout = timeout
self.proxies = proxies
self.host = host
self.host = self._ensure_protocol(host)
if isinstance(scopes, str):
self.scopes = [scopes]
else:
self.scopes = scopes or [self.DEFAULT_SCOPE]
self.api_version = self.resolve_api_version_from_value(api_version)

def _ensure_protocol(self, host: str | None, schema: str = "https") -> str | None:
"""Ensure URL has http:// or https:// protocol prefix."""
if not host:
return None

if host.startswith(("http://", "https://")):
return host

self.log.warning(
"URL '%s' is missing protocol prefix. Automatically adding '%s://'. "
"Please update your connection configuration to include the full URL with protocol.",
host,
schema,
)
return f"{schema}://{host}"

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to connection form."""
Expand Down Expand Up @@ -232,9 +248,9 @@ def get_host(self, connection: Connection) -> str:
if connection.schema and connection.host:
return f"{connection.schema}://{connection.host}"
return NationalClouds.Global.value
if not self.host.startswith("http://") or not self.host.startswith("https://"):
return f"{connection.schema}://{self.host}"
return self.host

schema = connection.schema or "https"
return cast("str", self._ensure_protocol(self.host, schema))

def get_base_url(self, host: str, api_version: str, config: dict) -> str:
base_url = config.get("base_url", urljoin(host, api_version)).strip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,85 @@ async def test_build_request_adapter_masks_secrets(self):
mock_redact.assert_any_call("my_secret_password", name="client_secret")


class TestKiotaRequestAdapterHookProtocol:
"""Test protocol handling in KiotaRequestAdapterHook."""

def test_init_with_https_protocol(self):
"""Test that URL with https protocol is preserved."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="https://api.powerbi.com")
assert hook.host == "https://api.powerbi.com"

def test_init_with_http_protocol(self):
"""Test that URL with http protocol is preserved."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="http://api.powerbi.com")
assert hook.host == "http://api.powerbi.com"

def test_init_without_protocol(self):
"""Test that URL without protocol gets https added."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="api.powerbi.com")
assert hook.host == "https://api.powerbi.com"

def test_init_with_none_host(self):
"""Test that None host remains None."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None)
assert hook.host is None

def test_init_with_empty_host(self):
"""Test that empty string host becomes None."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="")
assert hook.host is None

def test_get_host_with_protocol_in_host_parameter(self):
"""Test get_host returns self.host when it already has protocol."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host="https://api.powerbi.com")
connection = mock_connection(schema="https", host="graph.microsoft.com")
actual = hook.get_host(connection)
assert actual == "https://api.powerbi.com"

def test_get_host_without_host_parameter_uses_connection(self):
"""Test get_host builds URL from connection when self.host is None."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None)
connection = mock_connection(schema="https", host="graph.microsoft.com")
actual = hook.get_host(connection)
assert actual == "https://graph.microsoft.com"

def test_get_host_fallback_to_default_when_no_connection_info(self):
"""Test get_host returns default when no host info available."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None)
connection = mock_connection(schema=None, host=None)
actual = hook.get_host(connection)
assert actual == NationalClouds.Global.value

def test_get_host_with_none_schema_uses_https_fallback(self):
"""Test get_host uses https fallback when connection.schema is None but host exists."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api", host=None)
hook.host = "api.powerbi.com"
connection = mock_connection(schema=None, host="dummy.com")
actual = hook.get_host(connection)
assert actual == "https://api.powerbi.com"

def test_ensure_protocol_warns_when_adding_protocol(self):
"""Test that _ensure_protocol logs warning when adding protocol."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")

with patch.object(hook.log, "warning") as mock_warning:
result = hook._ensure_protocol("api.powerbi.com")

assert result == "https://api.powerbi.com"
mock_warning.assert_called_once()
assert "missing protocol prefix" in mock_warning.call_args[0][0].lower()


class TestResponseHandler:
def test_default_response_handler_when_json(self):
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
Expand Down