-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.py
32 lines (21 loc) · 907 Bytes
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import util
class State():
def __init__(self, state):
self.state = state
def __repr__(self):
return str(self.state)
def __eq__(self, other):
return np.equal(self.state, other.state).all()
def __hash__(self):
return hash(str(self.state))
def __len__(self):
return len(self.state)
def minus(self, other, hide_value=0):
return State(np.logical_and(self.state.astype(bool), other.state.astype(bool)).astype(float) + hide_value * self.invert() + hide_value * other.invert())
def apply(self, other, hide_value=0):
return State(self.state * other.state + hide_value * self.invert())
def invert(self):
return np.logical_xor(self.state, np.ones(self.state.shape), dtype=float).astype(float)
def get_state(self):
return self.state