Skip to content

Commit

Permalink
Merge pull request #1 from RUCAIBox/master
Browse files Browse the repository at this point in the history
update
  • Loading branch information
2017pxy authored Nov 13, 2020
2 parents 4e1b06a + 9c6898c commit 953bc7a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 18 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ oriented to the GPU environment.
for testing and comparing recommendation algorithms.

## RecBole News
**11/03/2020**: We release the first version of RecBole **v0.1.0 release**.
**11/03/2020**: We release the first version of RecBole **v0.1.1**.


## Installation
Expand Down Expand Up @@ -156,7 +156,7 @@ python run_recbole.py --model=[model_name]
## RecBole Major Releases
| Releases | Date | Features |
|-----------|--------|-------------------------|
| v0.1.0 | 11/03/2020 | Basic RecBole |
| v0.1.1 | 11/03/2020 | Basic RecBole |


## Contributing
Expand All @@ -169,7 +169,7 @@ We expect all contributions discussed in the issue tracker and going through PRs


## Cite
If you find RecBole useful for your research or development, please cite the following [paper](https://arxiv.org/abs/2011.01731).
If you find RecBole useful for your research or development, please cite the following [paper](https://arxiv.org/abs/2011.01731):

```
@article{recbole,
Expand Down
2 changes: 1 addition & 1 deletion conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: recbole
version: 0.1.0
version: 0.1.1

source:
path: ../
Expand Down
32 changes: 19 additions & 13 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class Config(object):
def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None):
"""
Args:
model (str): the model name, default is None, if it is None, config will search the parameter 'model'
from the external input as the model name.
model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
will search the parameter 'model' from the external input as the model name or model class.
dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
from the external input as the dataset name.
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
Expand All @@ -71,8 +71,9 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
self.variable_config_dict = self._load_variable_config_dict(config_dict)
self.cmd_config_dict = self._load_cmd_line()
self._merge_external_config_dict()
self.model, self.dataset = self._get_model_and_dataset(model, dataset)
self._load_internal_config_dict(self.model, self.dataset)

self.model, self.model_class, self.dataset = self._get_model_and_dataset(model, dataset)
self._load_internal_config_dict(self.model, self.model_class, self.dataset)
self.final_config_dict = self._get_final_config_dict()
self._set_default_parameters()
self._init_device()
Expand Down Expand Up @@ -167,14 +168,20 @@ def _merge_external_config_dict(self):
self.external_config_dict = external_config_dict

def _get_model_and_dataset(self, model, dataset):

if model is None:
try:
final_model = self.external_config_dict['model']
model = self.external_config_dict['model']
except KeyError:
raise KeyError('model need to be specified in at least one of the these ways: '
'[model variable, config file, config dict, command line] ')
raise KeyError(
'model need to be specified in at least one of the these ways: '
'[model variable, config file, config dict, command line] ')
if not isinstance(model, str):
final_model_class = model
final_model = model.__name__
else:
final_model = model
final_model_class = get_model(final_model)

if dataset is None:
try:
Expand All @@ -185,9 +192,9 @@ def _get_model_and_dataset(self, model, dataset):
else:
final_dataset = dataset

return final_model, final_dataset
return final_model, final_model_class, final_dataset

def _load_internal_config_dict(self, model, dataset):
def _load_internal_config_dict(self, model, model_class, dataset):
current_path = os.path.dirname(os.path.realpath(__file__))
overall_init_file = os.path.join(current_path, '../properties/overall.yaml')
model_init_file = os.path.join(current_path, '../properties/model/' + model + '.yaml')
Expand All @@ -204,8 +211,7 @@ def _load_internal_config_dict(self, model, dataset):
key not in self.parameters['Dataset']]
if config_dict is not None:
self.internal_config_dict.update(config_dict)

self.internal_config_dict['MODEL_TYPE'] = get_model(model).type
self.internal_config_dict['MODEL_TYPE'] = model_class.type
if self.internal_config_dict['MODEL_TYPE'] == ModelType.GENERAL:
pass
elif self.internal_config_dict['MODEL_TYPE'] == ModelType.CONTEXT:
Expand Down Expand Up @@ -271,8 +277,8 @@ def _set_default_parameters(self):
else:
self.final_config_dict['data_path'] = os.path.join(self.final_config_dict['data_path'], self.dataset)

if hasattr(get_model(self.model), 'input_type'):
self.final_config_dict['MODEL_INPUT_TYPE'] = get_model(self.model).input_type
if hasattr(self.model_class, 'input_type'):
self.final_config_dict['MODEL_INPUT_TYPE'] = self.model_class.input_type
elif 'loss_type' in self.final_config_dict:
if self.final_config_dict['loss_type'] in ['CE']:
self.final_config_dict['MODEL_INPUT_TYPE'] = InputType.POINTWISE
Expand Down
26 changes: 25 additions & 1 deletion recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/10/28 2020/10/13, 2020/10/25
# @Time : 2020/10/28 2020/10/13, 2020/11/10
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, panxy@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -164,6 +164,7 @@ def _data_filtering(self):
thus :meth:`~recbole.data.dataset.dataset.Dataset._reset_index()` will reset the index of feats.
"""
self._filter_nan_user_or_item()
self._remove_duplication()
self._filter_by_field_value()
self._filter_by_inter_num()
self._reset_index()
Expand Down Expand Up @@ -603,6 +604,29 @@ def _filter_nan_user_or_item(self):
name, list(dropped_inter + 2), field))
self.inter_feat.drop(self.inter_feat.index[dropped_inter], inplace=True)

def _remove_duplication(self):
"""Remove duplications in inter_feat.
If :attr:`self.config['rm_dup_inter']` is not ``None``, it will remove duplicated user-item interactions.
Note:
Before removing duplicated user-item interactions, if :attr:`time_field` existed, :attr:`inter_feat`
will be sorted by :attr:`time_field` in ascending order.
"""
keep = self.config['rm_dup_inter']
if keep is None:
return
self._check_field('uid_field', 'iid_field')

if self.time_field in self.inter_feat:
self.inter_feat.sort_values(by=[self.time_field], ascending=True, inplace=True)
self.logger.info('Records in original dataset have been sorted by value of [{}] in ascending order.'.format(
self.time_field))
else:
self.logger.warning('Timestamp field has not been loaded or specified, '
'thus strategy [{}] of duplication removal may be meaningless.'.format(keep))
self.inter_feat.drop_duplicates(subset=[self.uid_field, self.iid_field], keep=keep, inplace=True)

def _filter_by_inter_num(self):
"""Filter by number of interaction.
Expand Down
1 change: 1 addition & 0 deletions recbole/properties/dataset/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ unload_col: ~
additional_feat_suffix: ~

# Filtering
rm_dup_inter: ~
max_user_inter_num: ~
min_user_inter_num: 0
max_item_inter_num: ~
Expand Down

0 comments on commit 953bc7a

Please sign in to comment.