-
Notifications
You must be signed in to change notification settings - Fork 35
/
discrete_ebm.py
267 lines (214 loc) · 9.7 KB
/
discrete_ebm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from abc import ABC, abstractmethod
from typing import Literal, Tuple
import torch
import torch.nn as nn
from gfn.actions import Actions
from gfn.env import DiscreteEnv
from gfn.preprocessors import EnumPreprocessor, IdentityPreprocessor
from gfn.states import DiscreteStates, States
class EnergyFunction(nn.Module, ABC):
"""Base class for energy functions"""
@abstractmethod
def forward(self, states: torch.Tensor) -> torch.Tensor:
"""Forward pass of the energy function
Args:
states: tensor of states of shape (*batch_shape, *state_shape)
Returns tensor of energies of shape (*batch_shape)
"""
class IsingModel(EnergyFunction):
"""Ising model energy function"""
def __init__(self, J: torch.Tensor):
"""Ising model energy function
Args:
J: interaction matrix of shape (state_shape, state_shape)
"""
super().__init__()
self.J = J
self._state_shape, _ = J.shape
assert J.shape == (self._state_shape, self._state_shape)
self.linear = nn.Linear(self._state_shape, 1, bias=False)
self.linear.weight.data = J
def forward(self, states: torch.Tensor) -> torch.Tensor:
"""Forward pass of the ising model.
Args:
states: tensor of states of shape (*batch_shape, *state_shape)
Returns tensor of energies of shape (*batch_shape)
"""
assert states.shape[-1] == self._state_shape
states = states.float()
tmp = self.linear(states)
return -(states * tmp).sum(-1)
class DiscreteEBM(DiscreteEnv):
"""Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf"""
def __init__(
self,
ndim: int,
energy: EnergyFunction | None = None,
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
):
"""Discrete EBM environment.
Args:
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
None, the Ising model with Identity matrix is used.
alpha: interaction strength the EBM. Defaults to 1.0.
device_str: "cpu" or "cuda". Defaults to "cpu".
preprocessor_name: "KHot" or "OneHot" or "Identity".
Defaults to "KHot".
"""
self.ndim = ndim
s0 = torch.full((ndim,), -1, dtype=torch.long, device=torch.device(device_str))
sf = torch.full((ndim,), 2, dtype=torch.long, device=torch.device(device_str))
if energy is None:
energy = IsingModel(
torch.ones((ndim, ndim), device=torch.device(device_str))
)
self.energy: EnergyFunction = energy
self.alpha = alpha
n_actions = 2 * ndim + 1
# the last action is the exit action that is only available for complete states
# Action i in [0, ndim - 1] corresponds to replacing s[i] with 0
# Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1
if preprocessor_name == "Identity":
preprocessor = IdentityPreprocessor(output_dim=ndim)
elif preprocessor_name == "Enum":
preprocessor = EnumPreprocessor(
get_states_indices=self.get_states_indices,
)
else:
raise ValueError(f"Unknown preprocessor {preprocessor_name}")
super().__init__(
s0=s0,
state_shape=(self.ndim,),
# dummy_action=,
# exit_action=,
n_actions=n_actions,
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
)
def update_masks(self, states: type[States]) -> None:
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
states.backward_masks[..., : self.ndim] = states.tensor == 0
states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1
def make_random_states_tensor(self, batch_shape: Tuple) -> torch.Tensor:
"""Generates random states tensor of shape (*batch_shape, ndim)."""
return torch.randint(
-1,
2,
batch_shape + (self.ndim,),
dtype=torch.long,
device=self.device,
)
def is_exit_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""Determines if the actions are exit actions.
Args:
actions: tensor of actions of shape (*batch_shape, *action_shape)
Returns tensor of booleans of shape (*batch_shape)
"""
return actions == self.n_actions - 1
def step(self, states: States, actions: Actions) -> torch.Tensor:
"""Performs a step.
Args:
states: States object representing the current states.
actions: Actions object representing the actions to be taken.
Returns the next states as tensor of shape (*batch_shape, ndim).
"""
# First, we select that actions that replace a -1 with a 0.
# Remove singleton dimension for broadcasting.
mask_0 = (actions.tensor < self.ndim).squeeze(-1)
states.tensor[mask_0] = states.tensor[mask_0].scatter(
-1, actions.tensor[mask_0], 0 # Set indices to 0.
)
# Then, we select that actions that replace a -1 with a 1.
mask_1 = (
(actions.tensor >= self.ndim) & (actions.tensor < 2 * self.ndim)
).squeeze(
-1
) # Remove singleton dimension for broadcasting.
states.tensor[mask_1] = states.tensor[mask_1].scatter(
-1, (actions.tensor[mask_1] - self.ndim), 1 # Set indices to 1.
)
return states.tensor
def backward_step(self, states: States, actions: Actions) -> torch.Tensor:
"""Performs a backward step.
In this env, states are n-dim vectors. s0 is empty (represented as -1),
so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a
0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas
action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
A backward action asks "what index should be set back to -1", hence the fmod
to enable wrapping of indices.
"""
return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1)
def reward(self, final_states: DiscreteStates) -> torch.Tensor:
"""Not used during training but provided for completeness.
Note the effect of clipping will be seen in these values.
Args:
final_states: DiscreteStates object representing the final states.
Returns the reward as tensor of shape (*batch_shape).
"""
reward = torch.exp(self.log_reward(final_states))
assert reward.shape == final_states.batch_shape
return reward
def log_reward(self, final_states: DiscreteStates) -> torch.Tensor:
"""The energy weighted by alpha is our log reward.
Args:
final_states: DiscreteStates object representing the final states.
Returns the log reward as tensor of shape (*batch_shape)."""
raw_states = final_states.tensor
canonical = 2 * raw_states - 1
log_reward = -self.alpha * self.energy(canonical)
assert log_reward.shape == final_states.batch_shape
return log_reward
def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3
Args:
states: DiscreteStates object representing the states.
Returns the states indices as tensor of shape (*batch_shape).
"""
states_raw = states.tensor
canonical_base = 3 ** torch.arange(self.ndim - 1, -1, -1, device=self.device)
states_indices = (states_raw + 1).mul(canonical_base).sum(-1).long()
assert states_indices.shape == states.batch_shape
return states_indices
def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states.
Args:
states: DiscreteStates object representing the states.
Returns the indices of the terminating states as tensor of shape (*batch_shape).
"""
states_raw = states.tensor
canonical_base = 2 ** torch.arange(self.ndim - 1, -1, -1, device=self.device)
states_indices = (states_raw).mul(canonical_base).sum(-1).long()
assert states_indices.shape == states.batch_shape
return states_indices
@property
def n_states(self) -> int:
return 3**self.ndim
@property
def n_terminating_states(self) -> int:
return 2**self.ndim
@property
def all_states(self) -> DiscreteStates:
# This is brute force !
digits = torch.arange(3, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
all_states = all_states - 1
return self.States(all_states)
@property
def terminating_states(self) -> DiscreteStates:
digits = torch.arange(2, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
return self.States(all_states)
@property
def true_dist_pmf(self) -> torch.Tensor:
true_dist = self.reward(self.terminating_states)
return true_dist / true_dist.sum()
@property
def log_partition(self) -> float:
log_rewards = self.log_reward(self.terminating_states)
return torch.logsumexp(log_rewards, -1).item()