diff --git a/example/neural_collaborative_filtering/README.md b/example/neural_collaborative_filtering/README.md index 819f4d94dff9..00d3ed12295b 100644 --- a/example/neural_collaborative_filtering/README.md +++ b/example/neural_collaborative_filtering/README.md @@ -29,15 +29,6 @@ Author: Dr. Xiangnan He (http://www.comp.nus.edu.sg/~xiangnan/) Code Reference: https://github.com/hexiangnan/neural_collaborative_filtering -## Environment Settings -We use MXnet with MKL-DNN as the backend. -- MXNet version: '1.5.1' - -## Install -``` -pip install -r requirements.txt -``` - ## Dataset We provide the processed datasets on [Google Drive](https://drive.google.com/drive/folders/1qACR_Zhc2O2W0RrazzcepM2vJeh0MMdO?usp=sharing): MovieLens 20 Million (ml-20m), you can download directly or @@ -66,7 +57,9 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co |dtype|HR@10|NDCG@10| |:---:|:--:|:--:| |float32|0.6393|0.3849| -|int8|0.6366|0.3824| +|float32 opt|0.6393|0.3849| +|int8|0.6395|0.3852| +|int8 opt|0.6396|0.3852| ## Training @@ -75,11 +68,20 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co python train.py # --gpu=0 ``` +## Model Optimizer + +``` +# optimize model +python model_optimizer.py +``` + ## Calibration ``` # neumf calibration on ml-20m dataset python ncf.py --prefix=./model/ml-20m/neumf --calibration +# optimized neumf calibration on ml-20m dataset +python ncf.py --prefix=./model/ml-20m/neumf-opt --calibration ``` ## Evaluation @@ -87,15 +89,25 @@ python ncf.py --prefix=./model/ml-20m/neumf --calibration ``` # neumf float32 inference on ml-20m dataset python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf +# optimized neumf float32 inference on ml-20m dataset +python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt # neumf int8 inference on ml-20m dataset python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized +# optimized neumf int8 inference on ml-20m dataset +python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt-quantized ``` ## Benchmark ``` +usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]] + # neumf float32 benchmark on ml-20m dataset -python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf --benchmark +sh benchmark.sh -p model/ml-20m/neumf +# optimized neumf float32 benchmark on ml-20m dataset +sh benchmark.sh -p model/ml-20m/neumf-opt # neumf int8 benchmark on ml-20m dataset -python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized --benchmark +sh benchmark.sh -p model/ml-20m/neumf-quantized +# optimized neumf int8 benchmark on ml-20m dataset +sh benchmark.sh -p model/ml-20m/neumf-opt-quantized ``` diff --git a/example/neural_collaborative_filtering/benchmark.sh b/example/neural_collaborative_filtering/benchmark.sh new file mode 100644 index 000000000000..60fec746cd20 --- /dev/null +++ b/example/neural_collaborative_filtering/benchmark.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +usage() +{ + echo "usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]]" +} + +while [ $# -gt 0 ]; do + case "$1" in + --prefix | -p) + shift + PREFIX=$1 + ;; + --epoch | -e) + shift + EPOCH=$1 + ;; + --dataset | -d) + shift + DATASET=$1 + ;; + --batch-size | -b) + shift + BS=$1 + ;; + --instance | -i) + shift + INS=$1 + ;; + --core | -c) + shift + CORES=$1 + ;; + --help | -h) + usage + exit 1 + ;; + *) + usage + exit 1 + esac + shift +done + +NUM_SOCKET=`lscpu | grep 'Socket(s)' | awk '{print $NF}'` +NUM_NUMA_NODE=`lscpu | grep 'NUMA node(s)' | awk '{print $NF}'` +CORES_PER_SOCKET=`lscpu | grep 'Core(s) per socket' | awk '{print $NF}'` +NUM_CORES=$((CORES_PER_SOCKET * NUM_SOCKET)) +CORES_PER_NUMA=$((NUM_CORES / NUM_NUMA_NODE)) +echo "target machine has $NUM_CORES physical core(s) on $NUM_NUMA_NODE numa nodes of $NUM_SOCKET socket(s)." + +if [ -z $PREFIX ]; then + echo "Error: Need a model prefix." + exit +fi +if [ -z $EPOCH ]; then + echo "Default: set epoch of model parameters to 7." + EPOCH=7 +fi +if [ -z $DATASET ]; then + echo "Default: set dataset to ml-20m." + DATASET='ml-20m' +fi +if [ -z $INS ]; then + echo "Default: launch one instance per physical core." + INS=$NUM_CORES +fi +if [ -z $CORES ]; then + echo "Default: divide full physical cores." + CORES=$((NUM_CORES / $INS)) +fi +if [ -z $BS ]; then + echo "Default: set batch size to 700." + BS=700 +fi + +echo " cores/instance: $CORES" +echo " total instances: $INS" +echo " batch size: $BS" +echo "" + +rm NCF_*.log + +for((i=0;i<$INS;i++)); +do + ((a=$i*$CORES)) + ((b=$a+$CORES-1)) + memid=$((b/CORES_PER_NUMA)) + LOG=NCF_$i.log + echo " $i instance use $a-$b cores with $LOG" + KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 \ + OMP_NUM_THREADS=$CORES \ + numactl --physcpubind=$a-$b --membind=$memid python ncf.py --batch-size=$BS --dataset=$DATASET --epoch=$EPOCH --benchmark --prefix=$PREFIX 2>&1 | tee $LOG & +done +wait + +grep speed NCF_*.log | awk '{ sum += $(NF-1) }; END { print "Total Performance is " sum " samples/sec"}' diff --git a/example/neural_collaborative_filtering/convert.py b/example/neural_collaborative_filtering/convert.py index 4c64d2cdedab..7fb7f1ede9e4 100644 --- a/example/neural_collaborative_filtering/convert.py +++ b/example/neural_collaborative_filtering/convert.py @@ -38,7 +38,7 @@ def parse_args(): parser = ArgumentParser() parser.add_argument('--dataset', nargs='?', default='ml-20m', choices=['ml-1m', 'ml-20m'], help='The dataset name, temporary support ml-1m and ml-20m.') - parser.add_argument('path', type=str, default = './data/', + parser.add_argument('--path', type=str, default = './data/', help='Path to reviews CSV file from MovieLens') parser.add_argument('-n', '--negatives', type=int, default=999, help='Number of negative samples for each positive' diff --git a/example/neural_collaborative_filtering/core/model.py b/example/neural_collaborative_filtering/core/model.py index b516e5039fed..6c03bb01a357 100644 --- a/example/neural_collaborative_filtering/core/model.py +++ b/example/neural_collaborative_filtering/core/model.py @@ -37,6 +37,27 @@ def _init_weight(self, _, arr): limit = np.sqrt(3. / self._fan_in) mx.random.uniform(-limit, limit, out=arr) +# only for inference model optimize +def mlp_opt(user, item, factor_size, model_layers, max_user, max_item): + user_weight = mx.sym.Variable('fused_mlp_user_weight', init=mx.init.Normal(0.01)) + item_weight = mx.sym.Variable('fused_mlp_item_weight', init=mx.init.Normal(0.01)) + embed_user = mx.sym.Embedding(data=user, weight=user_weight, input_dim=max_user, + output_dim=factor_size * 2, name='fused_embed_user'+str(factor_size)) + embed_item = mx.sym.Embedding(data=item, weight=item_weight, input_dim=max_item, + output_dim=factor_size * 2, name='fused_embed_item'+str(factor_size)) + pre_gemm_concat = embed_user + embed_item + + for i in range(1, len(model_layers)): + if i==1: + pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1)) + continue + else: + mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i]) + mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init) + pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1)) + pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1)) + + return pre_gemm_concat def mlp(user, item, factor_size, model_layers, max_user, max_item): user_weight = mx.sym.Variable('mlp_user_weight', init=mx.init.Normal(0.01)) @@ -47,14 +68,11 @@ def mlp(user, item, factor_size, model_layers, max_user, max_item): output_dim=factor_size, name='embed_item'+str(factor_size)) pre_gemm_concat = mx.sym.concat(embed_user, embed_item, dim=1, name='pre_gemm_concat') - for i, layer in enumerate(model_layers): - if i==0: - mlp_weight_init = golorot_uniform(2 * factor_size, model_layers[i]) - else: - mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i]) - mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i), init=mlp_weight_init) - pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=layer, name='fc_'+str(i)) - pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i)) + for i in range(1, len(model_layers)): + mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i]) + mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init) + pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1)) + pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1)) return pre_gemm_concat @@ -70,24 +88,34 @@ def gmf(user, item, factor_size, max_user, max_item): return pred def get_model(model_type='neumf', factor_size_mlp=128, factor_size_gmf=64, - model_layers=[256, 128, 64], num_hidden=1, - max_user=138493, max_item=26744): + model_layers=[256, 256, 128, 64], num_hidden=1, + max_user=138493, max_item=26744, opt=False): # input user = mx.sym.Variable('user') item = mx.sym.Variable('item') if model_type == 'mlp': - net = mlp(user=user, item=item, - factor_size=factor_size_mlp, model_layers=model_layers, - max_user=max_user, max_item=max_item) + if opt: + net = mlp_opt(user=user, item=item, + factor_size=factor_size_mlp, model_layers=model_layers, + max_user=max_user, max_item=max_item) + else: + net = mlp(user=user, item=item, + factor_size=factor_size_mlp, model_layers=model_layers, + max_user=max_user, max_item=max_item) elif model_type == 'gmf': net = gmf(user=user, item=item, factor_size=factor_size_gmf, max_user=max_user, max_item=max_item) elif model_type == 'neumf': - net_mlp = mlp(user=user, item=item, - factor_size=factor_size_mlp, model_layers=model_layers, - max_user=max_user, max_item=max_item) + if opt: + net_mlp = mlp_opt(user=user, item=item, + factor_size=factor_size_mlp, model_layers=model_layers, + max_user=max_user, max_item=max_item) + else: + net_mlp = mlp(user=user, item=item, + factor_size=factor_size_mlp, model_layers=model_layers, + max_user=max_user, max_item=max_item) net_gmf = gmf(user=user, item=item, factor_size=factor_size_gmf, max_user=max_user, max_item=max_item) diff --git a/example/neural_collaborative_filtering/model_optimizer.py b/example/neural_collaborative_filtering/model_optimizer.py new file mode 100644 index 000000000000..2866ae7e7e05 --- /dev/null +++ b/example/neural_collaborative_filtering/model_optimizer.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import time +import argparse +import logging +import math +import random +import numpy as np +import mxnet as mx +from core.model import get_model +from core.dataset import NCFTrainData + +logging.basicConfig(level=logging.DEBUG) + +parser = argparse.ArgumentParser(description="Run model optimizer.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--path', nargs='?', default='./data/', + help='Input data path.') +parser.add_argument('--dataset', nargs='?', default='ml-20m', + help='The dataset name.') +parser.add_argument('--model-prefix', type=str, default='./model/ml-20m/neumf') +parser.add_argument('--epoch', type=int, default=7, help='parameters epoch') +parser.add_argument('--model-type', type=str, default='neumf', choices=['neumf', 'gmf', 'mlp'], + help="mdoel type") +parser.add_argument('--layers', default='[256, 256, 128, 64]', + help="list of number hiddens of fc layers in mlp model.") +parser.add_argument('--factor-size-gmf', type=int, default=64, + help="outdim of gmf embedding layers.") +parser.add_argument('--num-hidden', type=int, default=1, + help="num-hidden of neumf fc layer") + +head = '%(asctime)-15s %(message)s' +logging.basicConfig(level=logging.INFO, format=head) + +# arg parser +args = parser.parse_args() +logging.info(args) + +model_prefix = args.model_prefix +model_type = args.model_type +model_layers = eval(args.layers) +factor_size_gmf = args.factor_size_gmf +factor_size_mlp = int(model_layers[0]/2) +num_hidden = args.num_hidden +train_dataset = NCFTrainData((args.path + args.dataset + '/train-ratings.csv'), nb_neg=4) +net = get_model(model_type, factor_size_mlp, factor_size_gmf, + model_layers, num_hidden, train_dataset.nb_users, train_dataset.nb_items, opt=True) + +raw_params, _ = mx.model.load_params(model_prefix, args.epoch) +fc_0_weight_split = mx.nd.split(raw_params['fc_0_weight'], axis=1, num_outputs=2) +fc_0_left = fc_0_weight_split[0] +fc_0_right = fc_0_weight_split[1] + +user_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_user_weight'], weight=fc_0_left, bias=raw_params['fc_0_bias'], no_bias=False, num_hidden=model_layers[0]) +item_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_item_weight'], weight=fc_0_right, no_bias=True, num_hidden=model_layers[0]) + +opt_params = raw_params +del opt_params['mlp_user_weight'] +del opt_params['mlp_item_weight'] +del opt_params['fc_0_bias'] +opt_params['fused_mlp_user_weight'] = user_weight_fusion +opt_params['fused_mlp_item_weight'] = item_weight_fusion + +mx.model.save_checkpoint(model_prefix + '-opt', args.epoch, net, opt_params, {}) + diff --git a/example/neural_collaborative_filtering/ncf.py b/example/neural_collaborative_filtering/ncf.py index 0fd9f733a1bd..b01be01bc8d9 100644 --- a/example/neural_collaborative_filtering/ncf.py +++ b/example/neural_collaborative_filtering/ncf.py @@ -42,20 +42,12 @@ help='max number of item index.') parser.add_argument('--batch-size', type=int, default=256, help='number of examples per batch') -parser.add_argument('--model-type', type=str, default='neumf', choices=['neumf', 'gmf', 'mlp'], - help="mdoel type") -parser.add_argument('--layers', default='[256, 128, 64]', - help="list of number hiddens of fc layers in mlp model.") -parser.add_argument('--factor-size-gmf', type=int, default=64, - help="outdim of gmf embedding layers.") -parser.add_argument('--num-hidden', type=int, default=1, - help="num-hidden of neumf fc layer") parser.add_argument('--topk', type=int, default=10, help="topk for accuracy evaluation.") parser.add_argument('--gpu', type=int, default=None, help="index of gpu to run, e.g. 0 or 1. None means using cpu().") parser.add_argument('--benchmark', action='store_true', help="whether to benchmark performance only") -parser.add_argument('--epoch', type=int, default=0, help='model checkpoint index for inference') +parser.add_argument('--epoch', type=int, default=7, help='model checkpoint index for inference') parser.add_argument('--prefix', default='./model/ml-20m/neumf', help="model checkpoint prefix") parser.add_argument('--calibration', action='store_true', help="whether to calibrate model") parser.add_argument('--calib-mode', type=str, choices=['naive', 'entropy'], default='naive', @@ -85,11 +77,6 @@ max_user = args.max_user max_item = args.max_item batch_size = args.batch_size - model_type = args.model_type - model_layers = eval(args.layers) - factor_size_gmf = args.factor_size_gmf - factor_size_mlp = int(model_layers[0]/2) - num_hidden = args.num_hidden benchmark = args.benchmark calibration = args.calibration calib_mode = args.calib_mode @@ -129,7 +116,7 @@ cqsym, cqarg_params, aux_params, collector = quantize_graph(sym=net, arg_params=arg_params, aux_params=aux_params, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, - quantized_dtype=args.quantized_dtype, logger=logging) + quantized_dtype=quantized_dtype, logger=logging) max_num_examples = num_calib_batches * batch_size mod._exec_group.execs[0].set_monitor_callback(collector.collect, monitor_all=True) num_batches = 0 @@ -144,12 +131,17 @@ % (num_batches, batch_size)) cqsym, cqarg_params, aux_params = calib_graph(qsym=cqsym, arg_params=arg_params, aux_params=aux_params, collector=collector, calib_mode=calib_mode, - quantized_dtype=args.quantized_dtype, logger=logging) + quantized_dtype=quantized_dtype, logger=logging) sym_name = '%s-symbol.json' % (args.prefix + '-quantized') cqsym = cqsym.get_backend_symbol('MKLDNN_QUANTIZE') mx.model.save_checkpoint(args.prefix + '-quantized', args.epoch, cqsym, cqarg_params, aux_params) elif benchmark: logging.info('Benchmarking...') + data = [mx.random.randint(0, 1000, shape=shape, ctx=ctx) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) # empty label + for i in range(2000): + mod.forward(batch, is_train=False) + logging.info('Benchmarking...') num_samples = 0 for ib, batch in enumerate(val_iter): if ib == 5: diff --git a/example/neural_collaborative_filtering/train.py b/example/neural_collaborative_filtering/train.py index 0b0cfad1ef39..c68f271a6f0d 100644 --- a/example/neural_collaborative_filtering/train.py +++ b/example/neural_collaborative_filtering/train.py @@ -45,7 +45,7 @@ help="mdoel type") parser.add_argument('--num-negative', type=int, default=4, help="number of negative samples per positive sample while training.") -parser.add_argument('--layers', default='[256, 128, 64]', +parser.add_argument('--layers', default='[256, 256, 128, 64]', help="list of number hiddens of fc layers in mlp model.") parser.add_argument('--factor-size-gmf', type=int, default=64, help="outdim of gmf embedding layers.")