Skip to content

Commit

Permalink
Add REAL_VECTOR support
Browse files Browse the repository at this point in the history
  • Loading branch information
kasium committed Sep 10, 2024
1 parent 8aa631a commit d368014
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 4 deletions.
21 changes: 21 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ and `SYSUUID <https://help.sap.com/docs/hana-cloud-database/sap-hana-cloud-sap-h
) which can be used to generate e.g. default values like
``Column('id', Uuid, server_default=func.NEWUID)``.

The ``REAL_VECTOR`` datatype is only supported within SAP HANA and needs to be imported from
``sqlalchemy_hana.types``. See below for more details.

Real Vector
~~~~~~~~~~~
By default, vectors are represented using a python ``list``.
This can be changed using the engine parameter ``vector_output_type``, which can be set to
``list`` (default), ``tuple`` or ``memoryview``.
Note that this setting is applied globally and cannot be adapted on a column basis.

For proper typing, the ``REAL_VECTOR`` class is generic and be set to the proper type like

.. code-block:: python
from sqlalchemy_hana.types import REAL_VECTOR
Column("v1", REAL_VECTOR[list[float]](length=10))
Please note, that the generic type and ``vector_output_type`` should be kept in sync; this is not
enforced.

Regex
~~~~~
sqlalchemy-hana supports the ``regexp_match`` and ``regexp_replace``
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ disable = [
"too-many-branches",
"too-many-return-statements",
"too-many-boolean-expressions",
"too-many-arguments",
"duplicate-code",
]

[tool.pylint.basic]
good-names = ["visit_TINYINT", "visit_SMALLDECIMAL", "visit_SECONDDATE", "visit_ALPHANUM", "visit_JSON"]
good-names = ["visit_TINYINT", "visit_SMALLDECIMAL", "visit_SECONDDATE", "visit_ALPHANUM", "visit_JSON", "visit_REAL_VECTOR", "REAL_VECTOR"]

[tool.mypy]
# formatting
Expand Down
19 changes: 18 additions & 1 deletion sqlalchemy_hana/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from contextlib import closing
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

import hdbcli.dbapi
import sqlalchemy
Expand Down Expand Up @@ -383,6 +383,12 @@ def visit_uuid(self, type_: types.TypeEngine[Any], **kw: Any) -> str:
def visit_JSON(self, type_: types.TypeEngine[Any], **kw: Any) -> str:
return self.visit_NCLOB(type_, **kw)

def visit_REAL_VECTOR(self, type_: hana_types.REAL_VECTOR[Any], **kw: Any) -> str:
# SAP HANA special type
if type_.length is not None:
return f"REAL_VECTOR({type_.length})"
return "REAL_VECTOR"


class HANADDLCompiler(compiler.DDLCompiler):
def visit_unique_constraint(
Expand Down Expand Up @@ -555,13 +561,15 @@ def __init__(
use_native_boolean: bool = True,
json_serializer: Callable[[Any], str] | None = None,
json_deserializer: Callable[[str], Any] | None = None,
vector_output_type: Literal["list", "tuple", "memoryview"] = "list",
**kw: Any,
) -> None:
super().__init__(**kw)
self.isolation_level = isolation_level
self.supports_native_boolean = use_native_boolean
self._json_serializer = json_serializer
self._json_deserializer = json_deserializer
self.vector_output_type = vector_output_type

@classmethod
def import_dbapi(cls) -> ModuleType:
Expand All @@ -586,6 +594,13 @@ def create_connect_args(self, url: URL) -> ConnectArgsType:
port = 30013
kwargs.setdefault("port", port)

if "vectoroutputtype" in kwargs:
raise ValueError(
"Explicit vectoroutputtype is not supported, "
"use the vector_output_type kwarg instead"
)
kwargs["vectoroutputtype"] = self.vector_output_type

return (), kwargs

def connect(self, *args: Any, **kw: Any) -> DBAPIConnection:
Expand Down Expand Up @@ -917,6 +932,8 @@ def get_columns(
column["type"] = hana_types.DECIMAL(row[4], row[5])
elif column["type"] == hana_types.FLOAT:
column["type"] = hana_types.FLOAT(row[4])
elif column["type"] == hana_types.REAL_VECTOR:
column["type"] = hana_types.REAL_VECTOR(row[4] if row[4] > 0 else None)
elif column["type"] in self.types_with_length:
column["type"] = column["type"](row[4])

Expand Down
18 changes: 17 additions & 1 deletion sqlalchemy_hana/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
from __future__ import annotations

from datetime import date, datetime, time
from typing import Callable, Literal
from typing import TYPE_CHECKING, Callable, Generic, List, Literal, Tuple, TypeVar

import sqlalchemy
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.type_api import TypeEngine

if TYPE_CHECKING:
StrTypeEngine = TypeEngine[str]

else:
StrTypeEngine = TypeEngine

_RV = TypeVar("_RV", Tuple[float, ...], List[float], memoryview)


class DATE(sqltypes.DATE):
Expand Down Expand Up @@ -148,6 +157,13 @@ class JSON(sqltypes.JSON):
pass


class REAL_VECTOR(TypeEngine[_RV], Generic[_RV]):
__visit_name__ = "REAL_VECTOR"

def __init__(self, length: int | None = None) -> None:
self.length = length


__all__ = [
"ALPHANUM",
"BIGINT",
Expand Down
19 changes: 18 additions & 1 deletion test/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_parsing_userkey_hdbcli(self) -> None:
_, result_kwargs = config.db.dialect.create_connect_args(
make_url("hana://userkey=myuserkeyname")
)
assert result_kwargs == {"userkey": "myuserkeyname"}
assert result_kwargs == {"userkey": "myuserkeyname", "vectoroutputtype": "list"}

def test_pass_uri_query_as_kwargs(self) -> None:
"""SQLAlchemy-HANA should passes all URL parameters to hdbcli."""
Expand Down Expand Up @@ -216,3 +216,20 @@ def test_do_rollback_to_savepoint_ignores_error(self) -> None:
) as super_rollback:
dialect.do_rollback_to_savepoint(connection, "savepoint")
super_rollback.assert_not_called()

def test_vectoroutputtype_is_blocked(self) -> None:
url = "hana://username:secret-password@example.com/?encrypt=true&vectoroutputtype=list"
with pytest.raises(ValueError, match="vectoroutputtype"):
config.db.dialect.create_connect_args(make_url(url))

@pytest.mark.parametrize(
"kwargs,vector_output_type",
[
({}, "list"),
({"vector_output_type": "list"}, "list"),
({"vector_output_type": "memoryview"}, "memoryview"),
],
)
def test_vector_output_type(self, kwargs: dict, vector_output_type: str) -> None:
engine = create_engine("hana://username:secret-password@example.com", **kwargs)
assert engine.dialect.vector_output_type == vector_output_type
24 changes: 24 additions & 0 deletions test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,30 @@ def test_compile(self, connection, metadata) -> None:
)


class RealVectorTest(_TypeBaseTest):
column_type = hana_types.REAL_VECTOR()
data = [1, 2, 3]

@property
def reflected_column_type(self):
return hana_types.REAL_VECTOR()

@testing.provide_metadata
def test_reflection_with_length(self):
with testing.db.connect() as connection, connection.begin():
table = Table(
"t",
self.metadata,
Column("vec1", hana_types.REAL_VECTOR(length=10)),
Column("vec2", hana_types.REAL_VECTOR()),
)
table.create(bind=connection)

columns = sqlalchemy.inspect(connection).get_columns("t")
assert columns[0]["type"].length == 10
assert columns[1]["type"].length is None


if sqlalchemy.__version__ >= "2":

class StringUUIDAsStringTest(_TypeBaseTest):
Expand Down

0 comments on commit d368014

Please sign in to comment.