From bf2ce9a905cd3ebaa087e6b67593c252ba509f54 Mon Sep 17 00:00:00 2001 From: Dmitriy Korneev Date: Thu, 22 Nov 2018 13:23:20 +0200 Subject: [PATCH] correctly handle null values in columns --- decryptor/postgresql/packet_handler.go | 3 +++ tests/test.py | 11 ++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/decryptor/postgresql/packet_handler.go b/decryptor/postgresql/packet_handler.go index 2a0d6e1df..cca32bea9 100644 --- a/decryptor/postgresql/packet_handler.go +++ b/decryptor/postgresql/packet_handler.go @@ -125,6 +125,9 @@ type ColumnData struct { // Length return column length converted from LengthBuf func (column *ColumnData) Length() int { + if column.isNull { + return 0 + } return int(binary.BigEndian.Uint32(column.LengthBuf[:])) } diff --git a/tests/test.py b/tests/test.py index 59dfe2367..293c91705 100644 --- a/tests/test.py +++ b/tests/test.py @@ -88,6 +88,7 @@ sa.Column('id', sa.Integer, primary_key=True), sa.Column('data', sa.LargeBinary(length=COLUMN_DATA_SIZE)), sa.Column('raw_data', sa.Text), + sa.Column('nullable_column', sa.Text, nullable=True), ) acrarollback_output_table = sa.Table('acrarollback_output', metadata, @@ -1781,7 +1782,7 @@ def testShutdown(self): result = self.engine1.execute( sa.select([sa.cast(zone, BYTEA), test_table]) .where(test_table.c.id == row_id)) - for zone, _, data, raw_data in result: + for zone, _, data, raw_data, _ in result: self.assertEqual(zone, zone) self.assertEqual(data, poison_record) @@ -1796,7 +1797,7 @@ def testShutdown2(self): result = self.engine1.execute( sa.select([test_table]) .where(test_table.c.id == row_id)) - for _, data, raw_data in result: + for _, data, raw_data, _ in result: self.assertEqual(data, poison_record) def testShutdown3(self): @@ -1809,7 +1810,7 @@ def testShutdown3(self): result = self.engine1.execute( sa.select([test_table])) - for _, data, raw_data in result: + for _, data, raw_data, _ in result: self.assertEqual(data, poison_record) def testShutdown4(self): @@ -1826,7 +1827,7 @@ def testShutdown4(self): result = self.engine1.execute( sa.select([test_table])) - for _, data, raw_data in result: + for _, data, raw_data, _ in result: self.assertEqual(testData, data) @@ -1920,7 +1921,7 @@ def testNoDetect(self): self.assertNotIn('Check poison records', log) result = self.engine1.execute( sa.select([test_table])) - for _, data, raw_data in result: + for _, data, raw_data, _ in result: self.assertEqual(poison_record, data)