Skip to content

[Fix] Fetch when vector id string contains spaces #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen/apis
Submodule apis updated from e9b47c to 062b11
2 changes: 1 addition & 1 deletion codegen/python-oas-templates
8 changes: 5 additions & 3 deletions pinecone/core/openapi/shared/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import ssl
import os
from urllib.parse import urlencode
from urllib.parse import urlencode, quote

import urllib3

Expand Down Expand Up @@ -182,7 +182,7 @@ def request(
if (method != "DELETE") and ("Content-Type" not in headers):
headers["Content-Type"] = "application/json"
if query_params:
url += "?" + urlencode(query_params)
url += "?" + urlencode(query_params, quote_via=quote)
if ("Content-Type" not in headers) or (re.search("json", headers["Content-Type"], re.IGNORECASE)):
request_body = None
if body is not None:
Expand Down Expand Up @@ -240,8 +240,10 @@ def request(
raise PineconeApiException(status=0, reason=msg)
# For `GET`, `HEAD`
else:
if query_params:
url += "?" + urlencode(query_params, quote_via=quote)
r = self.pool_manager.request(
method, url, fields=query_params, preload_content=_preload_content, timeout=timeout, headers=headers
method, url, preload_content=_preload_content, timeout=timeout, headers=headers
)
except urllib3.exceptions.SSLError as e:
msg = "{0}\n{1}".format(type(e).__name__, str(e))
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import json
from ..helpers import get_environment_var, random_string
from .seed import setup_data, setup_list_data
from .seed import setup_data, setup_list_data, setup_weird_ids_data

# Test matrix needs to consider the following dimensions:
# - pod vs serverless
Expand Down Expand Up @@ -60,13 +60,16 @@ def index_name():

@pytest.fixture(scope="session")
def namespace():
# return 'banana'
return random_string(10)


@pytest.fixture(scope="session")
def list_namespace():
# return 'list-banana'
return random_string(10)


@pytest.fixture(scope="session")
def weird_ids_namespace():
return random_string(10)


Expand All @@ -89,9 +92,12 @@ def index_host(index_name, metric, spec):


@pytest.fixture(scope="session", autouse=True)
def seed_data(idx, namespace, index_host, list_namespace):
def seed_data(idx, namespace, index_host, list_namespace, weird_ids_namespace):
print("Seeding data in host " + index_host)

print("Seeding data in weird is namespace " + weird_ids_namespace)
setup_weird_ids_data(idx, weird_ids_namespace, True)

print('Seeding list data in namespace "' + list_namespace + '"')
setup_list_data(idx, list_namespace, True)

Expand Down
80 changes: 80 additions & 0 deletions tests/integration/data/seed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..helpers import poll_fetch_for_ids_in_namespace
from pinecone import Vector
from .utils import embedding_values
import itertools


def setup_data(idx, target_namespace, wait):
Expand Down Expand Up @@ -43,3 +44,82 @@ def setup_list_data(idx, target_namespace, wait):

if wait:
poll_fetch_for_ids_in_namespace(idx, ids=["999"], namespace=target_namespace)


def weird_invalid_ids():
invisible = [
"⠀", # U+2800
" ", # U+00A0
"­", # U+00AD
"឴", # U+17F4
"᠎", # U+180E
" ", # U+2000
" ", # U+2001
" ", # U+2002
]
emojis = list("🌲🍦")
two_byte = list("田中さんにあげて下さい")
quotes = ["‘", "’", "“", "”", "„", "‟", "‹", "›", "❛", "❜", "❝", "❞", "❮", "❯", """, "'", "「", "」"]

return invisible + emojis + two_byte + quotes


def weird_valid_ids():
# Drawing inspiration from the big list of naughty strings https://github.com/minimaxir/big-list-of-naughty-strings/blob/master/blns.txt
ids = []

numbers = list("1234567890")
invisible = [" ", "\n", "\t", "\r"]
punctuation = list("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
escaped = [f"\\{c}" for c in punctuation]

characters = numbers + invisible + punctuation + escaped
ids.extend(characters)
ids.extend(["".join(x) for x in itertools.combinations_with_replacement(characters, 2)])

boolean_ish = [
"undefined",
"nil",
"null",
"Null",
"NULL",
"None",
"True",
"False",
"true",
"false",
]
ids.extend(boolean_ish)

script_injection = [
"<script>alert(0)</script>",
"<svg><script>123<1>alert(3)</script>",
'" onfocus=JaVaSCript:alert(10) autofocus',
"javascript:alert(1)",
"javascript:alert(1);",
'<img src\x32=x onerror="javascript:alert(182)">' "1;DROP TABLE users",
"' OR 1=1 -- 1",
"' OR '1'='1",
]
ids.extend(script_injection)

unwanted_interpolation = [
"$HOME",
"$ENV{'HOME'}",
"%d",
"%s",
"%n",
"%x",
"{0}",
]
ids.extend(unwanted_interpolation)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total nit, but it just be simpler to do something like return booleanish + script_injection + unwanted_interpolation, instead of calling ids.extend mult times?


return ids


def setup_weird_ids_data(idx, target_namespace, wait):
weird_ids = weird_valid_ids()
batch_size = 100
for i in range(0, len(weird_ids), batch_size):
chunk = weird_ids[i : i + batch_size]
idx.upsert(vectors=[(x, embedding_values(2)) for x in chunk], namespace=target_namespace)
49 changes: 49 additions & 0 deletions tests/integration/data/test_weird_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from .seed import weird_valid_ids, weird_invalid_ids


class TestHandlingOfWeirdIds:
def test_fetch_weird_ids(self, idx, weird_ids_namespace):
weird_ids = weird_valid_ids()
batch_size = 100
for i in range(0, len(weird_ids), batch_size):
ids_to_fetch = weird_ids[i : i + batch_size]
results = idx.fetch(ids=ids_to_fetch, namespace=weird_ids_namespace)
assert results.usage["read_units"] > 0
assert len(results.vectors) == len(ids_to_fetch)
for id in ids_to_fetch:
assert id in results.vectors
assert results.vectors[id].id == id
assert results.vectors[id].metadata == None
assert results.vectors[id].values != None
assert len(results.vectors[id].values) == 2

@pytest.mark.parametrize("id_to_query", weird_valid_ids())
def test_query_weird_ids(self, idx, weird_ids_namespace, id_to_query):
results = idx.query(id=id_to_query, top_k=10, namespace=weird_ids_namespace, include_values=True)
assert results.usage["read_units"] > 0
assert len(results.matches) == 10
assert results.namespace == weird_ids_namespace
assert results.matches[0].id != None
assert results.matches[0].metadata == None
assert results.matches[0].values != None
assert len(results.matches[0].values) == 2

def test_list_weird_ids(self, idx, weird_ids_namespace):
expected_ids = set(weird_valid_ids())
id_iterator = idx.list(namespace=weird_ids_namespace)
for page in id_iterator:
for id in page:
assert id in expected_ids

@pytest.mark.parametrize("id_to_upsert", weird_invalid_ids())
def test_weird_invalid_ids(self, idx, weird_ids_namespace, id_to_upsert):
with pytest.raises(Exception) as e:
idx.upsert(vectors=[(id_to_upsert, [0.1, 0.1])], namespace=weird_ids_namespace)
assert "Vector ID must be ASCII" in str(e.value)

def test_null_character(self, idx, weird_ids_namespace):
with pytest.raises(Exception) as e:
idx.upsert(vectors=[("\0", [0.1, 0.1])], namespace=weird_ids_namespace)

assert "Vector ID must not contain null character" in str(e.value)
Loading