Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Model] Implement Neural Collaborative Filtering with MXNet #16689

Merged
merged 7 commits into from
Nov 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions example/neural_collaborative_filtering/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# Neural Collaborative Filtering

[![Build Status](https://travis-ci.com/xinyu-intel/ncf_mxnet.svg?branch=master)](https://travis-ci.com/xinyu-intel/ncf_mxnet)

This is MXNet implementation for the paper:

Xiangnan He, Lizi Liao, Hanwang Zhang, Liqiang Nie, Xia Hu and Tat-Seng Chua (2017). [Neural Collaborative Filtering.](http://dl.acm.org/citation.cfm?id=3052569) In Proceedings of WWW '17, Perth, Australia, April 03-07, 2017.

Three collaborative filtering models: Generalized Matrix Factorization (GMF), Multi-Layer Perceptron (MLP), and Neural Matrix Factorization (NeuMF). To target the models for implicit feedback and ranking task, we optimize them using log loss with negative sampling.

Author: Dr. Xiangnan He (http://www.comp.nus.edu.sg/~xiangnan/)

Code Reference: https://github.com/hexiangnan/neural_collaborative_filtering

## Environment Settings
We use MXnet with MKL-DNN as the backend.
- MXNet version: '1.5.1'

## Install
```
pip install -r requirements.txt
```

## Dataset

We provide the processed datasets on [Google Drive](https://drive.google.com/drive/folders/1qACR_Zhc2O2W0RrazzcepM2vJeh0MMdO?usp=sharing): MovieLens 20 Million (ml-20m), you can download directly or
run the script to prepare the datasets:
```
python convert.py ./data/
```

train-ratings.csv
- Train file (positive instances).
- Each Line is a training instance: userID\t itemID\t

test-ratings.csv
- Test file (positive instances).
- Each Line is a testing instance: userID\t itemID\t

test-negative.csv
- Test file (negative instances).
- Each line corresponds to the line of test.rating, containing 999 negative samples.
- Each line is in the format: userID,\t negativeItemID1\t negativeItemID2 ...

## Pre-trained models

We provide the pretrained ml-20m model on [Google Drive](https://drive.google.com/drive/folders/1qACR_Zhc2O2W0RrazzcepM2vJeh0MMdO?usp=sharing), you can download directly for evaluation or calibration.

|dtype|HR@10|NDCG@10|
|:---:|:--:|:--:|
|float32|0.6393|0.3849|
|int8|0.6366|0.3824|

## Training

```
# train ncf model with ml-20m dataset
python train.py # --gpu=0
```

## Calibration

```
# neumf calibration on ml-20m dataset
python ncf.py --prefix=./model/ml-20m/neumf --calibration
```

## Evaluation

```
# neumf float32 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf
# neumf int8 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized
```

## Benchmark

```
# neumf float32 benchmark on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf --benchmark
# neumf int8 benchmark on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized --benchmark
```
64 changes: 64 additions & 0 deletions example/neural_collaborative_filtering/ci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import mxnet as mx
from core.model import get_model

def test_model():
def test_ncf(model_type):
net = get_model(model_type=model_type, factor_size_mlp=128, factor_size_gmf=64,
model_layers=[256, 128, 64], num_hidden=1, max_user=138493, max_item=26744)
mod = mx.module.Module(net, context=mx.cpu(), data_names=['user', 'item'], label_names=['softmax_label'])
provide_data = [mx.io.DataDesc(name='item', shape=((1,))),
mx.io.DataDesc(name='user', shape=((1,)))]
provide_label = [mx.io.DataDesc(name='softmax_label', shape=((1,)))]
mod.bind(for_training=True, data_shapes=provide_data, label_shapes=provide_label)
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
data = [mx.nd.full(shape=shape, val=26744, ctx=mx.cpu(), dtype='int32')
for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
mod.forward(batch)
mod.backward()
mx.nd.waitall()

data_dict = {'user': data[0], 'item': data[1]}
calib_data = mx.io.NDArrayIter(data=data_dict, batch_size=1)
calib_data = mx.test_utils.DummyIter(calib_data)
arg_params, aux_params = mod.get_params()
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model_mkldnn(sym=net,
arg_params=arg_params,
aux_params=aux_params,
ctx=mx.cpu(),
quantized_dtype='auto',
calib_mode='naive',
calib_data=calib_data,
data_names=['user', 'item'],
excluded_sym_names=['post_gemm_concat', 'fc_final'],
num_calib_examples=1)
qmod = mx.module.Module(qsym, context=mx.cpu(), data_names=['user', 'item'], label_names=['softmax_label'])
qmod.bind(for_training=True, data_shapes=provide_data, label_shapes=provide_label)
qmod.set_params(qarg_params, qaux_params)
qmod.forward(batch)
mx.nd.waitall()

for model_type in ['neumf', 'mlp', 'gmf']:
test_ncf(model_type)

if __name__ == "__main__":
import nose
nose.runmodule()

127 changes: 127 additions & 0 deletions example/neural_collaborative_filtering/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import urllib
import zipfile
from argparse import ArgumentParser
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
from core.load import implicit_load

MIN_RATINGS = 20

USER_COLUMN = 'user_id'
ITEM_COLUMN = 'item_id'

TRAIN_RATINGS_FILENAME = 'train-ratings.csv'
TEST_RATINGS_FILENAME = 'test-ratings.csv'
TEST_NEG_FILENAME = 'test-negative.csv'

def parse_args():
parser = ArgumentParser()
parser.add_argument('--dataset', nargs='?', default='ml-20m', choices=['ml-1m', 'ml-20m'],
help='The dataset name, temporary support ml-1m and ml-20m.')
parser.add_argument('path', type=str, default = './data/',
help='Path to reviews CSV file from MovieLens')
parser.add_argument('-n', '--negatives', type=int, default=999,
help='Number of negative samples for each positive'
'test example')
parser.add_argument('-s', '--seed', type=int, default=0,
help='Random seed to reproduce same negative samples')
return parser.parse_args()

def get_movielens_data(data_dir, dataset):
if not os.path.exists(data_dir + '%s.zip' % dataset):
os.mkdir(data_dir)
urllib.request.urlretrieve('http://files.grouplens.org/datasets/movielens/%s.zip' % dataset, data_dir + dataset + '.zip')
with zipfile.ZipFile(data_dir + "%s.zip" % dataset, "r") as f:
f.extractall(data_dir + "./")

def main():
args = parse_args()
np.random.seed(args.seed)

print("download movielens {} dataset".format(args.dataset))
get_movielens_data(args.path, args.dataset)
output = os.path.join(args.path, args.dataset)

print("Loading raw data from {}".format(output))
df = implicit_load(os.path.join(output,"ratings.csv"), sort=False)

print("Filtering out users with less than {} ratings".format(MIN_RATINGS))
grouped = df.groupby(USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= MIN_RATINGS)

print("Mapping original user and item IDs to new sequential IDs")
original_users = df[USER_COLUMN].unique()
original_items = df[ITEM_COLUMN].unique()

user_map = {user: index for index, user in enumerate(original_users)}
item_map = {item: index for index, item in enumerate(original_items)}

df[USER_COLUMN] = df[USER_COLUMN].apply(lambda user: user_map[user])
df[ITEM_COLUMN] = df[ITEM_COLUMN].apply(lambda item: item_map[item])

assert df[USER_COLUMN].max() == len(original_users) - 1
assert df[ITEM_COLUMN].max() == len(original_items) - 1

print("Creating list of items for each user")
# Need to sort before popping to get last item
df.sort_values(by='timestamp', inplace=True)
all_ratings = set(zip(df[USER_COLUMN], df[ITEM_COLUMN]))
user_to_items = defaultdict(list)
for row in tqdm(df.itertuples(), desc='Ratings', total=len(df)):
user_to_items[getattr(row, USER_COLUMN)].append(getattr(row, ITEM_COLUMN)) # noqa: E501

test_ratings = []
test_negs = []
all_items = set(range(len(original_items)))

print("Generating {} negative samples for each user"
.format(args.negatives))

for user in tqdm(range(len(original_users)), desc='Users', total=len(original_users)): # noqa: E501
test_item = user_to_items[user].pop()

all_ratings.remove((user, test_item))
all_negs = all_items - set(user_to_items[user])
all_negs = sorted(list(all_negs)) # determinism

test_ratings.append((user, test_item))
test_negs.append(list(np.random.choice(all_negs, args.negatives)))

print("Saving train and test CSV files to {}".format(output))
df_train_ratings = pd.DataFrame(list(all_ratings))
df_train_ratings['fake_rating'] = 1
df_train_ratings.to_csv(os.path.join(output, TRAIN_RATINGS_FILENAME),
index=False, header=False, sep='\t')

df_test_ratings = pd.DataFrame(test_ratings)
df_test_ratings['fake_rating'] = 1
df_test_ratings.to_csv(os.path.join(output, TEST_RATINGS_FILENAME),
index=False, header=False, sep='\t')

df_test_negs = pd.DataFrame(test_negs)
df_test_negs.to_csv(os.path.join(output, TEST_NEG_FILENAME),
index=False, header=False, sep='\t')

if __name__ == '__main__':
main()

99 changes: 99 additions & 0 deletions example/neural_collaborative_filtering/core/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import mxnet as mx
import numpy as np
import pandas as pd
import scipy.sparse as sp

class NCFTestData(object):
def __init__(self, path):
'''
Constructor
path: converted data root
testRatings: converted test ratings data
testNegatives: negative samples for evaluation dataset
'''
self.testRatings = self.load_rating_file_as_list(os.path.join(path, 'test-ratings.csv'))
self.testNegatives = self.load_negative_file(os.path.join(path ,'test-negative.csv'))
assert len(self.testRatings) == len(self.testNegatives)

def load_rating_file_as_list(self, filename):
ratingList = []
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
user, item = int(arr[0]), int(arr[1])
ratingList.append([user, item])
line = f.readline()
return ratingList

def load_negative_file(self, filename):
negativeList = []
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
negatives = []
for x in arr:
negatives.append(int(x))
negativeList.append(negatives)
line = f.readline()
return negativeList

class NCFTrainData(mx.gluon.data.Dataset):
def __init__(self, train_fname, nb_neg):
'''
Constructor
train_fname: converted data root
nb_neg: number of negative samples per positive sample while training
'''
self._load_train_matrix(train_fname)
self.nb_neg = nb_neg

def _load_train_matrix(self, train_fname):
def process_line(line):
tmp = line.split('\t')
return [int(tmp[0]), int(tmp[1]), float(tmp[2]) > 0]
with open(train_fname, 'r') as file:
data = list(map(process_line, file))
self.nb_users = max(data, key=lambda x: x[0])[0] + 1
self.nb_items = max(data, key=lambda x: x[1])[1] + 1

self.data = list(filter(lambda x: x[2], data))
self.mat = sp.dok_matrix(
(self.nb_users, self.nb_items), dtype=np.float32)
for user, item, _ in data:
self.mat[user, item] = 1.

def __len__(self):
return (self.nb_neg + 1) * len(self.data)

def __getitem__(self, idx):
if idx % (self.nb_neg + 1) == 0:
idx = idx // (self.nb_neg + 1)
return self.data[idx][0], self.data[idx][1], np.ones(1, dtype=np.float32).item() # noqa: E501
else:
idx = idx // (self.nb_neg + 1)
u = self.data[idx][0]
j = mx.random.randint(0, self.nb_items).asnumpy().item()
while (u, j) in self.mat:
j = mx.random.randint(0, self.nb_items).asnumpy().item()
xinyu-intel marked this conversation as resolved.
Show resolved Hide resolved
return u, j, np.zeros(1, dtype=np.float32).item()

Loading