Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Miauwkeru committed Feb 10, 2025
1 parent 88259ec commit 407ac03
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
2 changes: 0 additions & 2 deletions flow/record/fieldtypes/net/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from flow.record.fieldtypes import string
from flow.record.fieldtypes.net.ip import (
IPAddress,
IPInterface,
IPNetwork,
ipaddress,
ipinterface,
Expand All @@ -12,7 +11,6 @@

__all__ = [
"IPAddress",
"IPInterface",
"IPNetwork",
"ipaddress",
"ipinterface",
Expand Down
26 changes: 13 additions & 13 deletions flow/record/fieldtypes/net/ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@
_IPNetwork = Union[IPv4Network, IPv6Network]
_IPAddress = Union[IPv4Address, IPv6Address]
_IPInterface = Union[IPv4Interface, IPv6Interface]
_ConversionTypes = Union[str, int, bytes]
_IP = Union[_IPNetwork, _IPAddress, _IPInterface]


class ipaddress(FieldType):
val = None
val: _IPAddress = None
_type = "net.ipaddress"

def __init__(self, addr: str | int | bytes):
def __init__(self, addr: _ConversionTypes | _IPAddress):
self.val = ip_address(addr)

def __eq__(self, b: str | int | bytes | _IPAddress) -> bool:
def __eq__(self, b: _ConversionTypes | _IPAddress) -> bool:
try:
return self.val == ip_address(b)
except ValueError:
Expand Down Expand Up @@ -57,13 +59,13 @@ def _unpack(data: int) -> ipaddress:


class ipnetwork(FieldType):
val = None
val: _IPNetwork = None
_type = "net.ipnetwork"

def __init__(self, addr: str | int | bytes):
def __init__(self, addr: _ConversionTypes | _IPNetwork):
self.val = ip_network(addr)

def __eq__(self, b: str | int | bytes | _IPNetwork) -> bool:
def __eq__(self, b: _ConversionTypes | _IPNetwork) -> bool:
try:
return self.val == ip_network(b)
except ValueError:
Expand Down Expand Up @@ -108,13 +110,13 @@ def netmask(self) -> ipaddress:


class ipinterface(FieldType):
val = None
val: _IPInterface = None
_type = "net.ipinterface"

def __init__(self, addr: int) -> None:
def __init__(self, addr: _ConversionTypes | _IP) -> None:
self.val = ip_interface(addr)

def __eq__(self, b: str | int | bytes | _IPInterface) -> bool:
def __eq__(self, b: _ConversionTypes | _IP) -> bool:
try:
return self.val == ip_interface(b)
except ValueError:
Expand All @@ -138,8 +140,8 @@ def network(self) -> ipnetwork:
return ipnetwork(self.val.network)

@property
def netmask(self) -> ipnetwork:
return self.network.netmask
def netmask(self) -> ipaddress:
return ipaddress(self.val.netmask)

def _pack(self) -> str:
return self.val.compressed
Expand All @@ -151,7 +153,5 @@ def _unpack(data: str) -> ipinterface:

# alias: net.IPAddress -> net.ipaddress
# alias: net.IPNetwork -> net.ipnetwork
# alias: net.IPInterface -> net.ipinterface
IPAddress = ipaddress
IPNetwork = ipnetwork
IPInterface = ipinterface
37 changes: 37 additions & 0 deletions tests/test_fieldtype_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ def test_record_ipinterface() -> None:
assert r.interface == "::1"
assert r.interface == "::1/128"

r = TestRecord("64:ff9b::2/96")
assert r.interface == "64:ff9b::2/96"
assert r.interface.ip == "64:ff9b::2"
assert r.interface.network == "64:ff9b::/96"
assert r.interface.netmask == "ffff:ffff:ffff:ffff:ffff:ffff::"

# instantiate from different types
assert TestRecord(1).interface == "0.0.0.1/32"
assert TestRecord(0x7F0000FF).interface == "127.0.0.255/32"
assert TestRecord(b"\x7f\xff\xff\xff").interface == "127.255.255.255/32"

# Test whether it functions in a set
data = {TestRecord(x).interface for x in ["192.168.0.0/24", "192.168.0.0/24", "::1", "::1"]}
assert len(data) == 2
Expand All @@ -170,6 +181,32 @@ def test_record_ipinterface() -> None:
assert "::1" not in data


def test_record_ipinterface_types() -> None:
TestRecord = RecordDescriptor(
"test/ipinterface",
[
(
"net.ipinterface",
"interface",
)
],
)

r = TestRecord("192.168.0.255/24")
_if = r.interface
assert isinstance(_if, net.ipinterface)
assert isinstance(_if.ip, net.ipaddress)
assert isinstance(_if.network, net.ipnetwork)
assert isinstance(_if.netmask, net.ipaddress)

r = TestRecord("64:ff9b::/96")
_if = r.interface
assert isinstance(_if, net.ipinterface)
assert isinstance(_if.ip, net.ipaddress)
assert isinstance(_if.network, net.ipnetwork)
assert isinstance(_if.netmask, net.ipaddress)


@pytest.mark.parametrize("PSelector", [Selector, CompiledSelector])
def test_selector_ipaddress(PSelector: type[Selector]) -> None:
TestRecord = RecordDescriptor(
Expand Down

0 comments on commit 407ac03

Please sign in to comment.