Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a wrapper to synchronously log environment state transitions to SQLite database #679

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler_gym/util/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def humanize_duration(seconds: float) -> str:


def humanize_duration_hms(seconds: float) -> str:
"""Format a time in to :code:`hours:minutes:seconds` format."""
seconds = int(seconds)
return f"{seconds // 3600}:{(seconds % 3600) // 60:02d}:{seconds % 60:02d}"

Expand Down
1 change: 1 addition & 0 deletions compiler_gym/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ py_library(
"core.py",
"datasets.py",
"llvm.py",
"sqlite_logger.py",
"time_limit.py",
"validation.py",
],
Expand Down
1 change: 1 addition & 0 deletions compiler_gym/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(WRAPPERS_SRCS
)
if(COMPILER_GYM_ENABLE_LLVM_ENV)
list(APPEND WRAPPERS_SRCS "llvm.py")
list(APPEND WRAPPERS_SRCS "sqlite_logger.py")
endif()
if(COMPILER_GYM_ENABLE_MLIR_ENV)
list(APPEND WRAPPERS_SRCS "mlir.py")
Expand Down
4 changes: 4 additions & 0 deletions compiler_gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@

if config.enable_llvm_env:
from compiler_gym.wrappers.llvm import RuntimePointEstimateReward # noqa: F401
from compiler_gym.wrappers.sqlite_logger import ( # noqa: F401
SynchronousSqliteLogger,
)

from compiler_gym.wrappers.time_limit import TimeLimit

Expand All @@ -69,3 +72,4 @@

if config.enable_llvm_env:
__all__.append("RuntimePointEstimateReward")
__all__.append("SynchronousSqliteLogger")
274 changes: 274 additions & 0 deletions compiler_gym/wrappers/sqlite_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements a wrapper that logs state transitions to an sqlite
database.
"""
import logging
import pickle
import sqlite3
import zlib
from pathlib import Path
from time import time
from typing import Iterable, Optional, Union

import numpy as np

from compiler_gym.envs import LlvmEnv
from compiler_gym.spaces import Reward
from compiler_gym.util.gym_type_hints import ActionType
from compiler_gym.util.timer import Timer, humanize_duration
from compiler_gym.views import ObservationSpaceSpec
from compiler_gym.wrappers import CompilerEnvWrapper

DB_CREATION_SCRIPT = """
CREATE TABLE IF NOT EXISTS States (
benchmark_uri TEXT NOT NULL, -- The URI of the benchmark.
done INTEGER NOT NULL, -- 0 = False, 1 = True.
ir_instruction_count_oz_reward REAL NULLABLE,
state_id TEXT NOT NULL, -- 40-char sha1.
actions TEXT NOT NULL, -- Decode: [int(x) for x in field.split()]
PRIMARY KEY (benchmark_uri, actions),
FOREIGN KEY (state_id) REFERENCES Observations(state_id) ON UPDATE CASCADE
);

CREATE TABLE IF NOT EXISTS Observations (
state_id TEXT NOT NULL, -- 40-char sha1.
ir_instruction_count INTEGER NOT NULL,
compressed_llvm_ir BLOB NOT NULL, -- Decode: zlib.decompress(...)
pickled_compressed_programl BLOB NOT NULL, -- Decode: pickle.loads(zlib.decompress(...))
autophase TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64)
instcount TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64)
PRIMARY KEY (state_id)
);
"""


class SynchronousSqliteLogger(CompilerEnvWrapper):
"""A wrapper for an LLVM environment that logs all transitions to an sqlite
database.

Wrap an existing LLVM environment and then use it as per normal:

>>> env = SynchronousSqliteLogger(
... env=gym.make("llvm-autophase-ic-v0"),
... db_path="example.db",
... )

Connect to the database file you specified:

.. code-block::
$ sqlite3 example.db

There are two tables:

1. States: records every unique combination of benchmark + actions. For each
entry, records an identifying state ID, the episode reward, and whether
the episode is terminated:

.. code-block::

sqlite> .mode markdown
sqlite> .headers on
sqlite> select * from States limit 5;
| benchmark_uri | done | ir_instruction_count_oz_reward | state_id | actions |
|--------------------------|------|--------------------------------|------------------------------------------|----------------|
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | |
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | 31 |
| generator://csmith-v0/99 | 0 | 0.0 | 52f7142ef606d8b1dec2ff3371c7452c8d7b81ea | 31 116 |
| generator://csmith-v0/99 | 0 | 0.268005818128586 | d8c05bd41b7a6c6157b6a8f0f5093907c7cc7ecf | 31 116 103 |
| generator://csmith-v0/99 | 0 | 0.288621664047241 | c4d7ecd3807793a0d8bc281104c7f5a8aa4670f9 | 31 116 103 109 |

2. Observations: records pickled, compressed, and text observation values
for each unique state.

Caveats of this implementation:

1. Only :class:`LlvmEnv <compiler_gym.envs.LlvmEnv>` environments may be
wrapped.

2. The wrapped environment must have an observation space and reward space
set.

3. The observation spaces and reward spaces that are logged to database
are hardcoded. To change what is recorded, you must copy and modify this
implementation.

4. Writing to the database is synchronous and adds significant overhead to
the compute cost of the environment.
"""

def __init__(
self,
env: LlvmEnv,
db_path: Path,
commit_frequency_in_seconds: int = 300,
max_step_buffer_length: int = 5000,
):
"""Constructor.

:param env: The environment to wrap.

:param db_path: The path of the database to log to. This file may
already exist. If it does, new entries are appended. If the files
does not exist, it is created.

:param commit_frequency_in_seconds: The maximum amount of time to elapse
before writing pending logs to the database.

:param max_step_buffer_length: The maximum number of calls to
:code:`step()` before writing pending logs to the database.
"""
super().__init__(env)
if not hasattr(env, "unwrapped"):
raise TypeError("Requires LlvmEnv base environment")
if not isinstance(self.unwrapped, LlvmEnv):
raise TypeError("Requires LlvmEnv base environment")
db_path.parent.mkdir(exist_ok=True, parents=True)
self.connection = sqlite3.connect(str(db_path))
self.cursor = self.connection.cursor()
self.commit_frequency = commit_frequency_in_seconds
self.max_step_buffer_length = max_step_buffer_length

self.cursor.executescript(DB_CREATION_SCRIPT)
self.connection.commit()
self.last_commit = time()

self.observations_buffer = {}
self.step_buffer = []

# House keeping notice: Keep these lists in sync with record().
self._observations = [
self.env.observation.spaces["IrSha1"],
self.env.observation.spaces["Ir"],
self.env.observation.spaces["Programl"],
self.env.observation.spaces["Autophase"],
self.env.observation.spaces["InstCount"],
self.env.observation.spaces["IrInstructionCount"],
]
self._rewards = [
self.env.reward.spaces["IrInstructionCountOz"],
self.env.reward.spaces["IrInstructionCount"],
]
self._reward_totals = np.zeros(len(self._rewards))

def flush(self) -> None:
"""Flush the buffered steps and observations to database."""
n_steps, n_observations = len(self.step_buffer), len(self.observations_buffer)

# Nothing to flush.
if not n_steps:
return

with Timer() as flush_time:
# House keeping notice: Keep these statements in sync with record().
self.cursor.executemany(
"INSERT OR IGNORE INTO States VALUES (?, ?, ?, ?, ?)",
self.step_buffer,
)
self.cursor.executemany(
"INSERT OR IGNORE INTO Observations VALUES (?, ?, ?, ?, ?, ?)",
((k, *v) for k, v in self.observations_buffer.items()),
)
self.step_buffer = []
self.observations_buffer = {}

self.connection.commit()

logging.info(
"Wrote %d state records and %d observations in %s. Last flush %s ago",
n_steps,
n_observations,
flush_time,
humanize_duration(time() - self.last_commit),
)
self.last_commit = time()

def reset(self, *args, **kwargs):
observation = self.env.reset(*args, **kwargs)
observations, rewards, done, info = self.env.multistep(
actions=[],
observation_spaces=self._observations,
reward_spaces=self._rewards,
)
assert not done, f"reset() failed! {info}"
self._reward_totals = np.array(rewards, dtype=np.float32)
rewards = self._reward_totals
self._record(
actions=self.actions,
observations=observations,
rewards=self._reward_totals,
done=False,
)
return observation

def step(
self,
action: ActionType,
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
assert self.observation_space, "No observation space set"
assert self.reward_space, "No reward space set"
assert (
observation_spaces is None
), "SynchronousSqliteLogger does not support observation_spaces"
assert (
reward_spaces is None
), "SynchronousSqliteLogger does not support reward_spaces"
assert (
observations is None
), "SynchronousSqliteLogger does not support observations"
assert rewards is None, "SynchronousSqliteLogger does not support rewards"

observations, rewards, done, info = self.env.step(
action=action,
observation_spaces=self._observations + [self.observation_space_spec],
reward_spaces=self._rewards + [self.reward_space],
)
self._reward_totals += rewards[:-1]
self._record(
actions=self.actions,
observations=observations[:-1],
rewards=self._reward_totals,
done=done,
)
return observations[-1], rewards[-1], done, info

def _record(self, actions, observations, rewards, done) -> None:
state_id, ir, programl, autophase, instcount, instruction_count = observations
instruction_count_reward = float(rewards[0])

self.step_buffer.append(
(
str(self.benchmark.uri),
1 if done else 0,
instruction_count_reward,
state_id,
" ".join(str(x) for x in actions),
)
)

self.observations_buffer[state_id] = (
instruction_count,
zlib.compress(ir.encode("utf-8")),
zlib.compress(pickle.dumps(programl)),
" ".join(str(x) for x in autophase),
" ".join(str(x) for x in instcount),
)

if (
len(self.step_buffer) >= self.max_step_buffer_length
or time() - self.last_commit >= self.commit_frequency
):
self.flush()

def close(self):
self.flush()
self.env.close()

def fork(self):
raise NotImplementedError
7 changes: 7 additions & 0 deletions docs/source/compiler_gym/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ LLVM Environment wrappers
.. autoclass:: RuntimePointEstimateReward

.. automethod:: __init__


.. autoclass:: SynchronousSqliteLogger

.. automethod:: __init__

.. automethod:: flush
11 changes: 11 additions & 0 deletions tests/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ py_test(
],
)

py_test(
name = "sqlite_logger_test",
timeout = "short",
srcs = ["sqlite_logger_test.py"],
deps = [
"//compiler_gym/wrappers",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
)

py_test(
name = "time_limit_wrappers_test",
timeout = "short",
Expand Down
22 changes: 22 additions & 0 deletions tests/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ cg_py_test(
tests::pytest_plugins::llvm
)

cg_py_test(
NAME sqlite_logger_test
SRCS "sqlite_logger_test.py"
DEPS
compiler_gym::envs::llvm::llvm
compiler_gym::errors::errors
compiler_gym::wrappers::wrappers
tests::test_main
tests::pytest_plugins::llvm
)

cg_py_test(
NAME
time_limit_wrappers_test
Expand All @@ -63,3 +74,14 @@ cg_py_test(
tests::pytest_plugins::llvm
tests::test_main
)

cg_py_test(
NAME validation_test
SRCS "validation_test.py"
DEPS
compiler_gym::envs::llvm::llvm
compiler_gym::errors::errors
compiler_gym::wrappers::wrappers
tests::test_main
tests::pytest_plugins::llvm
)
Loading