From ae1f655ea3681b81dae7420a4da921b75ca284bb Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 28 Jul 2017 15:18:07 -0400 Subject: [PATCH] Add factories to ease creation of array / struct parameter types. (#3700) Closes: #3364 --- spanner/google/cloud/spanner/__init__.py | 6 +++ spanner/google/cloud/spanner/types.py | 41 ++++++++++++++++ spanner/tests/unit/test_types.py | 61 ++++++++++++++++++++++++ 3 files changed, 108 insertions(+) create mode 100644 spanner/tests/unit/test_types.py diff --git a/spanner/google/cloud/spanner/__init__.py b/spanner/google/cloud/spanner/__init__.py index 6b9366ab6646a..244bdb868f9a7 100644 --- a/spanner/google/cloud/spanner/__init__.py +++ b/spanner/google/cloud/spanner/__init__.py @@ -27,18 +27,22 @@ from google.cloud.spanner.pool import BurstyPool from google.cloud.spanner.pool import FixedSizePool +from google.cloud.spanner.types import ArrayParamType from google.cloud.spanner.types import BOOL_PARAM_TYPE from google.cloud.spanner.types import BYTES_PARAM_TYPE from google.cloud.spanner.types import DATE_PARAM_TYPE from google.cloud.spanner.types import FLOAT64_PARAM_TYPE from google.cloud.spanner.types import INT64_PARAM_TYPE from google.cloud.spanner.types import STRING_PARAM_TYPE +from google.cloud.spanner.types import StructField +from google.cloud.spanner.types import StructParamType from google.cloud.spanner.types import TIMESTAMP_PARAM_TYPE __all__ = [ '__version__', 'AbstractSessionPool', + 'ArrayParamType', 'BOOL_PARAM_TYPE', 'BYTES_PARAM_TYPE', 'BurstyPool', @@ -50,5 +54,7 @@ 'KeyRange', 'KeySet', 'STRING_PARAM_TYPE', + 'StructField', + 'StructParamType', 'TIMESTAMP_PARAM_TYPE', ] diff --git a/spanner/google/cloud/spanner/types.py b/spanner/google/cloud/spanner/types.py index aa0316ee02b93..9e22da94c51f4 100644 --- a/spanner/google/cloud/spanner/types.py +++ b/spanner/google/cloud/spanner/types.py @@ -25,3 +25,44 @@ FLOAT64_PARAM_TYPE = type_pb2.Type(code=type_pb2.FLOAT64) DATE_PARAM_TYPE = type_pb2.Type(code=type_pb2.DATE) TIMESTAMP_PARAM_TYPE = type_pb2.Type(code=type_pb2.TIMESTAMP) + + +def ArrayParamType(element_type): # pylint: disable=invalid-name + """Construct an array paramter type description protobuf. + + :type element_type: :class:`type_pb2.Type` + :param element_type: the type of elements of the array + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate array-type protobuf + """ + return type_pb2.Type(code=type_pb2.ARRAY, array_element_type=element_type) + + +def StructField(name, field_type): # pylint: disable=invalid-name + """Construct a field description protobuf. + + :type name: str + :param name: the name of the field + + :type field_type: :class:`type_pb2.Type` + :param field_type: the type of the field + + :rtype: :class:`type_pb2.StructType.Field` + :returns: the appropriate array-type protobuf + """ + return type_pb2.StructType.Field(name=name, type=field_type) + + +def StructParamType(fields): # pylint: disable=invalid-name + """Construct a struct paramter type description protobuf. + + :type fields: list of :class:`type_pb2.StructType.Field` + :param fields: the fields of the struct + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate struct-type protobuf + """ + return type_pb2.Type( + code=type_pb2.STRUCT, + struct_type=type_pb2.StructType(fields=fields)) diff --git a/spanner/tests/unit/test_types.py b/spanner/tests/unit/test_types.py new file mode 100644 index 0000000000000..4f30779c757f7 --- /dev/null +++ b/spanner/tests/unit/test_types.py @@ -0,0 +1,61 @@ +# Copyright 2017 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + + +class Test_ArrayParamType(unittest.TestCase): + + def test_it(self): + from google.cloud.proto.spanner.v1 import type_pb2 + from google.cloud.spanner.types import ArrayParamType + from google.cloud.spanner.types import INT64_PARAM_TYPE + + expected = type_pb2.Type( + code=type_pb2.ARRAY, + array_element_type=type_pb2.Type(code=type_pb2.INT64)) + + found = ArrayParamType(INT64_PARAM_TYPE) + + self.assertEqual(found, expected) + + +class Test_Struct(unittest.TestCase): + + def test_it(self): + from google.cloud.proto.spanner.v1 import type_pb2 + from google.cloud.spanner.types import INT64_PARAM_TYPE + from google.cloud.spanner.types import STRING_PARAM_TYPE + from google.cloud.spanner.types import StructParamType + from google.cloud.spanner.types import StructField + + struct_type = type_pb2.StructType(fields=[ + type_pb2.StructType.Field( + name='name', + type=type_pb2.Type(code=type_pb2.STRING)), + type_pb2.StructType.Field( + name='count', + type=type_pb2.Type(code=type_pb2.INT64)), + ]) + expected = type_pb2.Type( + code=type_pb2.STRUCT, + struct_type=struct_type) + + found = StructParamType([ + StructField('name', STRING_PARAM_TYPE), + StructField('count', INT64_PARAM_TYPE), + ]) + + self.assertEqual(found, expected)