diff --git a/codegen/apis b/codegen/apis index e9b47c76..062b114b 160000 --- a/codegen/apis +++ b/codegen/apis @@ -1 +1 @@ -Subproject commit e9b47c76f649656002f4911946ca6c4c4a6f04fc +Subproject commit 062b114b6d7b016de2b4d2b68c211a81b8689d1a diff --git a/codegen/python-oas-templates b/codegen/python-oas-templates index b72bd5bf..7e6d01bc 160000 --- a/codegen/python-oas-templates +++ b/codegen/python-oas-templates @@ -1 +1 @@ -Subproject commit b72bd5bf2b03a995f3e78c16bf0483dc756f6506 +Subproject commit 7e6d01bc265c425f9af88496e467c959e63b2117 diff --git a/pinecone/core/openapi/shared/rest.py b/pinecone/core/openapi/shared/rest.py index 2e7a2dcc..609f26ad 100644 --- a/pinecone/core/openapi/shared/rest.py +++ b/pinecone/core/openapi/shared/rest.py @@ -4,7 +4,7 @@ import re import ssl import os -from urllib.parse import urlencode +from urllib.parse import urlencode, quote import urllib3 @@ -182,7 +182,7 @@ def request( if (method != "DELETE") and ("Content-Type" not in headers): headers["Content-Type"] = "application/json" if query_params: - url += "?" + urlencode(query_params) + url += "?" + urlencode(query_params, quote_via=quote) if ("Content-Type" not in headers) or (re.search("json", headers["Content-Type"], re.IGNORECASE)): request_body = None if body is not None: @@ -240,8 +240,10 @@ def request( raise PineconeApiException(status=0, reason=msg) # For `GET`, `HEAD` else: + if query_params: + url += "?" + urlencode(query_params, quote_via=quote) r = self.pool_manager.request( - method, url, fields=query_params, preload_content=_preload_content, timeout=timeout, headers=headers + method, url, preload_content=_preload_content, timeout=timeout, headers=headers ) except urllib3.exceptions.SSLError as e: msg = "{0}\n{1}".format(type(e).__name__, str(e)) diff --git a/tests/integration/data/conftest.py b/tests/integration/data/conftest.py index c69b8d44..b1f8d95c 100644 --- a/tests/integration/data/conftest.py +++ b/tests/integration/data/conftest.py @@ -3,7 +3,7 @@ import time import json from ..helpers import get_environment_var, random_string -from .seed import setup_data, setup_list_data +from .seed import setup_data, setup_list_data, setup_weird_ids_data # Test matrix needs to consider the following dimensions: # - pod vs serverless @@ -60,13 +60,16 @@ def index_name(): @pytest.fixture(scope="session") def namespace(): - # return 'banana' return random_string(10) @pytest.fixture(scope="session") def list_namespace(): - # return 'list-banana' + return random_string(10) + + +@pytest.fixture(scope="session") +def weird_ids_namespace(): return random_string(10) @@ -89,9 +92,12 @@ def index_host(index_name, metric, spec): @pytest.fixture(scope="session", autouse=True) -def seed_data(idx, namespace, index_host, list_namespace): +def seed_data(idx, namespace, index_host, list_namespace, weird_ids_namespace): print("Seeding data in host " + index_host) + print("Seeding data in weird is namespace " + weird_ids_namespace) + setup_weird_ids_data(idx, weird_ids_namespace, True) + print('Seeding list data in namespace "' + list_namespace + '"') setup_list_data(idx, list_namespace, True) diff --git a/tests/integration/data/seed.py b/tests/integration/data/seed.py index fba0fc57..1b3efadf 100644 --- a/tests/integration/data/seed.py +++ b/tests/integration/data/seed.py @@ -1,6 +1,7 @@ from ..helpers import poll_fetch_for_ids_in_namespace from pinecone import Vector from .utils import embedding_values +import itertools def setup_data(idx, target_namespace, wait): @@ -43,3 +44,82 @@ def setup_list_data(idx, target_namespace, wait): if wait: poll_fetch_for_ids_in_namespace(idx, ids=["999"], namespace=target_namespace) + + +def weird_invalid_ids(): + invisible = [ + "⠀", # U+2800 + " ", # U+00A0 + "­", # U+00AD + "឴", # U+17F4 + "᠎", # U+180E + " ", # U+2000 + " ", # U+2001 + " ", # U+2002 + ] + emojis = list("🌲🍦") + two_byte = list("田中さんにあげて下さい") + quotes = ["‘", "’", "“", "”", "„", "‟", "‹", "›", "❛", "❜", "❝", "❞", "❮", "❯", """, "'", "「", "」"] + + return invisible + emojis + two_byte + quotes + + +def weird_valid_ids(): + # Drawing inspiration from the big list of naughty strings https://github.com/minimaxir/big-list-of-naughty-strings/blob/master/blns.txt + ids = [] + + numbers = list("1234567890") + invisible = [" ", "\n", "\t", "\r"] + punctuation = list("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~") + escaped = [f"\\{c}" for c in punctuation] + + characters = numbers + invisible + punctuation + escaped + ids.extend(characters) + ids.extend(["".join(x) for x in itertools.combinations_with_replacement(characters, 2)]) + + boolean_ish = [ + "undefined", + "nil", + "null", + "Null", + "NULL", + "None", + "True", + "False", + "true", + "false", + ] + ids.extend(boolean_ish) + + script_injection = [ + "", + "", + '" onfocus=JaVaSCript:alert(10) autofocus', + "javascript:alert(1)", + "javascript:alert(1);", + '' "1;DROP TABLE users", + "' OR 1=1 -- 1", + "' OR '1'='1", + ] + ids.extend(script_injection) + + unwanted_interpolation = [ + "$HOME", + "$ENV{'HOME'}", + "%d", + "%s", + "%n", + "%x", + "{0}", + ] + ids.extend(unwanted_interpolation) + + return ids + + +def setup_weird_ids_data(idx, target_namespace, wait): + weird_ids = weird_valid_ids() + batch_size = 100 + for i in range(0, len(weird_ids), batch_size): + chunk = weird_ids[i : i + batch_size] + idx.upsert(vectors=[(x, embedding_values(2)) for x in chunk], namespace=target_namespace) diff --git a/tests/integration/data/test_weird_ids.py b/tests/integration/data/test_weird_ids.py new file mode 100644 index 00000000..6a91487e --- /dev/null +++ b/tests/integration/data/test_weird_ids.py @@ -0,0 +1,49 @@ +import pytest +from .seed import weird_valid_ids, weird_invalid_ids + + +class TestHandlingOfWeirdIds: + def test_fetch_weird_ids(self, idx, weird_ids_namespace): + weird_ids = weird_valid_ids() + batch_size = 100 + for i in range(0, len(weird_ids), batch_size): + ids_to_fetch = weird_ids[i : i + batch_size] + results = idx.fetch(ids=ids_to_fetch, namespace=weird_ids_namespace) + assert results.usage["read_units"] > 0 + assert len(results.vectors) == len(ids_to_fetch) + for id in ids_to_fetch: + assert id in results.vectors + assert results.vectors[id].id == id + assert results.vectors[id].metadata == None + assert results.vectors[id].values != None + assert len(results.vectors[id].values) == 2 + + @pytest.mark.parametrize("id_to_query", weird_valid_ids()) + def test_query_weird_ids(self, idx, weird_ids_namespace, id_to_query): + results = idx.query(id=id_to_query, top_k=10, namespace=weird_ids_namespace, include_values=True) + assert results.usage["read_units"] > 0 + assert len(results.matches) == 10 + assert results.namespace == weird_ids_namespace + assert results.matches[0].id != None + assert results.matches[0].metadata == None + assert results.matches[0].values != None + assert len(results.matches[0].values) == 2 + + def test_list_weird_ids(self, idx, weird_ids_namespace): + expected_ids = set(weird_valid_ids()) + id_iterator = idx.list(namespace=weird_ids_namespace) + for page in id_iterator: + for id in page: + assert id in expected_ids + + @pytest.mark.parametrize("id_to_upsert", weird_invalid_ids()) + def test_weird_invalid_ids(self, idx, weird_ids_namespace, id_to_upsert): + with pytest.raises(Exception) as e: + idx.upsert(vectors=[(id_to_upsert, [0.1, 0.1])], namespace=weird_ids_namespace) + assert "Vector ID must be ASCII" in str(e.value) + + def test_null_character(self, idx, weird_ids_namespace): + with pytest.raises(Exception) as e: + idx.upsert(vectors=[("\0", [0.1, 0.1])], namespace=weird_ids_namespace) + + assert "Vector ID must not contain null character" in str(e.value)