Skip to content

Commit

Permalink
Remove use of __slots__ in classes.
Browse files Browse the repository at this point in the history
The __slots__ attribute is not compatible with copy() operator.

Fixes #768.
  • Loading branch information
ChrisCummins committed Oct 5, 2022
1 parent 17348dd commit 685abce
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 25 deletions.
2 changes: 0 additions & 2 deletions compiler_gym/datasets/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class Benchmark:
"""

__slots__ = ["_proto", "_validation_callbacks", "_sources"]

def __init__(
self,
proto: BenchmarkProto,
Expand Down
10 changes: 0 additions & 10 deletions compiler_gym/envs/llvm/llvm_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ class CostFunctionReward(Reward):
function.
"""

__slots__ = [
"cost_function",
"init_cost_function",
"previous_cost",
]

def __init__(self, cost_function: str, init_cost_function: str, **kwargs):
"""Constructor.
Expand Down Expand Up @@ -61,8 +55,6 @@ def update(
class NormalizedReward(CostFunctionReward):
"""A cost function reward that is normalized to the initial value."""

__slots__ = ["cost_norm", "benchmark"]

def __init__(self, **kwargs):
"""Constructor."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -100,8 +92,6 @@ class BaselineImprovementNormalizedReward(NormalizedReward):
baseline approach.
"""

__slots__ = ["baseline_cost_function"]

def __init__(self, baseline_cost_function: str, **kwargs):
super().__init__(**kwargs)
self.baseline_cost_function: str = baseline_cost_function
Expand Down
10 changes: 0 additions & 10 deletions compiler_gym/spaces/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ class Reward(Scalar):
signals based on observation values computed by the backend service.
"""

__slots__ = [
"name",
"observation_spaces",
"default_value",
"default_negates_returns",
"success_threshold",
"deterministic",
"platform_dependent",
]

def __init__(
self,
name: str,
Expand Down
4 changes: 1 addition & 3 deletions compiler_gym/spaces/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
class Scalar(Space):
"""A scalar value."""

__slots__ = ["min", "max", "dtype"]

def __init__(
self,
name: str,
Expand All @@ -35,10 +33,10 @@ def __init__(
:param dtype: The type of this scalar.
"""
self.name = name
self.min = min
self.max = max
self.dtype = dtype
self.name = name

def sample(self):
min = 0 if self.min is None else self.min
Expand Down
9 changes: 9 additions & 0 deletions tests/spaces/scalar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/spaces:scalar."""
from copy import copy, deepcopy

from compiler_gym.spaces import Scalar
from tests.test_main import main

Expand Down Expand Up @@ -67,5 +69,12 @@ def test_not_equal():
assert scalar != "not_as_scalar"


def test_deepcopy_regression_test():
"""Test to reproduce github.com/facebookresearch/CompilerGym/issues/768."""
x = Scalar(name="foo")
copy(x)
deepcopy(x)


if __name__ == "__main__":
main()

0 comments on commit 685abce

Please sign in to comment.