diff --git a/spanner/tests/_fixtures.py b/spanner/tests/_fixtures.py index fe7e038b8846..42b9a16ee336 100644 --- a/spanner/tests/_fixtures.py +++ b/spanner/tests/_fixtures.py @@ -47,6 +47,36 @@ name STRING(16), tags ARRAY ) PRIMARY KEY (id); +CREATE TABLE bool_plus_array_of_bool ( + id INT64, + name BOOL, + tags ARRAY ) + PRIMARY KEY (id); +CREATE TABLE bytes_plus_array_of_bytes ( + id INT64, + name BYTES(16), + tags ARRAY ) + PRIMARY KEY (id); +CREATE TABLE date_plus_array_of_date ( + id INT64, + name DATE, + tags ARRAY ) + PRIMARY KEY (id); +CREATE TABLE float_plus_array_of_float ( + id INT64, + name FLOAT64, + tags ARRAY ) + PRIMARY KEY (id); +CREATE TABLE int_plus_array_of_int ( + id INT64, + name INT64, + tags ARRAY ) + PRIMARY KEY (id); +CREATE TABLE time_plus_array_of_time ( + id INT64, + name TIMESTAMP, + tags ARRAY ) + PRIMARY KEY (id); """ DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(';') if stmt.strip()] diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 1f8b7d1ddff9..32e05e4fdd03 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -21,6 +21,8 @@ import time import unittest +from six.moves import range + from google.cloud.spanner_v1.proto.type_pb2 import ARRAY from google.cloud.spanner_v1.proto.type_pb2 import BOOL from google.cloud.spanner_v1.proto.type_pb2 import BYTES @@ -202,6 +204,80 @@ class _TestData(object): ALL = KeySet(all_=True) SQL = 'SELECT * FROM contacts ORDER BY contact_id' + def _test_batch_insert_then_read_arrays_data(self): + table = ('bool_plus_array_of_bool', + 'bytes_plus_array_of_bytes', + 'date_plus_array_of_date', + 'float_plus_array_of_float', + 'int_plus_array_of_int', + 'string_plus_array_of_string', + 'time_plus_array_of_time', + ) + times = (datetime.datetime(1989, 1, 17, 17, 59, 12, 345612), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345613), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345614), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345615), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345616), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345617), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345618), + datetime.datetime(1989, 1, 17, 17, 59, 12, 345619) + ) + dates = (datetime.date(1989, 1, 17), + datetime.date(1989, 1, 18), + datetime.date(1989, 1, 19), + datetime.date(1989, 1, 11), + datetime.date(1989, 1, 12), + datetime.date(1989, 1, 13), + datetime.date(1989, 1, 14), + datetime.date(1989, 1, 15) + ) + columns = ('id', 'name', 'tags') + rowdata = ( + ( + (0, None, None), + (1, True, [True, False, False]), + (2, False, []), + (3, False, [True, None, False]), + ), + ( + (0, None, None), + (1, b'cGhyZWQ=', [b'eWFiYmE=', b'ZGFiYmE=', b'ZG8=']), + (2, b'Ymhhcm5leQ==', []), + (3, b'd3lsbWE=', [b'b2g=', None, b'cGhyZWQ=']) + ), + ( + (0, None, None), + (1, dates[0], [dates[3], dates[4], dates[5]]), + (2, dates[1], []), + (3, dates[2], [dates[6], None, dates[7]]), + ), + ( + (0, None, None), + (1, 10., [40., 50., 60.]), + (2, 20., []), + (3, 30., [70., None, 80.]), + ), + ( + (0, None, None), + (1, 10, [40, 50, 60]), + (2, 20, []), + (3, 30, [70, None, 80]), + ), + ( + (0, None, None), + (1, 'phred', ['yabba', 'dabba', 'do']), + (2, 'bharney', []), + (3, 'wylma', ['oh', None, 'phred']), + ), + ( + (0, None, None), + (1, times[0], [times[3], times[4], times[5]]), + (2, times[1], []), + (3, times[2], [times[6], None, times[7]]), + ) + ) + return table, times, dates, columns, rowdata + def _assert_timestamp(self, value, nano_value): self.assertIsInstance(value, datetime.datetime) self.assertIsNone(value.tzinfo) @@ -219,6 +295,22 @@ def _assert_timestamp(self, value, nano_value): else: self.assertEqual(value.microsecond * 1000, nano_value.nanosecond) + def _assert_list_timestamp_equal(self, found, expected): + for found_cell, expected_cell in zip(found, expected): + try: + self._assert_timestamp(expected_cell, found_cell) + except AssertionError: + if not (found_cell is None and expected_cell is None): + raise AssertionError("Found and expected are not both None" + "and do not compare equal") + + def _list_of_timestamps_or_none(self, timestamps): + for timestamp in timestamps: + if not (isinstance(timestamp, datetime.datetime) + or timestamp is None): + return False + return True + def _check_row_data(self, row_data, expected=None): if expected is None: expected = self.ROW_DATA @@ -232,7 +324,14 @@ def _check_row_data(self, row_data, expected=None): elif isinstance(found_cell, float) and math.isnan(found_cell): self.assertTrue(math.isnan(expected_cell)) else: - self.assertEqual(found_cell, expected_cell) + if (isinstance(found_cell, list) and + isinstance(expected_cell, list) and + self._list_of_timestamps_or_none(expected_cell) + ): + self._assert_list_timestamp_equal(found_cell, + expected_cell) + else: + self.assertEqual(found_cell, expected_cell) class TestDatabaseAPI(unittest.TestCase, _TestData): @@ -281,7 +380,7 @@ def test_create_database(self): def test_update_database_ddl(self): pool = BurstyPool() - temp_db_id = 'temp_db' + temp_db_id = 'temp_db' + unique_resource_id('_') temp_db = Config.INSTANCE.database(temp_db_id, pool=pool) create_op = temp_db.create() self.to_delete.append(temp_db) @@ -450,29 +549,88 @@ def test_batch_insert_then_read(self): rows = list(snapshot.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_row_data(rows) - def test_batch_insert_then_read_string_array_of_string(self): - TABLE = 'string_plus_array_of_string' - COLUMNS = ['id', 'name', 'tags'] - ROWDATA = [ - (0, None, None), - (1, 'phred', ['yabba', 'dabba', 'do']), - (2, 'bharney', []), - (3, 'wylma', ['oh', None, 'phred']), - ] - retry = RetryInstanceState(_has_all_ddl) - retry(self._db.reload)() + def test_batch_insert_then_read_arrays(self): + table, times, dates, columns, rowdata = ( + self._test_batch_insert_then_read_arrays_data() + ) + + for index in range(len(table)): + session = self._db.session() + session.create() + self.to_delete.append(session) + with session.batch() as batch: + batch.delete(table[index], self.ALL) + batch.insert(table[index], columns, rowdata[index]) + + snapshot = session.snapshot(read_timestamp=batch.committed) + rows = list(snapshot.read(table[index], columns, self.ALL)) + self._check_row_data(rows, expected=rowdata[index]) + + def test_invalid_column(self): + table = 'counters' + columns = ('name', 'invalid') + row_data = (('', 0),) session = self._db.session() session.create() self.to_delete.append(session) - + error = "StatusCode.NOT_FOUND, Column not found in table" + with self.assertRaisesRegexp(errors.RetryError, error): + with session.batch() as batch: + batch.insert(table, columns, row_data) + + def test_invalid_table(self): + table = 'invalid' + columns = ('name', 'value') + row_data = (('', 0),) + session = self._db.session() + session.create() + self.to_delete.append(session) + error = "StatusCode.NOT_FOUND, Table not found" + with self.assertRaisesRegexp(errors.RetryError, error): + with session.batch() as batch: + batch.insert(table, columns, row_data) + + def test_invalid_type(self): + table = 'counters' + columns = ('name', 'value') + session = self._db.session() + session.create() + self.to_delete.append(session) + wrong = ((0, ''),) + right = (('', 0),) + error = "StatusCode.FAILED_PRECONDITION, Invalid value for column" with session.batch() as batch: - batch.delete(TABLE, self.ALL) - batch.insert(TABLE, COLUMNS, ROWDATA) - - snapshot = session.snapshot(read_timestamp=batch.committed) - rows = list(snapshot.read(TABLE, COLUMNS, self.ALL)) - self._check_row_data(rows, expected=ROWDATA) + batch.delete(table, self.ALL) + batch.insert(table, columns, right) + with self.assertRaisesRegexp(errors.RetryError, error): + with session.batch() as batch: + batch.delete(table, self.ALL) + batch.insert(table, columns, wrong) + + def test_batch_insert_then_read_random_bytes(self): + import random + import string + random.seed(0) + table = 'bytes_plus_array_of_bytes' + columns = ('id', 'name', 'tags') + session = self._db.session() + session.create() + self.to_delete.append(session) + column_length = 16 + for index in range(column_length): + data = [] + for rand in range(4): + letters = [random.choice(string.ascii_lowercase) + for i in range(index)] + words = ''.join(letters) + data.append(words.encode('base64').strip()) + rowdata = ((1, data[0], [data[1], data[2], data[3]]),) + with session.batch() as batch: + batch.delete(table, self.ALL) + batch.insert(table, columns, rowdata) + rows = list(session.read(table, columns, self.ALL)) + self._check_row_data(rows, expected=rowdata) def test_batch_insert_then_read_all_datatypes(self): retry = RetryInstanceState(_has_all_ddl) @@ -843,7 +1001,8 @@ def test_read_w_index(self): ] pool = BurstyPool() temp_db = Config.INSTANCE.database( - 'test_read_w_index', ddl_statements=DDL_STATEMENTS + EXTRA_DDL, + 'temp_db' + unique_resource_id('_'), + ddl_statements=DDL_STATEMENTS + EXTRA_DDL, pool=pool) operation = temp_db.create() self.to_delete.append(_DatabaseDropper(temp_db))