Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA & FIX: allow training_neg_sample_num > 1 in Sequential && add clip_grad_norm to trainer.py #533

Merged
merged 2 commits into from
Nov 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions recbole/data/dataloader/sequential_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ def _neg_sampling(self, data):
return self.sampling_func(data, neg_iids)

def _neg_sample_by_pair_wise_sampling(self, data, neg_iids):
data[self.neg_item_id] = neg_iids
return data
new_data = {key: np.concatenate([value] * self.times) for key, value in data.items()}
new_data[self.neg_item_id] = neg_iids
return new_data

def _neg_sample_by_point_wise_sampling(self, data, neg_iids):
new_data = {}
Expand Down
1 change: 1 addition & 0 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ learning_rate: 0.001
training_neg_sample_num: 1
eval_step: 1
stopping_step: 10
clip_grad_norm: {'max_norm': 5, 'norm_type': 2}

# evaluation settings
eval_setting: RO_RS,full
Expand Down
4 changes: 4 additions & 0 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self, config, model):
self.epochs = config['epochs']
self.eval_step = min(config['eval_step'], self.epochs)
self.stopping_step = config['stopping_step']
self.clip_grad_norm = config['clip_grad_norm']
self.valid_metric = config['valid_metric'].lower()
self.valid_metric_bigger = config['valid_metric_bigger']
self.test_batch_size = config['eval_batch_size']
Expand Down Expand Up @@ -149,6 +151,8 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None):
total_loss = losses.item() if total_loss is None else total_loss + losses.item()
self._check_nan(loss)
loss.backward()
if self.clip_grad_norm:
clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
self.optimizer.step()
return total_loss

Expand Down