diff --git a/README.rst b/README.rst index 7083d34..6b70a97 100644 --- a/README.rst +++ b/README.rst @@ -210,6 +210,27 @@ and `SYSUUID str: def visit_JSON(self, type_: types.TypeEngine[Any], **kw: Any) -> str: return self.visit_NCLOB(type_, **kw) + def visit_REAL_VECTOR(self, type_: hana_types.REAL_VECTOR[Any], **kw: Any) -> str: + # SAP HANA special type + if type_.length is not None: + return f"REAL_VECTOR({type_.length})" + return "REAL_VECTOR" + class HANADDLCompiler(compiler.DDLCompiler): def visit_unique_constraint( @@ -555,6 +561,7 @@ def __init__( use_native_boolean: bool = True, json_serializer: Callable[[Any], str] | None = None, json_deserializer: Callable[[str], Any] | None = None, + vector_output_type: Literal["list", "tuple", "memoryview"] = "list", **kw: Any, ) -> None: super().__init__(**kw) @@ -562,6 +569,7 @@ def __init__( self.supports_native_boolean = use_native_boolean self._json_serializer = json_serializer self._json_deserializer = json_deserializer + self.vector_output_type = vector_output_type @classmethod def import_dbapi(cls) -> ModuleType: @@ -586,6 +594,13 @@ def create_connect_args(self, url: URL) -> ConnectArgsType: port = 30013 kwargs.setdefault("port", port) + if "vectoroutputtype" in kwargs: + raise ValueError( + "Explicit vectoroutputtype is not supported, " + "use the vector_output_type kwarg instead" + ) + kwargs["vectoroutputtype"] = self.vector_output_type + return (), kwargs def connect(self, *args: Any, **kw: Any) -> DBAPIConnection: @@ -917,6 +932,8 @@ def get_columns( column["type"] = hana_types.DECIMAL(row[4], row[5]) elif column["type"] == hana_types.FLOAT: column["type"] = hana_types.FLOAT(row[4]) + elif column["type"] == hana_types.REAL_VECTOR: + column["type"] = hana_types.REAL_VECTOR(row[4] if row[4] > 0 else None) elif column["type"] in self.types_with_length: column["type"] = column["type"](row[4]) diff --git a/sqlalchemy_hana/types.py b/sqlalchemy_hana/types.py index 5d3c589..6bfe003 100644 --- a/sqlalchemy_hana/types.py +++ b/sqlalchemy_hana/types.py @@ -3,11 +3,20 @@ from __future__ import annotations from datetime import date, datetime, time -from typing import Callable, Literal +from typing import TYPE_CHECKING, Callable, Generic, List, Literal, Tuple, TypeVar import sqlalchemy from sqlalchemy import types as sqltypes from sqlalchemy.engine import Dialect +from sqlalchemy.sql.type_api import TypeEngine + +if TYPE_CHECKING: + StrTypeEngine = TypeEngine[str] + +else: + StrTypeEngine = TypeEngine + +_RV = TypeVar("_RV", Tuple[float, ...], List[float], memoryview) class DATE(sqltypes.DATE): @@ -148,6 +157,13 @@ class JSON(sqltypes.JSON): pass +class REAL_VECTOR(TypeEngine[_RV], Generic[_RV]): + __visit_name__ = "REAL_VECTOR" + + def __init__(self, length: int | None = None) -> None: + self.length = length + + __all__ = [ "ALPHANUM", "BIGINT", diff --git a/test/test_dialect.py b/test/test_dialect.py index 19e8a71..357e21c 100644 --- a/test/test_dialect.py +++ b/test/test_dialect.py @@ -86,7 +86,7 @@ def test_parsing_userkey_hdbcli(self) -> None: _, result_kwargs = config.db.dialect.create_connect_args( make_url("hana://userkey=myuserkeyname") ) - assert result_kwargs == {"userkey": "myuserkeyname"} + assert result_kwargs == {"userkey": "myuserkeyname", "vectoroutputtype": "list"} def test_pass_uri_query_as_kwargs(self) -> None: """SQLAlchemy-HANA should passes all URL parameters to hdbcli.""" @@ -216,3 +216,20 @@ def test_do_rollback_to_savepoint_ignores_error(self) -> None: ) as super_rollback: dialect.do_rollback_to_savepoint(connection, "savepoint") super_rollback.assert_not_called() + + def test_vectoroutputtype_is_blocked(self) -> None: + url = "hana://username:secret-password@example.com/?encrypt=true&vectoroutputtype=list" + with pytest.raises(ValueError, match="vectoroutputtype"): + config.db.dialect.create_connect_args(make_url(url)) + + @pytest.mark.parametrize( + "kwargs,vector_output_type", + [ + ({}, "list"), + ({"vector_output_type": "list"}, "list"), + ({"vector_output_type": "memoryview"}, "memoryview"), + ], + ) + def test_vector_output_type(self, kwargs: dict, vector_output_type: str) -> None: + engine = create_engine("hana://username:secret-password@example.com", **kwargs) + assert engine.dialect.vector_output_type == vector_output_type diff --git a/test/test_types.py b/test/test_types.py index 9fe8443..2a066f1 100644 --- a/test/test_types.py +++ b/test/test_types.py @@ -232,6 +232,30 @@ def test_compile(self, connection, metadata) -> None: ) +class RealVectorTest(_TypeBaseTest): + column_type = hana_types.REAL_VECTOR() + data = [1, 2, 3] + + @property + def reflected_column_type(self): + return hana_types.REAL_VECTOR() + + @testing.provide_metadata + def test_reflection_with_length(self): + with testing.db.connect() as connection, connection.begin(): + table = Table( + "t", + self.metadata, + Column("vec1", hana_types.REAL_VECTOR(length=10)), + Column("vec2", hana_types.REAL_VECTOR()), + ) + table.create(bind=connection) + + columns = sqlalchemy.inspect(connection).get_columns("t") + assert columns[0]["type"].length == 10 + assert columns[1]["type"].length is None + + if sqlalchemy.__version__ >= "2": class StringUUIDAsStringTest(_TypeBaseTest):