diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 172a428b..0ef9c865 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import pkg_resources import re @@ -27,7 +28,7 @@ from google.cloud.spanner_v1 import Client from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql import elements -from sqlalchemy import ForeignKeyConstraint, types +from sqlalchemy import ForeignKeyConstraint, types, TypeDecorator, PickleType from sqlalchemy.engine.base import Engine from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.event import listens_for @@ -78,6 +79,35 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None): OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"] +# PickleType that can be used with Spanner. +# Binary values are automatically encoded/decoded to/from base64. +# Usage: +# class User(Base): +# __tablename__ = 'users' +# +# user_id = Column(Integer, primary_key=True) +# username = Column(String(50), nullable=False) +# preferences = Column(PickleType(impl=SpannerPickleType)) +class SpannerPickleType(TypeDecorator): + impl = PickleType + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return base64.standard_b64encode(value) + + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + return base64.standard_b64decode(value) + + return process + + # Spanner-to-SQLAlchemy types map _type_map = { "BOOL": types.Boolean, diff --git a/samples/model.py b/samples/model.py index 65fc4a41..2b231ca6 100644 --- a/samples/model.py +++ b/samples/model.py @@ -32,8 +32,10 @@ Sequence, TextClause, Index, + PickleType, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType class Base(DeclarativeBase): @@ -64,6 +66,9 @@ class Singer(Base): ) birthdate: Mapped[Optional[datetime.date]] = mapped_column(Date, nullable=True) picture: Mapped[Optional[bytes]] = mapped_column(LargeBinary, nullable=True) + preferences: Mapped[Optional[object]] = mapped_column( + PickleType(impl=SpannerPickleType), nullable=True + ) albums: Mapped[List["Album"]] = relationship( back_populates="singer", cascade="all, delete-orphan" ) diff --git a/samples/pickle_type_sample.py b/samples/pickle_type_sample.py new file mode 100644 index 00000000..58159996 --- /dev/null +++ b/samples/pickle_type_sample.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from sample_helper import run_sample +from model import Singer + +# Shows how to use PickleType with Spanner. +def pickle_type(): + engine = create_engine( + "spanner:///projects/sample-project/" + "instances/sample-instance/" + "databases/sample-database", + echo=True, + ) + with Session(engine) as session: + singer = Singer( + id=str(uuid.uuid4()), + first_name="John", + last_name="Smith", + # Preferences are stored as an opaque BYTES column + # in the database. + preferences={ + "wakeup_call": "yes", + "vegetarian": "no", + }, + ) + session.add(singer) + session.commit() + + # Use AUTOCOMMIT for sessions that only read. This is more + # efficient than using a read/write transaction to only read. + session.connection(execution_options={"isolation_level": "AUTOCOMMIT"}) + print( + f"Inserted singer {singer.full_name} has these preferences: {singer.preferences}" + ) + + +if __name__ == "__main__": + run_sample(pickle_type) diff --git a/test/mockserver_tests/pickle_type_model.py b/test/mockserver_tests/pickle_type_model.py new file mode 100644 index 00000000..b3bb47c4 --- /dev/null +++ b/test/mockserver_tests/pickle_type_model.py @@ -0,0 +1,31 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import Column, Integer, String, PickleType +from sqlalchemy.orm import DeclarativeBase + +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType + + +class Base(DeclarativeBase): + pass + + +class UserPreferences(Base): + __tablename__ = "user_preferences" + + user_id = Column(Integer, primary_key=True) + username = Column(String(50), nullable=False) + preferences = Column(PickleType(impl=SpannerPickleType), nullable=True) + created_at = Column(String(30), nullable=False) diff --git a/test/mockserver_tests/test_pickle_type.py b/test/mockserver_tests/test_pickle_type.py new file mode 100644 index 00000000..b4c2e76c --- /dev/null +++ b/test/mockserver_tests/test_pickle_type.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import ( + FixedSizePool, + ResultSet, + BatchCreateSessionsRequest, + ExecuteSqlRequest, + CommitRequest, + BeginTransactionRequest, + TypeCode, +) +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_result, + add_update_count, +) +from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set + + +class TestPickleType(MockServerTestBase): + def test_create_table(self): + from test.mockserver_tests.pickle_type_model import Base + + add_result( + """SELECT true +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_SCHEMA="" AND TABLE_NAME="user_preferences" +LIMIT 1 +""", + ResultSet(), + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + Base.metadata.create_all(engine) + requests = self.database_admin_service.requests + eq_(1, len(requests)) + is_instance_of(requests[0], UpdateDatabaseDdlRequest) + eq_(1, len(requests[0].statements)) + eq_( + "CREATE TABLE user_preferences (\n" + "\tuser_id INT64 NOT NULL GENERATED BY DEFAULT" + " AS IDENTITY (BIT_REVERSED_POSITIVE), \n" + "\tusername STRING(50) NOT NULL, \n" + "\tpreferences BYTES(MAX), \n" + "\tcreated_at STRING(30) NOT NULL\n" + ") PRIMARY KEY (user_id)", + requests[0].statements[0], + ) + + def test_insert_and_query(self): + from test.mockserver_tests.pickle_type_model import UserPreferences + + add_update_count( + "INSERT INTO user_preferences (user_id, username, preferences, created_at) " + "VALUES (@a0, @a1, @a2, @a3)", + 1, + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + preferences = {"setting": "true"} + preferences_base64 = "gAWVFQAAAAAAAAB9lIwHc2V0dGluZ5SMBHRydWWUcy4=" + with Session(engine) as session: + new_user = UserPreferences( + user_id=1, + username="test_user", + preferences=preferences, + created_at="2025-05-04T00:00:00.000000", + ) + + session.add(new_user) + session.commit() + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(4, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], CommitRequest) + request: ExecuteSqlRequest = requests[2] + eq_(4, len(request.params)) + eq_("1", request.params["a0"]) + eq_("test_user", request.params["a1"]) + eq_(preferences_base64, request.params["a2"]) + eq_(TypeCode.INT64, request.param_types["a0"].code) + eq_(TypeCode.STRING, request.param_types["a1"].code) + eq_(TypeCode.BYTES, request.param_types["a2"].code) + + add_user_preferences_result( + "SELECT user_preferences.user_id AS user_preferences_user_id, " + "user_preferences.username AS user_preferences_username, " + "user_preferences.preferences AS user_preferences_preferences, " + "user_preferences.created_at AS user_preferences_created_at\n" + "FROM user_preferences\n" + "WHERE user_preferences.user_id = @a0\n" + " LIMIT @a1", + preferences_base64, + ) + user = session.query(UserPreferences).filter_by(user_id=1).first() + eq_(preferences, user.preferences) + + +def add_user_preferences_result(sql: str, preferences_base64: object): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="user_id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.INT64) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="user_name", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.STRING) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="preferences", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.BYTES) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="created_at", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.TIMESTAMP) + ), + ) + ), + ] + ) + ) + ) + ), + ) + ) + result.rows.extend( + [ + ( + "1", + "Test", + preferences_base64, + "2025-05-05T00:00:00.000000Z", + ), + ] + ) + add_result(sql, result)