diff --git a/flow/record/fieldtypes/net/__init__.py b/flow/record/fieldtypes/net/__init__.py index 524ca23..5c46c05 100644 --- a/flow/record/fieldtypes/net/__init__.py +++ b/flow/record/fieldtypes/net/__init__.py @@ -3,7 +3,6 @@ from flow.record.fieldtypes import string from flow.record.fieldtypes.net.ip import ( IPAddress, - IPInterface, IPNetwork, ipaddress, ipinterface, @@ -12,7 +11,6 @@ __all__ = [ "IPAddress", - "IPInterface", "IPNetwork", "ipaddress", "ipinterface", diff --git a/flow/record/fieldtypes/net/ip.py b/flow/record/fieldtypes/net/ip.py index 20e05e2..a202550 100644 --- a/flow/record/fieldtypes/net/ip.py +++ b/flow/record/fieldtypes/net/ip.py @@ -19,16 +19,18 @@ _IPNetwork = Union[IPv4Network, IPv6Network] _IPAddress = Union[IPv4Address, IPv6Address] _IPInterface = Union[IPv4Interface, IPv6Interface] +_ConversionTypes = Union[str, int, bytes] +_IP = Union[_IPNetwork, _IPAddress, _IPInterface] class ipaddress(FieldType): - val = None + val: _IPAddress = None _type = "net.ipaddress" - def __init__(self, addr: str | int | bytes): + def __init__(self, addr: _ConversionTypes | _IPAddress): self.val = ip_address(addr) - def __eq__(self, b: str | int | bytes | _IPAddress) -> bool: + def __eq__(self, b: _ConversionTypes | _IPAddress) -> bool: try: return self.val == ip_address(b) except ValueError: @@ -57,13 +59,13 @@ def _unpack(data: int) -> ipaddress: class ipnetwork(FieldType): - val = None + val: _IPNetwork = None _type = "net.ipnetwork" - def __init__(self, addr: str | int | bytes): + def __init__(self, addr: _ConversionTypes | _IPNetwork): self.val = ip_network(addr) - def __eq__(self, b: str | int | bytes | _IPNetwork) -> bool: + def __eq__(self, b: _ConversionTypes | _IPNetwork) -> bool: try: return self.val == ip_network(b) except ValueError: @@ -108,13 +110,13 @@ def netmask(self) -> ipaddress: class ipinterface(FieldType): - val = None + val: _IPInterface = None _type = "net.ipinterface" - def __init__(self, addr: int) -> None: + def __init__(self, addr: _ConversionTypes | _IP) -> None: self.val = ip_interface(addr) - def __eq__(self, b: str | int | bytes | _IPInterface) -> bool: + def __eq__(self, b: _ConversionTypes | _IP) -> bool: try: return self.val == ip_interface(b) except ValueError: @@ -138,8 +140,8 @@ def network(self) -> ipnetwork: return ipnetwork(self.val.network) @property - def netmask(self) -> ipnetwork: - return self.network.netmask + def netmask(self) -> ipaddress: + return ipaddress(self.val.netmask) def _pack(self) -> str: return self.val.compressed @@ -151,7 +153,5 @@ def _unpack(data: str) -> ipinterface: # alias: net.IPAddress -> net.ipaddress # alias: net.IPNetwork -> net.ipnetwork -# alias: net.IPInterface -> net.ipinterface IPAddress = ipaddress IPNetwork = ipnetwork -IPInterface = ipinterface diff --git a/tests/test_fieldtype_ip.py b/tests/test_fieldtype_ip.py index d39c62e..a1c224e 100644 --- a/tests/test_fieldtype_ip.py +++ b/tests/test_fieldtype_ip.py @@ -162,6 +162,17 @@ def test_record_ipinterface() -> None: assert r.interface == "::1" assert r.interface == "::1/128" + r = TestRecord("64:ff9b::2/96") + assert r.interface == "64:ff9b::2/96" + assert r.interface.ip == "64:ff9b::2" + assert r.interface.network == "64:ff9b::/96" + assert r.interface.netmask == "ffff:ffff:ffff:ffff:ffff:ffff::" + + # instantiate from different types + assert TestRecord(1).interface == "0.0.0.1/32" + assert TestRecord(0x7F0000FF).interface == "127.0.0.255/32" + assert TestRecord(b"\x7f\xff\xff\xff").interface == "127.255.255.255/32" + # Test whether it functions in a set data = {TestRecord(x).interface for x in ["192.168.0.0/24", "192.168.0.0/24", "::1", "::1"]} assert len(data) == 2 @@ -170,6 +181,32 @@ def test_record_ipinterface() -> None: assert "::1" not in data +def test_record_ipinterface_types() -> None: + TestRecord = RecordDescriptor( + "test/ipinterface", + [ + ( + "net.ipinterface", + "interface", + ) + ], + ) + + r = TestRecord("192.168.0.255/24") + _if = r.interface + assert isinstance(_if, net.ipinterface) + assert isinstance(_if.ip, net.ipaddress) + assert isinstance(_if.network, net.ipnetwork) + assert isinstance(_if.netmask, net.ipaddress) + + r = TestRecord("64:ff9b::/96") + _if = r.interface + assert isinstance(_if, net.ipinterface) + assert isinstance(_if.ip, net.ipaddress) + assert isinstance(_if.network, net.ipnetwork) + assert isinstance(_if.netmask, net.ipaddress) + + @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) def test_selector_ipaddress(PSelector: type[Selector]) -> None: TestRecord = RecordDescriptor(