Skip to content

Use new lazy_import function to lazily import third-party libraries #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions redisvl/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import sys
from argparse import Namespace

from tabulate import tabulate

from redisvl.cli.utils import add_index_parsing_options, create_redis_url
from redisvl.index import SearchIndex
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.redis.utils import convert_bytes, make_dict
from redisvl.schema.schema import IndexSchema
from redisvl.utils.log import get_logger
from redisvl.utils.utils import lazy_import

logger = get_logger("[RedisVL]")

Expand Down Expand Up @@ -127,6 +126,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:


def _display_in_table(index_info, output_format="rounded_outline"):
tabulate = lazy_import("tabulate")

print("\n")
attributes = index_info.get("attributes", [])
definition = make_dict(index_info.get("index_definition"))
Expand Down
5 changes: 3 additions & 2 deletions redisvl/cli/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import sys
from argparse import Namespace

from tabulate import tabulate

from redisvl.cli.utils import add_index_parsing_options, create_redis_url
from redisvl.index import SearchIndex
from redisvl.schema.schema import IndexSchema
from redisvl.utils.log import get_logger
from redisvl.utils.utils import lazy_import

logger = get_logger("[RedisVL]")

Expand Down Expand Up @@ -87,6 +86,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:


def _display_stats(index_info, output_format="rounded_outline"):
tabulate = lazy_import("tabulate")

# Extracting the statistics
stats_data = [(key, str(index_info.get(key))) for key in STATS_KEYS]

Expand Down
12 changes: 6 additions & 6 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.utils import lazy_import

nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")


class AggregationQuery(AggregateRequest):
Expand Down Expand Up @@ -162,17 +166,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
if not stopwords:
self._stopwords = set()
elif isinstance(stopwords, str):
# Lazy import because nltk is an optional dependency
try:
import nltk
from nltk.corpus import stopwords as nltk_stopwords
nltk.download("stopwords", quiet=True)
self._stopwords = set(nltk_stopwords.words(stopwords))
except ImportError:
raise ValueError(
f"Loading stopwords for {stopwords} failed: nltk is not installed."
)
try:
nltk.download("stopwords", quiet=True)
self._stopwords = set(nltk_stopwords.words(stopwords))
except Exception as e:
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
Expand Down
13 changes: 6 additions & 7 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.utils import denorm_cosine_distance
from redisvl.utils.utils import denorm_cosine_distance, lazy_import

nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")


class BaseQuery(RedisQuery):
Expand Down Expand Up @@ -893,17 +896,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
if not stopwords:
self._stopwords = set()
elif isinstance(stopwords, str):
# Lazy import because nltk is an optional dependency
try:
import nltk
from nltk.corpus import stopwords as nltk_stopwords
nltk.download("stopwords", quiet=True)
self._stopwords = set(nltk_stopwords.words(stopwords))
except ImportError:
raise ValueError(
f"Loading stopwords for {stopwords} failed: nltk is not installed."
)
try:
nltk.download("stopwords", quiet=True)
self._stopwords = set(nltk_stopwords.words(stopwords))
except Exception as e:
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
Expand Down
21 changes: 19 additions & 2 deletions redisvl/redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import hashlib
from typing import Any, Dict, List, Optional

import numpy as np
from ml_dtypes import bfloat16
from redisvl.utils.utils import lazy_import

# Lazy import numpy
np = lazy_import("numpy")

from redisvl.schema.fields import VectorDataType

Expand Down Expand Up @@ -41,6 +43,13 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)

# Special handling for bfloat16 which requires explicit import from ml_dtypes
if dtype.lower() == "bfloat16":
from ml_dtypes import bfloat16

return np.array(array, dtype=bfloat16).tobytes()

return np.array(array, dtype=dtype.lower()).tobytes()


Expand All @@ -52,6 +61,14 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)

# Special handling for bfloat16 which requires explicit import from ml_dtypes
# because otherwise the (lazily imported) numpy is unaware of the type
if dtype.lower() == "bfloat16":
from ml_dtypes import bfloat16

return np.frombuffer(buffer, dtype=bfloat16).tolist() # type: ignore[return-value]

return np.frombuffer(buffer, dtype=dtype.lower()).tolist() # type: ignore[return-value]


Expand Down
4 changes: 3 additions & 1 deletion redisvl/utils/optimize/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Callable, Dict, List

import numpy as np
from redisvl.utils.utils import lazy_import

np = lazy_import("numpy")
from ranx import Qrels, Run, evaluate

from redisvl.extensions.cache.llm.semantic import SemanticCache
Expand Down
4 changes: 3 additions & 1 deletion redisvl/utils/optimize/router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random
from typing import Any, Callable, Dict, List

import numpy as np
from redisvl.utils.utils import lazy_import

np = lazy_import("numpy")
from ranx import Qrels, Run, evaluate

from redisvl.extensions.router.semantic import SemanticRouter
Expand Down
4 changes: 3 additions & 1 deletion redisvl/utils/optimize/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List

import numpy as np
from redisvl.utils.utils import lazy_import

np = lazy_import("numpy")
from ranx import Qrels

from redisvl.utils.optimize.schema import LabeledData
Expand Down
6 changes: 4 additions & 2 deletions redisvl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,17 @@ def __getattr__(self, name: str) -> Any:
else:
# This means we couldn't find the attribute in the module path
raise AttributeError(
f"{self._parts[0]} has no attribute '{self._parts[1]}'"
f"module '{self._parts[0]}' has no attribute '{self._parts[1]}'"
)

# If we have a module, get the requested attribute
if hasattr(self._module, name):
return getattr(self._module, name)

# If the attribute doesn't exist, raise AttributeError
raise AttributeError(f"{self._module_path} has no attribute '{name}'")
raise AttributeError(
f"module '{self._module_path}' has no attribute '{name}'"
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Import the module if it hasn't been imported yet
Expand Down
113 changes: 112 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,15 @@ def test_empty_list_to_bytes():
def test_conversion_with_various_dtypes(dtype):
"""Test conversion of a list of floats to bytes with various dtypes"""
array = [1.0, -2.0, 3.5]
expected = np.array(array, dtype=dtype).tobytes()

# Special handling for bfloat16 which requires explicit import from ml_dtypes
if dtype == "bfloat16":
from ml_dtypes import bfloat16 as bf16

expected = np.array(array, dtype=bf16).tobytes()
else:
expected = np.array(array, dtype=dtype).tobytes()

assert array_to_buffer(array, dtype=dtype) == expected


Expand Down Expand Up @@ -541,6 +549,26 @@ def test_import_standard_library(self):
assert "json" in sys.modules
assert result == '{"key": "value"}'

def test_cached_module_import(self):
"""Test that _import_module returns the cached module if it exists"""
# Remove the module from sys.modules if it's already imported
if "json" in sys.modules:
del sys.modules["json"]

# Lazy import the module
json = lazy_import("json")

# Access an attribute to trigger the import
json.dumps

# The module should now be cached
# We need to access the private _import_module method directly
# to test the cached path
module = json._import_module()

# Verify that the cached module was returned
assert module is json._module

def test_import_already_imported_module(self):
"""Test lazy importing of an already imported module"""
# Make sure the module is imported
Expand Down Expand Up @@ -610,6 +638,17 @@ def test_import_nonexistent_module(self):

assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)

def test_call_nonexistent_module(self):
"""Test calling a nonexistent module"""
# Lazy import a nonexistent module
nonexistent = lazy_import("nonexistent_module_xyz")

# Calling the nonexistent module should raise ImportError
with pytest.raises(ImportError) as excinfo:
nonexistent()

assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)

def test_import_nonexistent_attribute(self):
"""Test lazy importing of a nonexistent attribute"""
# Lazy import a nonexistent attribute
Expand All @@ -623,6 +662,19 @@ def test_import_nonexistent_attribute(self):
excinfo.value
)

def test_getattr_on_nonexistent_attribute_path(self):
"""Test accessing an attribute on a nonexistent attribute path"""
# Lazy import a nonexistent attribute path
nonexistent_attr = lazy_import("math.nonexistent_attribute")

# Accessing an attribute on the nonexistent attribute should raise AttributeError
with pytest.raises(AttributeError) as excinfo:
nonexistent_attr.some_attribute

assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
excinfo.value
)

def test_import_noncallable(self):
"""Test calling a non-callable lazy imported object"""
# Lazy import a non-callable attribute
Expand All @@ -646,3 +698,62 @@ def test_attribute_error(self):
assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
excinfo.value
)

def test_attribute_error_after_import(self):
"""Test accessing a nonexistent attribute on a module after it's been imported"""
# Create a simple module with a known attribute
import types

test_module = types.ModuleType("test_module")
test_module.existing_attr = "exists"

# Add it to sys.modules so lazy_import can find it
sys.modules["test_module"] = test_module

try:
# Lazy import the module
lazy_mod = lazy_import("test_module")

# Access the existing attribute to trigger the import
assert lazy_mod.existing_attr == "exists"

# Now access a nonexistent attribute
with pytest.raises(AttributeError) as excinfo:
lazy_mod.nonexistent_attribute

assert (
"module 'test_module' has no attribute 'nonexistent_attribute'"
in str(excinfo.value)
)
finally:
# Clean up
if "test_module" in sys.modules:
del sys.modules["test_module"]

def test_attribute_error_with_direct_module_access(self):
"""Test accessing a nonexistent attribute by directly setting the _module attribute"""
# Get the lazy_import function
from redisvl.utils.utils import lazy_import

# Create a lazy import for a module that doesn't exist yet
lazy_mod = lazy_import("test_direct_module")

# Create a simple object with no __getattr__ method
class SimpleObject:
pass

obj = SimpleObject()

# Directly set the _module attribute to our simple object
# This bypasses the normal import mechanism
lazy_mod._module = obj

# Now access a nonexistent attribute
# This should go through our LazyModule.__getattr__ and hit line 332
with pytest.raises(AttributeError) as excinfo:
lazy_mod.nonexistent_attribute

assert (
"module 'test_direct_module' has no attribute 'nonexistent_attribute'"
in str(excinfo.value)
)