diff --git a/gcp_flowlogs_reader/gcp_flowlogs_reader.py b/gcp_flowlogs_reader/gcp_flowlogs_reader.py index 4260002..f7235dc 100644 --- a/gcp_flowlogs_reader/gcp_flowlogs_reader.py +++ b/gcp_flowlogs_reader/gcp_flowlogs_reader.py @@ -62,6 +62,12 @@ class ResourceLabels(NamedTuple): subnetwork_name: str +def safe_tuple_from_dict(cls, attrs): + attr_payload = {k: attrs[k] for k in cls._fields} + return cls(**attr_payload) + + + class FlowRecord: src_ip: Union[IPv4Address, IPv6Address] src_port: int @@ -115,9 +121,8 @@ def __init__(self, entry: StructEntry): ('dest_location', GeographicDetails), ]: try: - attr_payload = {k: flow_payload[k] for k in cls._fields} - value = cls(**attr_payload) - except (KeyError, TypeError): + value = safe_tuple_from_dict(cls, flow_payload[attr]) + except (KeyError, TypeError) as e: setattr(self, attr, None) else: setattr(self, attr, value) diff --git a/tests/test_gcp_flowlogs_reader.py b/tests/test_gcp_flowlogs_reader.py index f1699c2..1aa924c 100644 --- a/tests/test_gcp_flowlogs_reader.py +++ b/tests/test_gcp_flowlogs_reader.py @@ -23,6 +23,8 @@ GeographicDetails, ResourceLabels, ) +from gcp_flowlogs_reader.gcp_flowlogs_reader import safe_tuple_from_dict + PREFIX = 'gcp_flowlogs_reader.gcp_flowlogs_reader.{}'.format SAMPLE_PAYLOADS = [ @@ -222,10 +224,10 @@ def test_init_outbound(self): ('rtt_msec', 61), ('reporter', 'DEST'), ('src_instance', None), - ('dest_instance', InstanceDetails(**SAMPLE_PAYLOADS[0]['dest_instance'])), + ('dest_instance', safe_tuple_from_dict(InstanceDetails, SAMPLE_PAYLOADS[0]['dest_instance'])), ('src_vpc', None), - ('dest_vpc', VpcDetails(**SAMPLE_PAYLOADS[0]['dest_vpc'])), - ('src_location', GeographicDetails(**SAMPLE_PAYLOADS[0]['src_location'])), + ('dest_vpc', safe_tuple_from_dict(VpcDetails, SAMPLE_PAYLOADS[0]['dest_vpc'])), + ('src_location', safe_tuple_from_dict(GeographicDetails, SAMPLE_PAYLOADS[0]['src_location'])), ('dest_location', None), ]: with self.subTest(attr=attr): @@ -247,12 +249,12 @@ def test_init_inbound(self): ('packets_sent', 6), ('rtt_msec', None), ('reporter', 'SRC'), - ('src_instance', InstanceDetails(**SAMPLE_PAYLOADS[1]['src_instance'])), + ('src_instance', safe_tuple_from_dict(InstanceDetails, SAMPLE_PAYLOADS[1]['src_instance'])), ('dest_instance', None), - ('src_vpc', VpcDetails(**SAMPLE_PAYLOADS[1]['src_vpc'])), + ('src_vpc', safe_tuple_from_dict(VpcDetails, SAMPLE_PAYLOADS[1]['src_vpc'])), ('dest_vpc', None), ('src_location', None), - ('dest_location', GeographicDetails(**SAMPLE_PAYLOADS[1]['dest_location'])), + ('dest_location', safe_tuple_from_dict(GeographicDetails, SAMPLE_PAYLOADS[1]['dest_location'])), ]: with self.subTest(attr=attr): actual = getattr(flow_record, attr) @@ -319,6 +321,8 @@ def test_to_dict(self): ]: with self.subTest(attr=attr): actual = flow_dict[attr] + if isinstance(expected, dict): + expected = {k: v for k, v in expected.items() if k != 'subnetwork_region'} self.assertEqual(actual, expected) def test_from_payload(self): @@ -498,7 +502,7 @@ def test_multiple_projects( actual = list(reader) expected = [FlowRecord(x) for x in SAMPLE_ENTRIES] self.assertEqual(actual, expected) - self.assertEqual(reader.bytes_processed, 576) + self.assertEqual(reader.bytes_processed, 544) # Test the client getting called correctly with multiple projects expression = (