Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce https in search #11337

Merged
merged 8 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from azure.core.tracing.decorator import distributed_trace

from ._search_service_client_base import SearchServiceClientBase
from ._generated import SearchServiceClient as _SearchServiceClient
from .._headers_mixin import HeadersMixin
from .._version import SDK_MONIKER
from ._datasources_client import SearchDataSourcesClient
from ._indexes_client import SearchIndexesClient
Expand All @@ -22,7 +22,7 @@
from azure.core.credentials import AzureKeyCredential


class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods
class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
Expand All @@ -44,13 +44,10 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth

def __init__(self, endpoint, credential, **kwargs):
# type: (str, AzureKeyCredential, **Any) -> None

self._endpoint = endpoint # type: str
self._credential = credential # type: AzureKeyCredential
super(SearchServiceClient, self).__init__(endpoint, credential, **kwargs)
self._client = _SearchServiceClient(
endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs
) # type: _SearchServiceClient

self._indexes_client = SearchIndexesClient(endpoint, credential, **kwargs)

self._synonym_maps_client = SearchSynonymMapsClient(
Expand All @@ -65,10 +62,6 @@ def __init__(self, endpoint, credential, **kwargs):

self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs)

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]

def __enter__(self):
# type: () -> SearchServiceClient
self._client.__enter__() # pylint:disable=no-member
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING

from .._headers_mixin import HeadersMixin
from ._utils import _normalize_endpoint

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, List, Optional, Sequence
from azure.core.credentials import AzureKeyCredential


class SearchServiceClientBase(HeadersMixin): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
:type endpoint: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials import AzureKeyCredential
"""

_ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str

def __init__(self, endpoint, credential):
# type: (str, AzureKeyCredential) -> None

self._endpoint = _normalize_endpoint(endpoint) # type: str
self._credential = credential # type: AzureKeyCredential

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,13 @@ def get_access_conditions(model, match_condition=MatchConditions.Unconditionally
return (error_map, AccessCondition(if_match=if_match, if_none_match=if_none_match))
except AttributeError:
raise ValueError("Unable to get e_tag from the model")

def _normalize_endpoint(endpoint):
try:
if not endpoint.lower().startswith('http'):
endpoint = "https://" + endpoint
elif not endpoint.lower().startswith('https'):
raise ValueError("Bearer token authentication is not permitted for non-TLS protected (non-https) URLs.")
return endpoint
except AttributeError:
raise ValueError("Endpoint must be a string.")
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from azure.core.tracing.decorator_async import distributed_trace_async

from .._generated.aio import SearchServiceClient as _SearchServiceClient
from ..._headers_mixin import HeadersMixin
from .._search_service_client_base import SearchServiceClientBase
from ..._version import SDK_MONIKER
from ._datasources_client import SearchDataSourcesClient
from ._indexes_client import SearchIndexesClient
Expand All @@ -22,7 +22,7 @@
from azure.core.credentials import AzureKeyCredential


class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods
class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
Expand All @@ -45,8 +45,7 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth
def __init__(self, endpoint, credential, **kwargs):
# type: (str, AzureKeyCredential, **Any) -> None

self._endpoint = endpoint # type: str
self._credential = credential # type: AzureKeyCredential
super().__init__(endpoint, credential, **kwargs)
self._client = _SearchServiceClient(
endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs
) # type: _SearchServiceClient
Expand All @@ -65,10 +64,6 @@ def __init__(self, endpoint, credential, **kwargs):

self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs)

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]

async def __aenter__(self):
# type: () -> SearchServiceClient
await self._client.__aenter__() # pylint:disable=no-member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_credential_roll(self):
def test_repr(self):
client = SearchServiceClient("endpoint", CREDENTIAL)
assert repr(client) == "<SearchServiceClient [endpoint={}]>".format(
repr("endpoint")
repr("https://endpoint")
)

@mock.patch(
Expand All @@ -52,3 +52,17 @@ def test_get_service_statistics(self, mock_get_stats):
assert mock_get_stats.called
assert mock_get_stats.call_args[0] == ()
assert mock_get_stats.call_args[1] == {"headers": client._headers}

def test_endpoint_https(self):
credential = AzureKeyCredential(key="old_api_key")
client = SearchServiceClient("endpoint", credential)
assert client._endpoint.startswith('https')

client = SearchServiceClient("https://endpoint", credential)
assert client._endpoint.startswith('https')

with pytest.raises(ValueError):
client = SearchServiceClient("http://endpoint", credential)

with pytest.raises(ValueError):
client = SearchServiceClient(12345, credential)