Skip to content

Commit

Permalink
[datasets] Tighten the requirements of benchmark name regex.
Browse files Browse the repository at this point in the history
Do not permit that the benchmark name can be missing for a benchmark
URI to be considered well formed, as we no longer support dataset-only
URIs.

Issue facebookresearch#45.
  • Loading branch information
ChrisCummins authored and bwasti committed Aug 3, 2021
1 parent 3cc33c2 commit 87f02b5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
10 changes: 4 additions & 6 deletions compiler_gym/datasets/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,19 @@
# Regular expression that matches the full two-part URI prefix of a dataset:
# {{protocol}}://{{dataset}}
#
# A trailing slash is permitted.
# An optional trailing slash is permitted.
#
# Example matches: "benchmark://foo-v0", "benchmark://foo-v0/".
# Example matches: "benchmark://foo-v0", "generator://bar-v0/".
DATASET_NAME_RE = re.compile(
r"(?P<dataset>(?P<dataset_protocol>[a-zA-z0-9-_]+)://(?P<dataset_name>[a-zA-z0-9-_]+-v(?P<dataset_version>[0-9]+)))/?"
)

# Regular expression that matches the full three-part format of a benchmark URI:
# {{protocol}}://{{dataset}}/{{id}}
#
# The {{id}} is optional.
#
# Example matches: "benchmark://foo-v0/" or "benchmark://foo-v0/program".
# Example matches: "benchmark://foo-v0/foo" or "generator://bar-v1/foo/bar.txt".
BENCHMARK_URI_RE = re.compile(
r"(?P<dataset>(?P<dataset_protocol>[a-zA-z0-9-_]+)://(?P<dataset_name>[a-zA-z0-9-_]+-v(?P<dataset_version>[0-9]+)))(/(?P<benchmark_name>[^\s]*))?$"
r"(?P<dataset>(?P<dataset_protocol>[a-zA-z0-9-_]+)://(?P<dataset_name>[a-zA-z0-9-_]+-v(?P<dataset_version>[0-9]+)))/(?P<benchmark_name>[^\s]+)$"
)


Expand Down
56 changes: 33 additions & 23 deletions tests/datasets/benchmark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,54 +28,58 @@ def _rgx_match(regex, groupname, string) -> str:
return match.group(groupname)


@pytest.mark.parametrize("regex", (DATASET_NAME_RE, BENCHMARK_URI_RE))
def test_benchmark_uri_protocol(regex):
assert not regex.match("B?://cbench-v1/") # Invalid characters
assert not regex.match("cbench-v1/") # Missing protocol

def test_benchmark_uri_protocol():
assert (
_rgx_match(regex, "dataset_protocol", "benchmark://cbench-v1/") == "benchmark"
_rgx_match(DATASET_NAME_RE, "dataset_protocol", "benchmark://cbench-v1/")
== "benchmark"
)
assert (
_rgx_match(regex, "dataset_protocol", "Generator13://gen-v11/") == "Generator13"
_rgx_match(DATASET_NAME_RE, "dataset_protocol", "Generator13://gen-v11/")
== "Generator13"
)


def test_benchmark_uri_dataset():
assert not BENCHMARK_URI_RE.match("benchmark://cBench?v0/") # Invalid character
assert not BENCHMARK_URI_RE.match("benchmark://cBench/") # Missing version suffix
def test_invalid_benchmark_uris():
# Invalid protocol
assert not DATASET_NAME_RE.match("B?://cbench-v1/") # Invalid characters
assert not DATASET_NAME_RE.match("cbench-v1/") # Missing protocol

# Invalid dataset name
assert not BENCHMARK_URI_RE.match("benchmark://cbench?v0/foo") # Invalid character
assert not BENCHMARK_URI_RE.match(
"benchmark://cbench/foo"
) # Missing version suffix
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v0") # Missing benchmark ID
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v0/") # Missing benchmark ID

# Invalid benchmark ID
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/ whitespace") # Whitespace
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/\t") # Whitespace


def test_benchmark_uri_dataset():
assert (
_rgx_match(BENCHMARK_URI_RE, "dataset_name", "benchmark://cbench-v1/")
_rgx_match(BENCHMARK_URI_RE, "dataset_name", "benchmark://cbench-v1/foo")
== "cbench-v1"
)
assert (
_rgx_match(BENCHMARK_URI_RE, "dataset_name", "Generator13://gen-v11/")
_rgx_match(BENCHMARK_URI_RE, "dataset_name", "Generator13://gen-v11/foo")
== "gen-v11"
)


def test_benchmark_dataset_name():
assert (
_rgx_match(BENCHMARK_URI_RE, "dataset", "benchmark://cbench-v1/")
_rgx_match(BENCHMARK_URI_RE, "dataset", "benchmark://cbench-v1/foo")
== "benchmark://cbench-v1"
)
assert (
_rgx_match(BENCHMARK_URI_RE, "dataset", "Generator13://gen-v11/")
_rgx_match(BENCHMARK_URI_RE, "dataset", "Generator13://gen-v11/foo")
== "Generator13://gen-v11"
)


def test_benchmark_uri_id():
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/ whitespace") # Whitespace
assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/\t") # Whitespace

assert (
_rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1") is None
)
assert (
_rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/") == ""
)
assert (
_rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/foo")
== "foo"
Expand All @@ -84,6 +88,12 @@ def test_benchmark_uri_id():
_rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/foo/123")
== "foo/123"
)
assert (
_rgx_match(
BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/foo/123.txt"
)
== "foo/123.txt"
)
assert (
_rgx_match(
BENCHMARK_URI_RE,
Expand Down

0 comments on commit 87f02b5

Please sign in to comment.