Skip to content

Commit

Permalink
Add http and https to address validation regex (#1800)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Apr 18, 2023
1 parent da1be20 commit bea5dc8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/py/flwr/common/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
IPV6: int = 6

DOMAIN_PATTERN: Pattern[str] = re.compile(
r"^(localhost)|(([a-zA-Z]{1})|([a-zA-Z]{1}[a-zA-Z]{1})|"
r"^(https?:\/\/)?(localhost)|(([a-zA-Z]{1})|([a-zA-Z]{1}[a-zA-Z]{1})|"
r"([a-zA-Z]{1}[0-9]{1})|([0-9]{1}[a-zA-Z]{1})|"
r"([a-zA-Z0-9][-_.a-zA-Z0-9]{0,61}[a-zA-Z0-9]))\."
r"([a-zA-Z]{2,13}|[a-zA-Z0-9-]{2,30}.[a-zA-Z]{2,3})$"
)


def parse_address(address: str) -> Optional[Tuple[str, int, bool]]:
def parse_address(address: str) -> Optional[Tuple[str, int, Optional[bool]]]:
"""Parses an IP address into a host, port, and version.
Parameters
Expand All @@ -52,8 +52,9 @@ def parse_address(address: str) -> Optional[Tuple[str, int, bool]]:
raw_host, _, raw_port = address.rpartition(":")

if DOMAIN_PATTERN.match(raw_host):
print(raw_host)
host = raw_host
version = False
version = None
else:
host = raw_host.translate({ord(i): None for i in "[]"})
version = ip_address(host).version == IPV6
Expand Down
49 changes: 43 additions & 6 deletions src/py/flwr/common/address_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def test_ipv4_correct() -> None:
("127.0.0.1:8080", ("127.0.0.1", 8080, False)),
("0.0.0.0:12", ("0.0.0.0", 12, False)),
("0.0.0.0:65535", ("0.0.0.0", 65535, False)),
("flower.dev:123", ("flower.dev", 123, False)),
("flower.dev:123", ("flower.dev", 123, False)),
("sub.flower.dev:123", ("sub.flower.dev", 123, False)),
("sub2.sub1.flower.dev:123", ("sub2.sub1.flower.dev", 123, False)),
("s5.s4.s3.s2.s1.flower.dev:123", ("s5.s4.s3.s2.s1.flower.dev", 123, False)),
("localhost:123", ("localhost", 123, False)),
]

for address, expected in addresses:
Expand Down Expand Up @@ -119,3 +113,46 @@ def test_ipv6_incorrect() -> None:

# Assert
assert actual is None


def test_domain_correct() -> None:
"""Test if a correct domain address is correctly parsed."""

# Prepare
addresses = [
("flower.dev:123", ("flower.dev", 123, None)),
("flower.dev:123", ("flower.dev", 123, None)),
("sub.flower.dev:123", ("sub.flower.dev", 123, None)),
("sub2.sub1.flower.dev:123", ("sub2.sub1.flower.dev", 123, None)),
("s5.s4.s3.s2.s1.flower.dev:123", ("s5.s4.s3.s2.s1.flower.dev", 123, None)),
("localhost:123", ("localhost", 123, None)),
("https://localhost:123", ("https://localhost", 123, None)),
("http://localhost:123", ("http://localhost", 123, None)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


def test_domain_incorrect() -> None:
"""Test if an incorrect domain address returns None."""

# Prepare
addresses = [
"flower.dev::8080",
"flower.dev/html/index.html:12",
"http://fl:50",
"flower.dev",
"flower.dev:65536",
]

for address in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual is None

0 comments on commit bea5dc8

Please sign in to comment.