diff --git a/compiler_gym/compiler_env_state.py b/compiler_gym/compiler_env_state.py index e3c4307aa..07bc5f398 100644 --- a/compiler_gym/compiler_env_state.py +++ b/compiler_gym/compiler_env_state.py @@ -4,9 +4,12 @@ # LICENSE file in the root directory of this source tree. """This module defines a class to represent a compiler environment state.""" import csv +import re import sys +from io import StringIO from typing import Iterable, List, Optional, TextIO +import requests from pydantic import BaseModel, Field, validator from compiler_gym.datasets.uri import BenchmarkUri @@ -23,10 +26,7 @@ class CompilerEnvState(BaseModel): benchmark: str = Field( allow_mutation=False, - examples=[ - "benchmark://cbench-v1/crc32", - "generator://csmith-v0/0", - ], + examples=["benchmark://cbench-v1/crc32", "generator://csmith-v0/0",], ) """The URI of the benchmark used for this episode.""" @@ -37,9 +37,7 @@ class CompilerEnvState(BaseModel): """The walltime of the episode in seconds. Must be non-negative.""" reward: Optional[float] = Field( - required=False, - default=None, - allow_mutation=True, + required=False, default=None, allow_mutation=True, ) """The cumulative reward for this episode. Optional.""" @@ -229,6 +227,16 @@ def read_paths(paths: Iterable[str]) -> Iterable[CompilerEnvState]: for path in paths: if path == "-": yield from iter(CompilerEnvStateReader(sys.stdin)) + elif ( + re.match(r"^(http|https)://[a-zA-Z0-9.-_/]+(\.csv)$", path) is not None + ): + response: requests.Response = requests.get(path) + if response.status_code == 200: + yield from iter(CompilerEnvStateReader(StringIO(response.text))) + else: + raise requests.exceptions.InvalidURL( + f"Url {path} content could not be obtained" + ) else: with open(path) as f: yield from iter(CompilerEnvStateReader(f)) diff --git a/tests/compiler_env_state_test.py b/tests/compiler_env_state_test.py index 1d722d410..2ccf30675 100644 --- a/tests/compiler_env_state_test.py +++ b/tests/compiler_env_state_test.py @@ -8,6 +8,7 @@ from pathlib import Path import pytest +import requests from pydantic import ValidationError as PydanticValidationError from compiler_gym import CompilerEnvState, CompilerEnvStateWriter @@ -321,5 +322,92 @@ def test_state_serialize_deserialize_equality_no_reward(): assert state_from_csv.commandline == "-a -b -c" +def test_read_paths_stdin(monkeypatch): + monkeypatch.setattr( + "sys.stdin", + StringIO( + "benchmark,reward,walltime,commandline\n" + "benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n" + ), + ) + reader = CompilerEnvStateReader.read_paths(["-"]) + assert list(reader) == [ + CompilerEnvState( + benchmark="benchmark://cbench-v0/foo", + walltime=5, + commandline="-a -b -c", + reward=2, + ) + ] + + +def test_read_paths_file(tmp_path): + file_dir = f"{tmp_path}/test.csv" + with open(file_dir, "w") as csv_file: + csv_file.write( + "benchmark,reward,walltime,commandline\n" + "benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n" + ) + reader = CompilerEnvStateReader.read_paths([file_dir]) + assert list(reader) == [ + CompilerEnvState( + benchmark="benchmark://cbench-v0/foo", + walltime=5, + commandline="-a -b -c", + reward=2, + ) + ] + + +def test_read_paths_url(monkeypatch): + urls = ["https://compilergym.ai/benchmarktest.csv"] + + class MockResponse: + def __init__(self, text, status_code): + self.text = text + self.status_code = status_code + + def ok_mock_response(*args, **kwargs): + return MockResponse( + ( + "benchmark,reward,walltime,commandline\n" + "benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n" + ), + 200, + ) + + monkeypatch.setattr(requests, "get", ok_mock_response) + reader = CompilerEnvStateReader.read_paths(urls) + assert list(reader) == [ + CompilerEnvState( + benchmark="benchmark://cbench-v0/foo", + walltime=5, + commandline="-a -b -c", + reward=2, + ) + ] + + def bad_mock_response(*args, **kwargs): + return MockResponse("", 404) + + monkeypatch.setattr(requests, "get", bad_mock_response) + with pytest.raises(requests.exceptions.InvalidURL): + reader = CompilerEnvStateReader.read_paths(urls) + list(reader) + + +def test_read_paths_bad_inputs(): + bad_dirs = [ + "/fake/directory/file.csv", + "fake/directory/file.csv", + "https://www.compilergym.ai/benchmark", + "htts://www.compilergym.ai/benchmark.csv", + "htts://www.compilergym.ai/benchmark", + ] + with pytest.raises(FileNotFoundError): + reader = CompilerEnvStateReader.read_paths(bad_dirs) + list(reader) + + if __name__ == "__main__": main()