Skip to content

Commit

Permalink
[Fix] Fetch when vector id string contains spaces (#372)
Browse files Browse the repository at this point in the history
## Problem

Some data operations fail when the vector id string contains a space.

```python
from pinecone import Pinecone

pc = Pinecone()
pc.fetch(ids=["id with string"]) # no results returned, even when vector exists
```

## Solution

The problem occurred due to the way spaces were being encoded as `+`
instead of `%20` in url query params. The fix was a small adjustment to
our code generation templates.

I added test coverage for upsert / query / fetch with various weird ids
to make sure the change in encoding hasn't broken any other use cases
that could pop up.

## Type of Change

- [x] Bug fix (non-breaking change which fixes an issue)
  • Loading branch information
jhamon authored Jul 30, 2024
1 parent d9df375 commit 0f57dca
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 9 deletions.
2 changes: 1 addition & 1 deletion codegen/apis
Submodule apis updated from e9b47c to 062b11
2 changes: 1 addition & 1 deletion codegen/python-oas-templates
8 changes: 5 additions & 3 deletions pinecone/core/openapi/shared/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import ssl
import os
from urllib.parse import urlencode
from urllib.parse import urlencode, quote

import urllib3

Expand Down Expand Up @@ -182,7 +182,7 @@ def request(
if (method != "DELETE") and ("Content-Type" not in headers):
headers["Content-Type"] = "application/json"
if query_params:
url += "?" + urlencode(query_params)
url += "?" + urlencode(query_params, quote_via=quote)
if ("Content-Type" not in headers) or (re.search("json", headers["Content-Type"], re.IGNORECASE)):
request_body = None
if body is not None:
Expand Down Expand Up @@ -240,8 +240,10 @@ def request(
raise PineconeApiException(status=0, reason=msg)
# For `GET`, `HEAD`
else:
if query_params:
url += "?" + urlencode(query_params, quote_via=quote)
r = self.pool_manager.request(
method, url, fields=query_params, preload_content=_preload_content, timeout=timeout, headers=headers
method, url, preload_content=_preload_content, timeout=timeout, headers=headers
)
except urllib3.exceptions.SSLError as e:
msg = "{0}\n{1}".format(type(e).__name__, str(e))
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import json
from ..helpers import get_environment_var, random_string
from .seed import setup_data, setup_list_data
from .seed import setup_data, setup_list_data, setup_weird_ids_data

# Test matrix needs to consider the following dimensions:
# - pod vs serverless
Expand Down Expand Up @@ -60,13 +60,16 @@ def index_name():

@pytest.fixture(scope="session")
def namespace():
# return 'banana'
return random_string(10)


@pytest.fixture(scope="session")
def list_namespace():
# return 'list-banana'
return random_string(10)


@pytest.fixture(scope="session")
def weird_ids_namespace():
return random_string(10)


Expand All @@ -89,9 +92,12 @@ def index_host(index_name, metric, spec):


@pytest.fixture(scope="session", autouse=True)
def seed_data(idx, namespace, index_host, list_namespace):
def seed_data(idx, namespace, index_host, list_namespace, weird_ids_namespace):
print("Seeding data in host " + index_host)

print("Seeding data in weird is namespace " + weird_ids_namespace)
setup_weird_ids_data(idx, weird_ids_namespace, True)

print('Seeding list data in namespace "' + list_namespace + '"')
setup_list_data(idx, list_namespace, True)

Expand Down
80 changes: 80 additions & 0 deletions tests/integration/data/seed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..helpers import poll_fetch_for_ids_in_namespace
from pinecone import Vector
from .utils import embedding_values
import itertools


def setup_data(idx, target_namespace, wait):
Expand Down Expand Up @@ -43,3 +44,82 @@ def setup_list_data(idx, target_namespace, wait):

if wait:
poll_fetch_for_ids_in_namespace(idx, ids=["999"], namespace=target_namespace)


def weird_invalid_ids():
invisible = [
"⠀", # U+2800
" ", # U+00A0
"­", # U+00AD
"឴", # U+17F4
"᠎", # U+180E
" ", # U+2000
" ", # U+2001
" ", # U+2002
]
emojis = list("🌲🍦")
two_byte = list("田中さんにあげて下さい")
quotes = ["‘", "’", "“", "”", "„", "‟", "‹", "›", "❛", "❜", "❝", "❞", "❮", "❯", """, "'", "「", "」"]

return invisible + emojis + two_byte + quotes


def weird_valid_ids():
# Drawing inspiration from the big list of naughty strings https://github.com/minimaxir/big-list-of-naughty-strings/blob/master/blns.txt
ids = []

numbers = list("1234567890")
invisible = [" ", "\n", "\t", "\r"]
punctuation = list("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
escaped = [f"\\{c}" for c in punctuation]

characters = numbers + invisible + punctuation + escaped
ids.extend(characters)
ids.extend(["".join(x) for x in itertools.combinations_with_replacement(characters, 2)])

boolean_ish = [
"undefined",
"nil",
"null",
"Null",
"NULL",
"None",
"True",
"False",
"true",
"false",
]
ids.extend(boolean_ish)

script_injection = [
"<script>alert(0)</script>",
"<svg><script>123<1>alert(3)</script>",
'" onfocus=JaVaSCript:alert(10) autofocus',
"javascript:alert(1)",
"javascript:alert(1);",
'<img src\x32=x onerror="javascript:alert(182)">' "1;DROP TABLE users",
"' OR 1=1 -- 1",
"' OR '1'='1",
]
ids.extend(script_injection)

unwanted_interpolation = [
"$HOME",
"$ENV{'HOME'}",
"%d",
"%s",
"%n",
"%x",
"{0}",
]
ids.extend(unwanted_interpolation)

return ids


def setup_weird_ids_data(idx, target_namespace, wait):
weird_ids = weird_valid_ids()
batch_size = 100
for i in range(0, len(weird_ids), batch_size):
chunk = weird_ids[i : i + batch_size]
idx.upsert(vectors=[(x, embedding_values(2)) for x in chunk], namespace=target_namespace)
49 changes: 49 additions & 0 deletions tests/integration/data/test_weird_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from .seed import weird_valid_ids, weird_invalid_ids


class TestHandlingOfWeirdIds:
def test_fetch_weird_ids(self, idx, weird_ids_namespace):
weird_ids = weird_valid_ids()
batch_size = 100
for i in range(0, len(weird_ids), batch_size):
ids_to_fetch = weird_ids[i : i + batch_size]
results = idx.fetch(ids=ids_to_fetch, namespace=weird_ids_namespace)
assert results.usage["read_units"] > 0
assert len(results.vectors) == len(ids_to_fetch)
for id in ids_to_fetch:
assert id in results.vectors
assert results.vectors[id].id == id
assert results.vectors[id].metadata == None
assert results.vectors[id].values != None
assert len(results.vectors[id].values) == 2

@pytest.mark.parametrize("id_to_query", weird_valid_ids())
def test_query_weird_ids(self, idx, weird_ids_namespace, id_to_query):
results = idx.query(id=id_to_query, top_k=10, namespace=weird_ids_namespace, include_values=True)
assert results.usage["read_units"] > 0
assert len(results.matches) == 10
assert results.namespace == weird_ids_namespace
assert results.matches[0].id != None
assert results.matches[0].metadata == None
assert results.matches[0].values != None
assert len(results.matches[0].values) == 2

def test_list_weird_ids(self, idx, weird_ids_namespace):
expected_ids = set(weird_valid_ids())
id_iterator = idx.list(namespace=weird_ids_namespace)
for page in id_iterator:
for id in page:
assert id in expected_ids

@pytest.mark.parametrize("id_to_upsert", weird_invalid_ids())
def test_weird_invalid_ids(self, idx, weird_ids_namespace, id_to_upsert):
with pytest.raises(Exception) as e:
idx.upsert(vectors=[(id_to_upsert, [0.1, 0.1])], namespace=weird_ids_namespace)
assert "Vector ID must be ASCII" in str(e.value)

def test_null_character(self, idx, weird_ids_namespace):
with pytest.raises(Exception) as e:
idx.upsert(vectors=[("\0", [0.1, 0.1])], namespace=weird_ids_namespace)

assert "Vector ID must not contain null character" in str(e.value)

0 comments on commit 0f57dca

Please sign in to comment.