Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Further optimization for NCF model #17148

Merged
merged 5 commits into from
Dec 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions example/neural_collaborative_filtering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ Author: Dr. Xiangnan He (http://www.comp.nus.edu.sg/~xiangnan/)

Code Reference: https://github.com/hexiangnan/neural_collaborative_filtering

## Environment Settings
We use MXnet with MKL-DNN as the backend.
- MXNet version: '1.5.1'

## Install
```
pip install -r requirements.txt
```

## Dataset

We provide the processed datasets on [Google Drive](https://drive.google.com/drive/folders/1qACR_Zhc2O2W0RrazzcepM2vJeh0MMdO?usp=sharing): MovieLens 20 Million (ml-20m), you can download directly or
Expand Down Expand Up @@ -66,7 +57,9 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co
|dtype|HR@10|NDCG@10|
|:---:|:--:|:--:|
|float32|0.6393|0.3849|
|int8|0.6366|0.3824|
|float32 opt|0.6393|0.3849|
|int8|0.6395|0.3852|
|int8 opt|0.6396|0.3852|

## Training

Expand All @@ -75,27 +68,46 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co
python train.py # --gpu=0
```

## Model Optimizer

```
# optimize model
python model_optimizer.py
```

## Calibration

```
# neumf calibration on ml-20m dataset
python ncf.py --prefix=./model/ml-20m/neumf --calibration
# optimized neumf calibration on ml-20m dataset
python ncf.py --prefix=./model/ml-20m/neumf-opt --calibration
```

## Evaluation

```
# neumf float32 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf
# optimized neumf float32 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt
# neumf int8 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized
# optimized neumf int8 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt-quantized
```

## Benchmark

```
usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]]

# neumf float32 benchmark on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf --benchmark
sh benchmark.sh -p model/ml-20m/neumf
# optimized neumf float32 benchmark on ml-20m dataset
sh benchmark.sh -p model/ml-20m/neumf-opt
# neumf int8 benchmark on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized --benchmark
sh benchmark.sh -p model/ml-20m/neumf-quantized
# optimized neumf int8 benchmark on ml-20m dataset
sh benchmark.sh -p model/ml-20m/neumf-opt-quantized
```
114 changes: 114 additions & 0 deletions example/neural_collaborative_filtering/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

usage()
{
echo "usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]]"
}

while [ $# -gt 0 ]; do
case "$1" in
--prefix | -p)
shift
PREFIX=$1
;;
--epoch | -e)
shift
EPOCH=$1
;;
--dataset | -d)
shift
DATASET=$1
;;
--batch-size | -b)
shift
BS=$1
;;
--instance | -i)
shift
INS=$1
;;
--core | -c)
shift
CORES=$1
;;
--help | -h)
usage
exit 1
;;
*)
usage
exit 1
esac
shift
done

NUM_SOCKET=`lscpu | grep 'Socket(s)' | awk '{print $NF}'`
NUM_NUMA_NODE=`lscpu | grep 'NUMA node(s)' | awk '{print $NF}'`
CORES_PER_SOCKET=`lscpu | grep 'Core(s) per socket' | awk '{print $NF}'`
NUM_CORES=$((CORES_PER_SOCKET * NUM_SOCKET))
CORES_PER_NUMA=$((NUM_CORES / NUM_NUMA_NODE))
echo "target machine has $NUM_CORES physical core(s) on $NUM_NUMA_NODE numa nodes of $NUM_SOCKET socket(s)."

if [ -z $PREFIX ]; then
echo "Error: Need a model prefix."
exit
fi
if [ -z $EPOCH ]; then
echo "Default: set epoch of model parameters to 7."
EPOCH=7
fi
if [ -z $DATASET ]; then
echo "Default: set dataset to ml-20m."
DATASET='ml-20m'
fi
if [ -z $INS ]; then
echo "Default: launch one instance per physical core."
INS=$NUM_CORES
fi
if [ -z $CORES ]; then
echo "Default: divide full physical cores."
CORES=$((NUM_CORES / $INS))
fi
if [ -z $BS ]; then
echo "Default: set batch size to 700."
BS=700
fi

echo " cores/instance: $CORES"
echo " total instances: $INS"
echo " batch size: $BS"
echo ""

rm NCF_*.log

for((i=0;i<$INS;i++));
do
((a=$i*$CORES))
((b=$a+$CORES-1))
memid=$((b/CORES_PER_NUMA))
LOG=NCF_$i.log
echo " $i instance use $a-$b cores with $LOG"
KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 \
OMP_NUM_THREADS=$CORES \
numactl --physcpubind=$a-$b --membind=$memid python ncf.py --batch-size=$BS --dataset=$DATASET --epoch=$EPOCH --benchmark --prefix=$PREFIX 2>&1 | tee $LOG &
done
wait

grep speed NCF_*.log | awk '{ sum += $(NF-1) }; END { print "Total Performance is " sum " samples/sec"}'
2 changes: 1 addition & 1 deletion example/neural_collaborative_filtering/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_args():
parser = ArgumentParser()
parser.add_argument('--dataset', nargs='?', default='ml-20m', choices=['ml-1m', 'ml-20m'],
help='The dataset name, temporary support ml-1m and ml-20m.')
parser.add_argument('path', type=str, default = './data/',
parser.add_argument('--path', type=str, default = './data/',
help='Path to reviews CSV file from MovieLens')
parser.add_argument('-n', '--negatives', type=int, default=999,
help='Number of negative samples for each positive'
Expand Down
60 changes: 44 additions & 16 deletions example/neural_collaborative_filtering/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ def _init_weight(self, _, arr):
limit = np.sqrt(3. / self._fan_in)
mx.random.uniform(-limit, limit, out=arr)

# only for inference model optimize
def mlp_opt(user, item, factor_size, model_layers, max_user, max_item):
user_weight = mx.sym.Variable('fused_mlp_user_weight', init=mx.init.Normal(0.01))
item_weight = mx.sym.Variable('fused_mlp_item_weight', init=mx.init.Normal(0.01))
embed_user = mx.sym.Embedding(data=user, weight=user_weight, input_dim=max_user,
output_dim=factor_size * 2, name='fused_embed_user'+str(factor_size))
embed_item = mx.sym.Embedding(data=item, weight=item_weight, input_dim=max_item,
output_dim=factor_size * 2, name='fused_embed_item'+str(factor_size))
pre_gemm_concat = embed_user + embed_item

for i in range(1, len(model_layers)):
if i==1:
pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))
continue
else:
mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init)
pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1))
pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))

return pre_gemm_concat

def mlp(user, item, factor_size, model_layers, max_user, max_item):
user_weight = mx.sym.Variable('mlp_user_weight', init=mx.init.Normal(0.01))
Expand All @@ -47,14 +68,11 @@ def mlp(user, item, factor_size, model_layers, max_user, max_item):
output_dim=factor_size, name='embed_item'+str(factor_size))
pre_gemm_concat = mx.sym.concat(embed_user, embed_item, dim=1, name='pre_gemm_concat')

for i, layer in enumerate(model_layers):
if i==0:
mlp_weight_init = golorot_uniform(2 * factor_size, model_layers[i])
else:
mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i), init=mlp_weight_init)
pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=layer, name='fc_'+str(i))
pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i))
for i in range(1, len(model_layers)):
mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init)
pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1))
pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))

return pre_gemm_concat

Expand All @@ -70,24 +88,34 @@ def gmf(user, item, factor_size, max_user, max_item):
return pred

def get_model(model_type='neumf', factor_size_mlp=128, factor_size_gmf=64,
model_layers=[256, 128, 64], num_hidden=1,
max_user=138493, max_item=26744):
model_layers=[256, 256, 128, 64], num_hidden=1,
max_user=138493, max_item=26744, opt=False):
# input
user = mx.sym.Variable('user')
item = mx.sym.Variable('item')

if model_type == 'mlp':
net = mlp(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
if opt:
net = mlp_opt(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
else:
net = mlp(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
elif model_type == 'gmf':
net = gmf(user=user, item=item,
factor_size=factor_size_gmf,
max_user=max_user, max_item=max_item)
elif model_type == 'neumf':
net_mlp = mlp(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
if opt:
net_mlp = mlp_opt(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
else:
net_mlp = mlp(user=user, item=item,
factor_size=factor_size_mlp, model_layers=model_layers,
max_user=max_user, max_item=max_item)
net_gmf = gmf(user=user, item=item,
factor_size=factor_size_gmf,
max_user=max_user, max_item=max_item)
Expand Down
81 changes: 81 additions & 0 deletions example/neural_collaborative_filtering/model_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import time
import argparse
import logging
import math
import random
import numpy as np
import mxnet as mx
from core.model import get_model
from core.dataset import NCFTrainData

logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser(description="Run model optimizer.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--path', nargs='?', default='./data/',
help='Input data path.')
parser.add_argument('--dataset', nargs='?', default='ml-20m',
help='The dataset name.')
parser.add_argument('--model-prefix', type=str, default='./model/ml-20m/neumf')
parser.add_argument('--epoch', type=int, default=7, help='parameters epoch')
parser.add_argument('--model-type', type=str, default='neumf', choices=['neumf', 'gmf', 'mlp'],
help="mdoel type")
parser.add_argument('--layers', default='[256, 256, 128, 64]',
help="list of number hiddens of fc layers in mlp model.")
parser.add_argument('--factor-size-gmf', type=int, default=64,
help="outdim of gmf embedding layers.")
parser.add_argument('--num-hidden', type=int, default=1,
help="num-hidden of neumf fc layer")

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)

# arg parser
args = parser.parse_args()
logging.info(args)

model_prefix = args.model_prefix
model_type = args.model_type
model_layers = eval(args.layers)
factor_size_gmf = args.factor_size_gmf
factor_size_mlp = int(model_layers[0]/2)
num_hidden = args.num_hidden
train_dataset = NCFTrainData((args.path + args.dataset + '/train-ratings.csv'), nb_neg=4)
net = get_model(model_type, factor_size_mlp, factor_size_gmf,
model_layers, num_hidden, train_dataset.nb_users, train_dataset.nb_items, opt=True)

raw_params, _ = mx.model.load_params(model_prefix, args.epoch)
fc_0_weight_split = mx.nd.split(raw_params['fc_0_weight'], axis=1, num_outputs=2)
fc_0_left = fc_0_weight_split[0]
fc_0_right = fc_0_weight_split[1]

user_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_user_weight'], weight=fc_0_left, bias=raw_params['fc_0_bias'], no_bias=False, num_hidden=model_layers[0])
item_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_item_weight'], weight=fc_0_right, no_bias=True, num_hidden=model_layers[0])

opt_params = raw_params
del opt_params['mlp_user_weight']
del opt_params['mlp_item_weight']
del opt_params['fc_0_bias']
opt_params['fused_mlp_user_weight'] = user_weight_fusion
opt_params['fused_mlp_item_weight'] = item_weight_fusion

mx.model.save_checkpoint(model_prefix + '-opt', args.epoch, net, opt_params, {})

Loading