Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Distributed Recommendation Implemention #1338

Merged
merged 10 commits into from
Jul 8, 2022
Merged
6 changes: 3 additions & 3 deletions docs/source/user_guide/config/environment_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ Environment settings
===========================
Environment settings are designed to set basic parameters of running environment.

- ``gpu_id (int or str)`` : The id of GPU device. Defaults to ``0``.
- ``use_gpu (bool)`` : Whether or not to use GPU. If True, using GPU, else using CPU.
Defaults to ``True``.
- ``gpu_id (str)`` : The id of available GPU devices. Defaults to ``0``.
- ``worker (int)`` : The number of workers processing the data.
- ``seed (int)`` : Random seed. Defaults to ``2020``.
- ``state (str)`` : Logging level. Defaults to ``'INFO'``.
Range in ``['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']``.
Expand Down Expand Up @@ -39,3 +38,4 @@ Environment settings are designed to set basic parameters of running environment
Defaults to ``False``.
- ``wandb_project (str)``: The project to conduct experiment in W&B.
Defaults to ``'recbole'``.
- ``shuffle (bool)``: Whether or not shuffle the training data before each epoch. Defaults to ``True``.
28 changes: 23 additions & 5 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import sys
import yaml
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that regardless of the existence of local_rank, we need to import torch. So what are the concerns of removing this line here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make the setting of environment variables effective, we must put the os.environ["CUDA_VISIBLE_DEVICES"] behind import torch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks.

from logging import getLogger

from recbole.evaluator import metric_types, smaller_metrics
Expand Down Expand Up @@ -342,10 +341,29 @@ def _set_default_parameters(self):
raise NotImplementedError('Full sort evaluation do not match value-based metrics!')

def _init_device(self):
use_gpu = self.final_config_dict['use_gpu']
if use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.final_config_dict['gpu_id'])
self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
gpu_id = self.final_config_dict['gpu_id']
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
import torch

if 'local_rank' not in self.final_config_dict:
self.final_config_dict['single_spec'] = True
self.final_config_dict['local_rank'] = 0
self.final_config_dict['device'] = torch.device("cpu") if len(gpu_id) == 0 or not torch.cuda.is_available() else torch.device("cuda")
else:
assert len(gpu_id.split(',')) >= self.final_config_dict['nproc']
torch.distributed.init_process_group(
backend='nccl',
rank=self.final_config_dict['local_rank'],
world_size=self.final_config_dict['world_size'],
init_method='tcp://' + self.final_config_dict['ip'] + ':' + str(self.final_config_dict['port'])
)
self.final_config_dict['device'] = torch.device("cuda", self.final_config_dict['local_rank'])
self.final_config_dict['single_spec'] = False
torch.cuda.set_device(self.final_config_dict['local_rank'])
if self.final_config_dict['local_rank'] != 0:
self.final_config_dict['state'] = 'error'
self.final_config_dict['show_progress'] = False
self.final_config_dict['verbose'] = False

def _set_train_neg_sample_args(self):
neg_sampling = self.final_config_dict['neg_sampling']
Expand Down
79 changes: 33 additions & 46 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from recbole.utils import InputType, FeatureType, FeatureSource


class AbstractDataLoader:
class AbstractDataLoader(torch.utils.data.DataLoader):
""":class:`AbstractDataLoader` is an abstract object which would return a batch of data which is loaded by
:class:`~recbole.data.interaction.Interaction` when it is iterated.
And it is also the ancestor of all other dataloader.
Expand All @@ -42,14 +42,27 @@ class AbstractDataLoader:
"""

def __init__(self, config, dataset, sampler, shuffle=False):
self.config = config
self.logger = getLogger()
self.dataset = dataset
self.sampler = sampler
self.batch_size = self.step = self.model = None
self.shuffle = shuffle
self.pr = 0
self.config = config
self._dataset = dataset
self._sampler = sampler
self._batch_size = self.step = self.model = None
self._init_batch_size_and_step()
index_sampler = None
if not config['single_spec']:
index_sampler = torch.utils.data.distributed.DistributedSampler(
list(range(self.sample_size)), shuffle=shuffle, drop_last=False
)
self.step = max(1, self.step // config['world_size'])
shuffle = False
super().__init__(
dataset=list(range(self.sample_size)),
batch_size=self.step,
collate_fn=self.collate_fn,
num_workers=config['worker'],
shuffle=shuffle,
sampler=index_sampler
)

def _init_batch_size_and_step(self):
"""Initializing :attr:`step` and :attr:`batch_size`."""
Expand All @@ -64,47 +77,18 @@ def update_config(self, config):
self.config = config
self._init_batch_size_and_step()

def __len__(self):
return math.ceil(self.pr_end / self.step)

def __iter__(self):
if self.shuffle:
self._shuffle()
return self

def __next__(self):
if self.pr >= self.pr_end:
self.pr = 0
raise StopIteration()
return self._next_batch_data()

@property
def pr_end(self):
"""This property marks the end of dataloader.pr which is used in :meth:`__next__`."""
raise NotImplementedError('Method [pr_end] should be implemented')

def _shuffle(self):
"""Shuffle the order of data, and it will be called by :meth:`__iter__` if self.shuffle is True.
"""
raise NotImplementedError('Method [shuffle] should be implemented.')

def _next_batch_data(self):
"""Assemble next batch of data in form of Interaction, and return these data.

Returns:
Interaction: The next batch of data.
"""
raise NotImplementedError('Method [next_batch_data] should be implemented.')

def set_batch_size(self, batch_size):
"""Reset the batch_size of the dataloader, but it can't be called when dataloader is being iterated.

Args:
batch_size (int): the new batch_size of dataloader.
"""
if self.pr != 0:
raise PermissionError('Cannot change dataloader\'s batch_size while iteration')
self.batch_size = batch_size
self._batch_size = batch_size

def collate_fn(self):
"""Collect the sampled index, and apply neg_sampling or other methods to get the final data.
"""
raise NotImplementedError('Method [collate_fn] must be implemented.')


class NegSampleDataLoader(AbstractDataLoader):
Expand All @@ -120,6 +104,7 @@ class NegSampleDataLoader(AbstractDataLoader):
"""

def __init__(self, config, dataset, sampler, shuffle=True):
self.logger = getLogger()
super().__init__(config, dataset, sampler, shuffle=shuffle)

def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):
Expand Down Expand Up @@ -159,7 +144,9 @@ def _neg_sampling(self, inter_feat):
candidate_num = self.neg_sample_args['dynamic']
user_ids = inter_feat[self.uid_field].numpy()
item_ids = inter_feat[self.iid_field].numpy()
neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num)
neg_candidate_ids = self._sampler.sample_by_user_ids(
user_ids, item_ids, self.neg_sample_num * candidate_num
)
self.model.eval()
interaction = copy.deepcopy(inter_feat).to(self.model.device)
interaction = interaction.repeat(self.neg_sample_num * candidate_num)
Expand All @@ -174,15 +161,15 @@ def _neg_sampling(self, inter_feat):
elif self.neg_sample_args['strategy'] == 'by':
user_ids = inter_feat[self.uid_field].numpy()
item_ids = inter_feat[self.iid_field].numpy()
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
neg_item_ids = self._sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
return self.sampling_func(inter_feat, neg_item_ids)
else:
return inter_feat

def _neg_sample_by_pair_wise_sampling(self, inter_feat, neg_item_ids):
inter_feat = inter_feat.repeat(self.times)
neg_item_feat = Interaction({self.iid_field: neg_item_ids})
neg_item_feat = self.dataset.join(neg_item_feat)
neg_item_feat = self._dataset.join(neg_item_feat)
neg_item_feat.add_prefix(self.neg_prefix)
inter_feat.update(neg_item_feat)
return inter_feat
Expand All @@ -191,7 +178,7 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids):
pos_inter_num = len(inter_feat)
new_data = inter_feat.repeat(self.times)
new_data[self.iid_field][pos_inter_num:] = neg_item_ids
new_data = self.dataset.join(new_data)
new_data = self._dataset.join(new_data)
labels = torch.zeros(pos_inter_num * self.times)
labels[:pos_inter_num] = 1.0
new_data.update(Interaction({self.label_field: labels}))
Expand Down
88 changes: 36 additions & 52 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np
import torch

from logging import getLogger
from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader, NegSampleDataLoader
from recbole.data.interaction import Interaction, cat_interactions
from recbole.utils import InputType, ModelType
Expand All @@ -34,7 +34,9 @@ class TrainDataLoader(NegSampleDataLoader):
"""

def __init__(self, config, dataset, sampler, shuffle=False):
self.logger = getLogger()
self._set_neg_sample_args(config, dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args'])
self.sample_size = len(dataset)
super().__init__(config, dataset, sampler, shuffle=shuffle)

def _init_batch_size_and_step(self):
Expand All @@ -49,20 +51,13 @@ def _init_batch_size_and_step(self):
self.set_batch_size(batch_size)

def update_config(self, config):
self._set_neg_sample_args(config, self.dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args'])
self._set_neg_sample_args(config, self._dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args'])
super().update_config(config)

@property
def pr_end(self):
return len(self.dataset)

def _shuffle(self):
self.dataset.shuffle()

def _next_batch_data(self):
cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step])
self.pr += self.step
return cur_data
def collate_fn(self, index):
index = np.array(index)
data = self._dataset[index]
return self._neg_sampling(data)


class NegSampleEvalDataLoader(NegSampleDataLoader):
Expand All @@ -79,6 +74,7 @@ class NegSampleEvalDataLoader(NegSampleDataLoader):
"""

def __init__(self, config, dataset, sampler, shuffle=False):
self.logger = getLogger()
self._set_neg_sample_args(config, dataset, InputType.POINTWISE, config['eval_neg_sample_args'])
if self.neg_sample_args['strategy'] == 'by':
user_num = dataset.user_num
Expand All @@ -96,7 +92,12 @@ def __init__(self, config, dataset, sampler, shuffle=False):
self.uid2index[uid] = slice(start[uid], end[uid] + 1)
self.uid2items_num[uid] = end[uid] - start[uid] + 1
self.uid_list = np.array(self.uid_list)

self.sample_size = len(self.uid_list)
else:
self.sample_size = len(dataset)
if shuffle:
self.logger.warnning('NegSampleEvalDataLoader can\'t shuffle')
shuffle = False
super().__init__(config, dataset, sampler, shuffle=shuffle)

def _init_batch_size_and_step(self):
Expand All @@ -117,44 +118,33 @@ def _init_batch_size_and_step(self):
self.set_batch_size(batch_size)

def update_config(self, config):
self._set_neg_sample_args(config, self.dataset, InputType.POINTWISE, config['eval_neg_sample_args'])
self._set_neg_sample_args(config, self._dataset, InputType.POINTWISE, config['eval_neg_sample_args'])
super().update_config(config)

@property
def pr_end(self):
if self.neg_sample_args['strategy'] == 'by':
return len(self.uid_list)
else:
return len(self.dataset)

def _shuffle(self):
self.logger.warnning('NegSampleEvalDataLoader can\'t shuffle')

def _next_batch_data(self):
def collate_fn(self, index):
index = np.array(index)
if self.neg_sample_args['strategy'] == 'by':
uid_list = self.uid_list[self.pr:self.pr + self.step]
uid_list = self.uid_list[index]
data_list = []
idx_list = []
positive_u = []
positive_i = torch.tensor([], dtype=torch.int64)

for idx, uid in enumerate(uid_list):
index = self.uid2index[uid]
data_list.append(self._neg_sampling(self.dataset[index]))
data_list.append(self._neg_sampling(self._dataset[index]))
idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)]
positive_u += [idx for i in range(self.uid2items_num[uid])]
positive_i = torch.cat((positive_i, self.dataset[index][self.iid_field]), 0)
positive_i = torch.cat((positive_i, self._dataset[index][self.iid_field]), 0)

cur_data = cat_interactions(data_list)
idx_list = torch.from_numpy(np.array(idx_list))
positive_u = torch.from_numpy(np.array(positive_u))

self.pr += self.step
idx_list = torch.from_numpy(np.array(idx_list)).long()
positive_u = torch.from_numpy(np.array(positive_u)).long()

return cur_data, idx_list, positive_u, positive_i
else:
cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step])
self.pr += self.step
data = self._dataset[index]
cur_data = self._neg_sampling(data)
return cur_data, None, None, None


Expand All @@ -171,6 +161,7 @@ class FullSortEvalDataLoader(AbstractDataLoader):
"""

def __init__(self, config, dataset, sampler, shuffle=False):
self.logger = getLogger()
self.uid_field = dataset.uid_field
self.iid_field = dataset.iid_field
self.is_sequential = config['MODEL_TYPE'] == ModelType.SEQUENTIAL
Expand All @@ -196,6 +187,10 @@ def __init__(self, config, dataset, sampler, shuffle=False):
self.uid_list = torch.tensor(self.uid_list, dtype=torch.int64)
self.user_df = dataset.join(Interaction({self.uid_field: self.uid_list}))

self.sample_size = len(self.user_df) if not self.is_sequential else len(dataset)
if shuffle:
self.logger.warnning('FullSortEvalDataLoader can\'t shuffle')
shuffle = False
super().__init__(config, dataset, sampler, shuffle=shuffle)

def _set_user_property(self, uid, used_item, positive_item):
Expand All @@ -209,27 +204,18 @@ def _set_user_property(self, uid, used_item, positive_item):
def _init_batch_size_and_step(self):
batch_size = self.config['eval_batch_size']
if not self.is_sequential:
batch_num = max(batch_size // self.dataset.item_num, 1)
new_batch_size = batch_num * self.dataset.item_num
batch_num = max(batch_size // self._dataset.item_num, 1)
new_batch_size = batch_num * self._dataset.item_num
self.step = batch_num
self.set_batch_size(new_batch_size)
else:
self.step = batch_size
self.set_batch_size(batch_size)

@property
def pr_end(self):
if not self.is_sequential:
return len(self.uid_list)
else:
return len(self.dataset)

def _shuffle(self):
self.logger.warnning('FullSortEvalDataLoader can\'t shuffle')

def _next_batch_data(self):
def collate_fn(self, index):
index = np.array(index)
if not self.is_sequential:
user_df = self.user_df[self.pr:self.pr + self.step]
user_df = self.user_df[index]
uid_list = list(user_df[self.uid_field])

history_item = self.uid2history_item[uid_list]
Expand All @@ -241,13 +227,11 @@ def _next_batch_data(self):
positive_u = torch.cat([torch.full_like(pos_iid, i) for i, pos_iid in enumerate(positive_item)])
positive_i = torch.cat(list(positive_item))

self.pr += self.step
return user_df, (history_u, history_i), positive_u, positive_i
else:
interaction = self.dataset[self.pr:self.pr + self.step]
interaction = self._dataset[index]
inter_num = len(interaction)
positive_u = torch.arange(inter_num)
positive_i = interaction[self.iid_field]

self.pr += self.step
return interaction, None, positive_u, positive_i
Loading