Skip to content

Commit

Permalink
Avoid race condition in path check
Browse files Browse the repository at this point in the history
File could be deleted between check and use
  • Loading branch information
twiggler committed Sep 17, 2024
1 parent c3d4f61 commit 9499af4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
13 changes: 3 additions & 10 deletions dissect/target/loaders/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,6 @@ def case(value: str) -> str:
raise argparse.ArgumentTypeError(f"Invalid case name specified: '{value}'")


def file_exists(file_path: str) -> str:
"""Checks if the file exists."""
if not os.path.isfile(file_path):
raise argparse.ArgumentTypeError(f"File does not exist: '{file_path}'")
return file_path


@arg("--mqtt-peers", type=strictly_positive, dest="peers", help="minimum number of peers to await for first alias")
@arg(
"--mqtt-case",
Expand All @@ -563,9 +556,9 @@ def file_exists(file_path: str) -> str:
)
@arg("--mqtt-port", type=port, dest="port", help="broker connection port")
@arg("--mqtt-broker", type=host_name_or_ip_address, dest="broker", help="broker ip-address")
@arg("--mqtt-key", type=file_exists, dest="key", help="private key file")
@arg("--mqtt-crt", type=file_exists, dest="crt", help="client certificate file")
@arg("--mqtt-ca", type=file_exists, dest="ca", help="certificate authority file")
@arg("--mqtt-key", type=Path, dest="key", help="private key file")
@arg("--mqtt-crt", type=Path, dest="crt", help="client certificate file")
@arg("--mqtt-ca", type=Path, dest="ca", help="certificate authority file")
@arg("--mqtt-command", dest="command", help="direct command to client(s)")
@arg("--mqtt-diag", action="store_true", dest="diag", help="show MQTT diagnostic information")
@arg("--mqtt-username", dest="username", help="Username for connection")
Expand Down
8 changes: 4 additions & 4 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def generate_longest_valid_hostname():
("example-label.com", True),
("example..com", False),
(generate_longest_valid_hostname(), True),
], ids=[
],
ids=[
"valid_domain",
"valid_domain_with_trailing_dot",
"invalid_double_dot",
Expand All @@ -183,9 +184,8 @@ def generate_longest_valid_hostname():
"invalid_end_hyphen",
"valid_domain_with_hyphen",
"invalid_empty_label",
"valid_max_length"
]
"valid_max_length",
],
)
def test_host_name_parser(hostname, is_valid_hostname) -> None:
assert host_name(hostname) == is_valid_hostname

0 comments on commit 9499af4

Please sign in to comment.