diff --git a/meilisearch/client.py b/meilisearch/client.py index 845d2ad2..85aea7e3 100644 --- a/meilisearch/client.py +++ b/meilisearch/client.py @@ -1,5 +1,9 @@ -from typing import Any, Dict, List, Optional - +import base64 +import hashlib +import hmac +import json +import datetime +from typing import Any, Dict, List, Optional, Union from meilisearch.index import Index from meilisearch.config import Config from meilisearch.task import get_task, get_tasks, wait_for_task @@ -464,3 +468,75 @@ def wait_for_task( An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://docs.meilisearch.com/errors/#meilisearch-errors """ return wait_for_task(self.config, uid, timeout_in_ms, interval_in_ms) + + def generate_tenant_token( + self, + search_rules: Union[Dict[str, Any], List[str]], + *, + expires_at: Optional[datetime.datetime] = None, + api_key: Optional[str] = None + ) -> str: + """Generate a JWT token for the use of multitenancy. + + Parameters + ---------- + search_rules: + A Dictionary or list of string which contains the rules to be enforced at search time for all or specific + accessible indexes for the signing API Key. + In the specific case where you do not want to have any restrictions you can also use a list ["*"]. + expires_at (optional): + Date and time when the key will expire. Note that if an expires_at value is included it should be in UTC time. + api_key (optional): + The API key parent of the token. If you leave it empty the client API Key will be used. + + Returns + ------- + jwt_token: + A string containing the jwt tenant token. + Note: If your token does not work remember that the search_rules is mandatory and should be well formatted. + `exp` must be a `datetime` in the future. It's not possible to create a token from the master key. + """ + # Validate all fields + if api_key == '' or api_key is None and self.config.api_key is None: + raise Exception('An api key is required in the client or should be passed as an argument.') + if not search_rules or search_rules == ['']: + raise Exception('The search_rules field is mandatory and should be defined.') + if expires_at and expires_at < datetime.datetime.utcnow(): + raise Exception('The date expires_at should be in the future.') + + # Standard JWT header for encryption with SHA256/HS256 algorithm + header = { + "typ": "JWT", + "alg": "HS256" + } + + api_key = str(self.config.api_key) if api_key is None else api_key + + # Add the required fields to the payload + payload = { + 'apiKeyPrefix': api_key[0:8], + 'searchRules': search_rules, + 'exp': int(datetime.datetime.timestamp(expires_at)) if expires_at is not None else None + } + + # Serialize the header and the payload + json_header = json.dumps(header, separators=(",",":")).encode() + json_payload = json.dumps(payload, separators=(",",":")).encode() + + # Encode the header and the payload to Base64Url String + header_encode = self._base64url_encode(json_header) + header_payload = self._base64url_encode(json_payload) + + secret_encoded = api_key.encode() + # Create Signature Hash + signature = hmac.new(secret_encoded, (header_encode + "." + header_payload).encode(), hashlib.sha256).digest() + # Create JWT + jwt_token = header_encode + '.' + header_payload + '.' + self._base64url_encode(signature) + + return jwt_token + + @staticmethod + def _base64url_encode( + data: bytes + ) -> str: + return base64.urlsafe_b64encode(data).decode('utf-8').replace('=','') diff --git a/tests/client/test_client_tenant_token.py b/tests/client/test_client_tenant_token.py new file mode 100644 index 00000000..037b07b9 --- /dev/null +++ b/tests/client/test_client_tenant_token.py @@ -0,0 +1,106 @@ +# pylint: disable=invalid-name + +from re import search +import pytest +import meilisearch +from tests import BASE_URL, MASTER_KEY +from meilisearch.errors import MeiliSearchApiError +import datetime + +def test_generate_tenant_token_with_search_rules(get_private_key, index_with_documents): + """Tests create a tenant token with only search rules.""" + index_with_documents() + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + token = client.generate_tenant_token(search_rules=["*"]) + + token_client = meilisearch.Client(BASE_URL, token) + response = token_client.index('indexUID').search('', { + 'limit': 5 + }) + assert isinstance(response, dict) + assert len(response['hits']) == 5 + assert response['query'] == '' + +def test_generate_tenant_token_with_search_rules_on_one_index(get_private_key, empty_index): + """Tests create a tenant token with search rules set for one index.""" + empty_index() + empty_index('tenant_token') + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + token = client.generate_tenant_token(search_rules=['indexUID']) + + token_client = meilisearch.Client(BASE_URL, token) + response = token_client.index('indexUID').search('') + assert isinstance(response, dict) + assert response['query'] == '' + with pytest.raises(MeiliSearchApiError): + response = token_client.index('tenant_token').search('') + +def test_generate_tenant_token_with_api_key(client, get_private_key, empty_index): + """Tests create a tenant token with search rules and an api key.""" + empty_index() + token = client.generate_tenant_token(search_rules=["*"], api_key=get_private_key['key']) + + token_client = meilisearch.Client(BASE_URL, token) + response = token_client.index('indexUID').search('') + assert isinstance(response, dict) + assert response['query'] == '' + +def test_generate_tenant_token_with_expires_at(client, get_private_key, empty_index): + """Tests create a tenant token with search rules and expiration date.""" + empty_index() + client = meilisearch.Client(BASE_URL, get_private_key['key']) + tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) + + token = client.generate_tenant_token(search_rules=["*"], expires_at=tomorrow) + + token_client = meilisearch.Client(BASE_URL, token) + response = token_client.index('indexUID').search('') + assert isinstance(response, dict) + assert response['query'] == '' + +def test_generate_tenant_token_with_empty_search_rules_in_list(get_private_key): + """Tests create a tenant token without search rules.""" + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules=['']) + +def test_generate_tenant_token_without_search_rules_in_list(get_private_key): + """Tests create a tenant token without search rules.""" + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules=[]) + +def test_generate_tenant_token_without_search_rules_in_dict(get_private_key): + """Tests create a tenant token without search rules.""" + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules={}) + +def test_generate_tenant_token_with_empty_search_rules_in_dict(get_private_key): + """Tests create a tenant token without search rules.""" + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules={''}) + +def test_generate_tenant_token_with_bad_expires_at(client, get_private_key): + """Tests create a tenant token with a bad expires at.""" + client = meilisearch.Client(BASE_URL, get_private_key['key']) + + yesterday = datetime.datetime.utcnow() + datetime.timedelta(days=-1) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules=["*"], expires_at=yesterday) + +def test_generate_tenant_token_with_no_api_key(client): + """Tests create a tenant token with no api key.""" + client = meilisearch.Client(BASE_URL) + + with pytest.raises(Exception): + client.generate_tenant_token(search_rules=["*"]) + diff --git a/tests/conftest.py b/tests/conftest.py index 60f0861e..43124544 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from tests import common import meilisearch from meilisearch.errors import MeiliSearchApiError +from typing import Optional @fixture(scope='session') def client(): @@ -67,8 +68,9 @@ def songs_ndjson(): return song_ndjson_file.read().encode('utf-8') @fixture(scope='function') -def empty_index(client): - def index_maker(index_name=common.INDEX_UID): +def empty_index(client, index_uid: Optional[str] = None): + index_uid = index_uid if index_uid else common.INDEX_UID + def index_maker(index_name=index_uid): task = client.create_index(uid=index_name) client.wait_for_task(task['uid']) return client.get_index(uid=index_name) @@ -109,3 +111,9 @@ def test_key_info(client): client.delete_key(key['key']) except MeiliSearchApiError: pass + +@fixture(scope='function') +def get_private_key(client): + keys = client.get_keys()['results'] + key = next(x for x in keys if 'Default Search API' in x['description']) + return key