Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get latest code #1

Merged
merged 14 commits into from
Nov 12, 2020
32 changes: 19 additions & 13 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class Config(object):
def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None):
"""
Args:
model (str): the model name, default is None, if it is None, config will search the parameter 'model'
from the external input as the model name.
model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
will search the parameter 'model' from the external input as the model name or model class.
dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
from the external input as the dataset name.
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
Expand All @@ -71,8 +71,9 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
self.variable_config_dict = self._load_variable_config_dict(config_dict)
self.cmd_config_dict = self._load_cmd_line()
self._merge_external_config_dict()
self.model, self.dataset = self._get_model_and_dataset(model, dataset)
self._load_internal_config_dict(self.model, self.dataset)

self.model, self.model_class, self.dataset = self._get_model_and_dataset(model, dataset)
self._load_internal_config_dict(self.model, self.model_class, self.dataset)
self.final_config_dict = self._get_final_config_dict()
self._set_default_parameters()
self._init_device()
Expand Down Expand Up @@ -167,14 +168,20 @@ def _merge_external_config_dict(self):
self.external_config_dict = external_config_dict

def _get_model_and_dataset(self, model, dataset):

if model is None:
try:
final_model = self.external_config_dict['model']
model = self.external_config_dict['model']
except KeyError:
raise KeyError('model need to be specified in at least one of the these ways: '
'[model variable, config file, config dict, command line] ')
raise KeyError(
'model need to be specified in at least one of the these ways: '
'[model variable, config file, config dict, command line] ')
if not isinstance(model, str):
final_model_class = model
final_model = model.__name__
else:
final_model = model
final_model_class = get_model(final_model)

if dataset is None:
try:
Expand All @@ -185,9 +192,9 @@ def _get_model_and_dataset(self, model, dataset):
else:
final_dataset = dataset

return final_model, final_dataset
return final_model, final_model_class, final_dataset

def _load_internal_config_dict(self, model, dataset):
def _load_internal_config_dict(self, model, model_class, dataset):
current_path = os.path.dirname(os.path.realpath(__file__))
overall_init_file = os.path.join(current_path, '../properties/overall.yaml')
model_init_file = os.path.join(current_path, '../properties/model/' + model + '.yaml')
Expand All @@ -204,8 +211,7 @@ def _load_internal_config_dict(self, model, dataset):
key not in self.parameters['Dataset']]
if config_dict is not None:
self.internal_config_dict.update(config_dict)

self.internal_config_dict['MODEL_TYPE'] = get_model(model).type
self.internal_config_dict['MODEL_TYPE'] = model_class.type
if self.internal_config_dict['MODEL_TYPE'] == ModelType.GENERAL:
pass
elif self.internal_config_dict['MODEL_TYPE'] == ModelType.CONTEXT:
Expand Down Expand Up @@ -271,8 +277,8 @@ def _set_default_parameters(self):
else:
self.final_config_dict['data_path'] = os.path.join(self.final_config_dict['data_path'], self.dataset)

if hasattr(get_model(self.model), 'input_type'):
self.final_config_dict['MODEL_INPUT_TYPE'] = get_model(self.model).input_type
if hasattr(self.model_class, 'input_type'):
self.final_config_dict['MODEL_INPUT_TYPE'] = self.model_class.input_type
elif 'loss_type' in self.final_config_dict:
if self.final_config_dict['loss_type'] in ['CE']:
self.final_config_dict['MODEL_INPUT_TYPE'] = InputType.POINTWISE
Expand Down
26 changes: 25 additions & 1 deletion recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/10/28 2020/10/13, 2020/10/25
# @Time : 2020/10/28 2020/10/13, 2020/11/10
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, panxy@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -164,6 +164,7 @@ def _data_filtering(self):
thus :meth:`~recbole.data.dataset.dataset.Dataset._reset_index()` will reset the index of feats.
"""
self._filter_nan_user_or_item()
self._remove_duplication()
self._filter_by_field_value()
self._filter_by_inter_num()
self._reset_index()
Expand Down Expand Up @@ -603,6 +604,29 @@ def _filter_nan_user_or_item(self):
name, list(dropped_inter + 2), field))
self.inter_feat.drop(self.inter_feat.index[dropped_inter], inplace=True)

def _remove_duplication(self):
"""Remove duplications in inter_feat.

If :attr:`self.config['rm_dup_inter']` is not ``None``, it will remove duplicated user-item interactions.

Note:
Before removing duplicated user-item interactions, if :attr:`time_field` existed, :attr:`inter_feat`
will be sorted by :attr:`time_field` in ascending order.
"""
keep = self.config['rm_dup_inter']
if keep is None:
return
self._check_field('uid_field', 'iid_field')

if self.time_field in self.inter_feat:
self.inter_feat.sort_values(by=[self.time_field], ascending=True, inplace=True)
self.logger.info('Records in original dataset have been sorted by value of [{}] in ascending order.'.format(
self.time_field))
else:
self.logger.warning('Timestamp field has not been loaded or specified, '
'thus strategy [{}] of duplication removal may be meaningless.'.format(keep))
self.inter_feat.drop_duplicates(subset=[self.uid_field, self.iid_field], keep=keep, inplace=True)

def _filter_by_inter_num(self):
"""Filter by number of interaction.

Expand Down
1 change: 1 addition & 0 deletions recbole/properties/dataset/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ unload_col: ~
additional_feat_suffix: ~

# Filtering
rm_dup_inter: ~
max_user_inter_num: ~
min_user_inter_num: 0
max_item_inter_num: ~
Expand Down