Skip to content

Commit

Permalink
Format Python code according to PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
BishopLiu authored and github-actions[bot] committed Oct 6, 2023
1 parent fefcf67 commit e4ee796
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 90 deletions.
157 changes: 111 additions & 46 deletions recbole/model/general_recommender/diffrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from recbole.model.layers import MLPLayers
import typing


class ModelMeanType(enum.Enum):
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
Expand All @@ -35,7 +36,16 @@ class DNN(nn.Module):
"""
A deep neural network for the reverse diffusion preocess.
"""
def __init__(self, dims: typing.List, emb_size: int, time_type="cat", act_func="tanh", norm=False, dropout=0.5):

def __init__(
self,
dims: typing.List,
emb_size: int,
time_type="cat",
act_func="tanh",
norm=False,
dropout=0.5,
):
super(DNN, self).__init__()
self.dims = dims
self.time_type = time_type
Expand All @@ -48,9 +58,13 @@ def __init__(self, dims: typing.List, emb_size: int, time_type="cat", act_func="
# Concatenate timestep embedding with input
self.dims[0] += self.time_emb_dim
else:
raise ValueError("Unimplemented timestep embedding type %s" % self.time_type)
raise ValueError(
"Unimplemented timestep embedding type %s" % self.time_type
)

self.mlp_layers = MLPLayers(layers=self.dims, dropout=0, activation=act_func, last_activation=False)
self.mlp_layers = MLPLayers(
layers=self.dims, dropout=0, activation=act_func, last_activation=False
)
self.drop = nn.Dropout(dropout)

self.apply(xavier_normal_initialization)
Expand All @@ -72,9 +86,9 @@ class DiffRec(GeneralRecommender, AutoEncoderMixin):
def __init__(self, config, dataset):
super(DiffRec, self).__init__(config, dataset)

if config["mean_type"] == 'x0':
if config["mean_type"] == "x0":
self.mean_type = ModelMeanType.START_X
elif config["mean_type"] == 'eps':
elif config["mean_type"] == "eps":
self.mean_type = ModelMeanType.EPSILON
else:
raise ValueError("Unimplemented mean type %s" % config["mean_type"])
Expand All @@ -92,27 +106,45 @@ def __init__(self, config, dataset):
self.emb_size = config["embedding_size"]
self.norm = config["norm"] # True or False
self.reweight = config["reweight"] # reweight the loss for different timesteps
self.sampling_noise = config["sampling_noise"] # whether sample noise during predict
self.sampling_noise = config[
"sampling_noise"
] # whether sample noise during predict
self.sampling_steps = config["sampling_steps"]
self.mlp_act_func = config["mlp_act_func"]
assert self.sampling_steps <= self.steps, "Too much steps in inference."

self.history_num_per_term = config["history_num_per_term"]
self.Lt_history = torch.zeros(self.steps, self.history_num_per_term, dtype=torch.float64).to(self.device)
self.Lt_history = torch.zeros(
self.steps, self.history_num_per_term, dtype=torch.float64
).to(self.device)
self.Lt_count = torch.zeros(self.steps, dtype=int).to(self.device)

dims = [self.n_items] + config["dims_dnn"] + [self.n_items]

self.mlp = DNN(dims=dims, emb_size=self.emb_size, time_type="cat", norm=self.norm, act_func=self.mlp_act_func).to(self.device)

if self.noise_scale != 0.:
self.betas = torch.tensor(self.get_betas(), dtype=torch.float64).to(self.device)
self.mlp = DNN(
dims=dims,
emb_size=self.emb_size,
time_type="cat",
norm=self.norm,
act_func=self.mlp_act_func,
).to(self.device)

if self.noise_scale != 0.0:
self.betas = torch.tensor(self.get_betas(), dtype=torch.float64).to(
self.device
)
if self.beta_fixed:
self.betas[0] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1
self.betas[
0
] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1
# The variance \beta_1 of the first step is fixed to a small constant to prevent overfitting.
assert len(self.betas.shape) == 1, "betas must be 1-D"
assert len(self.betas) == self.steps, "num of betas must equal to diffusion steps"
assert (self.betas > 0).all() and (self.betas <= 1).all(), "betas out of range"
assert (
len(self.betas) == self.steps
), "num of betas must equal to diffusion steps"
assert (self.betas > 0).all() and (
self.betas <= 1
).all(), "betas out of range"

self.calculate_for_diffusion()

Expand All @@ -134,8 +166,8 @@ def build_histroy_items(self, dataset):
row_num = dataset.user_num
row_ids, col_ids = user_ids, item_ids

for uid in range(1, row_num+1):
uindex = np.argwhere(user_ids==uid).flatten()
for uid in range(1, row_num + 1):
uindex = np.argwhere(user_ids == uid).flatten()
int_num = len(uindex)
weight = np.linspace(w_min, w_max, int_num)
values[uindex] = weight
Expand Down Expand Up @@ -173,28 +205,33 @@ def get_betas(self):
if self.noise_schedule == "linear":
return np.linspace(start, end, self.steps, dtype=np.float64)
else:
return betas_from_linear_variance(self.steps, np.linspace(start, end, self.steps, dtype=np.float64))
return betas_from_linear_variance(
self.steps, np.linspace(start, end, self.steps, dtype=np.float64)
)
elif self.noise_schedule == "cosine":
return betas_for_alpha_bar(
self.steps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
)
self.steps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
)
# Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1
elif self.noise_schedule == "binomial":
ts = np.arange(self.steps)
betas = [1 / (self.steps - t + 1) for t in ts]
return betas
else:
raise NotImplementedError(f"unknown beta schedule: {self.noise_schedule}!")

def calculate_for_diffusion(self):
alphas = 1.0 - self.betas
# [alpha_{1}, ..., alpha_{1}*...*alpha_{T}] shape (steps,)
self.alphas_cumprod = torch.cumprod(alphas, axis=0).to(self.device)
# alpha_{t-1}
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]]).to(self.device)
self.alphas_cumprod_prev = torch.cat(
[torch.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]]
).to(self.device)
# alpha_{t+1}
self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], torch.tensor([0.0]).to(self.device)]).to(self.device)
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:], torch.tensor([0.0]).to(self.device)]
).to(self.device)
assert self.alphas_cumprod_prev.shape == (self.steps,)

self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
Expand All @@ -208,11 +245,15 @@ def calculate_for_diffusion(self):
)

self.posterior_log_variance_clipped = torch.log(
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
torch.cat(
[self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]
)
)
# Eq.10 coef for x_theta
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.betas
* torch.sqrt(self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
# Eq.10 coef for x_t
self.posterior_mean_coef2 = (
Expand All @@ -231,7 +272,7 @@ def p_sample(self, x_start):

indices = list(range(self.steps))[::-1]

if self.noise_scale == 0.:
if self.noise_scale == 0.0:
for i in indices:
t = torch.tensor([i] * x_t.shape[0]).to(x_start.device)
x_t = self.mlp(x_t, t)
Expand All @@ -245,7 +286,10 @@ def p_sample(self, x_start):
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))
) # no noise when t == 0
x_t = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
x_t = (
out["mean"]
+ nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
)
else:
x_t = out["mean"]
return x_t
Expand All @@ -267,9 +311,9 @@ def calculate_loss(self, interaction):
x_start = self.get_rating_matrix(user)

batch_size, device = x_start.size(0), x_start.device
ts, pt = self.sample_timesteps(batch_size, device, 'importance')
ts, pt = self.sample_timesteps(batch_size, device, "importance")
noise = torch.randn_like(x_start)
if self.noise_scale != 0.:
if self.noise_scale != 0.0:
x_t = self.q_sample(x_start, ts, noise)
else:
x_t = x_start
Expand Down Expand Up @@ -301,9 +345,15 @@ def reweight_loss(self, x_start, x_t, mse, ts, target, model_output, device):
weight = torch.where((ts == 0), 1.0, weight)
loss = mse
elif self.mean_type == ModelMeanType.EPSILON:
weight = (1 - self.alphas_cumprod[ts]) / ((1-self.alphas_cumprod_prev[ts])**2 * (1-self.betas[ts]))
weight = (1 - self.alphas_cumprod[ts]) / (
(1 - self.alphas_cumprod_prev[ts]) ** 2 * (1 - self.betas[ts])
)
weight = torch.where((ts == 0), 1.0, weight)
likelihood = mean_flat((x_start - self._predict_xstart_from_eps(x_t, ts, model_output))**2 / 2.0)
likelihood = mean_flat(
(x_start - self._predict_xstart_from_eps(x_t, ts, model_output))
** 2
/ 2.0
)
loss = torch.where((ts == 0), likelihood, mse)
else:
weight = torch.tensor([1.0] * len(target)).to(device)
Expand All @@ -328,24 +378,26 @@ def update_Lt_history(self, ts, reloss):
print(loss)
raise ValueError

def sample_timesteps(self, batch_size, device, method='uniform', uniform_prob=0.001):
if method == 'importance': # importance sampling
def sample_timesteps(
self, batch_size, device, method="uniform", uniform_prob=0.001
):
if method == "importance": # importance sampling
if not (self.Lt_count == self.history_num_per_term).all():
return self.sample_timesteps(batch_size, device, method='uniform')
return self.sample_timesteps(batch_size, device, method="uniform")

Lt_sqrt = torch.sqrt(torch.mean(self.Lt_history ** 2, axis=-1))
Lt_sqrt = torch.sqrt(torch.mean(self.Lt_history**2, axis=-1))
pt_all = Lt_sqrt / torch.sum(Lt_sqrt)
pt_all *= 1 - uniform_prob
pt_all += uniform_prob / len(pt_all) # ensure the least prob > uniform_prob

assert pt_all.sum(-1) - 1. < 1e-5
assert pt_all.sum(-1) - 1.0 < 1e-5

t = torch.multinomial(pt_all, num_samples=batch_size, replacement=True)
pt = pt_all.gather(dim=0, index=t) * len(pt_all)

return t, pt

elif method == 'uniform': # uniform sampling
elif method == "uniform": # uniform sampling
t = torch.randint(0, self.steps, (batch_size,), device=device).long()
pt = torch.ones_like(t).float()

Expand All @@ -359,8 +411,11 @@ def q_sample(self, x_start, t, noise=None):
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ self._extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
* x_start
+ self._extract_into_tensor(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
* noise
)

Expand All @@ -374,7 +429,9 @@ def q_posterior_mean_variance(self, x_start, x_t, t):
self._extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ self._extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_variance = self._extract_into_tensor(
self.posterior_variance, t, x_t.shape
)
posterior_log_variance_clipped = self._extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
Expand All @@ -392,7 +449,7 @@ def p_mean_variance(self, x, t):
the initial x, x_0.
"""
B, C = x.shape[:2]
assert t.shape == (B, )
assert t.shape == (B,)
model_output = self.mlp(x, t)

model_variance = self.posterior_variance
Expand All @@ -408,7 +465,9 @@ def p_mean_variance(self, x, t):
else:
raise NotImplementedError(self.mean_type)

model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)

assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
Expand All @@ -424,8 +483,10 @@ def p_mean_variance(self, x, t):
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)
* x_t
- self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* eps
)

def SNR(self, t):
Expand Down Expand Up @@ -532,8 +593,12 @@ def timestep_embedding(timesteps, dim, max_period=10000):

half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(timesteps.device) # shape (dim//2,)
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(
timesteps.device
) # shape (dim//2,)
args = timesteps[:, None].float() * freqs[None] # (N, dim//2)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # (N, (dim//2)*2)
if dim % 2:
Expand Down
Loading

0 comments on commit e4ee796

Please sign in to comment.