diff --git a/.gitignore b/.gitignore index d7f8b11..73c0a45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .idea .DS_Store +*.py[cod] *.ipynb *ipynb* criteo_sample.txt diff --git a/core/__pycache__/__init__.cpython-36.pyc b/core/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index eba73a6..0000000 Binary files a/core/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/core/__pycache__/blocks.cpython-36.pyc b/core/__pycache__/blocks.cpython-36.pyc deleted file mode 100644 index 8e4773d..0000000 Binary files a/core/__pycache__/blocks.cpython-36.pyc and /dev/null differ diff --git a/core/__pycache__/features.cpython-36.pyc b/core/__pycache__/features.cpython-36.pyc deleted file mode 100644 index b01872e..0000000 Binary files a/core/__pycache__/features.cpython-36.pyc and /dev/null differ diff --git a/core/blocks.py b/core/blocks.py index b8f70b7..e745f25 100644 --- a/core/blocks.py +++ b/core/blocks.py @@ -1,4 +1,4 @@ -from collections import Iterable +from typing import Iterable import itertools import tensorflow as tf @@ -174,6 +174,22 @@ def __init__(self, self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer + # weights to be built + self.kernels = {} + + def build(self, input_shape): + # type(input_shape) == list + + for i in range(len(input_shape) - 1): + for j in range(i + 1, len(input_shape)): + kernel = self.add_weight(shape=(int(input_shape[i][1]), int(input_shape[j][1])), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + self.kernels[i, j] = kernel + + super(OuterProduct, self).build(input_shape) # Be sure to call this at the end + def call(self, inputs, **kwargs): outer_products_list = list() @@ -182,10 +198,7 @@ def call(self, inputs, **kwargs): for j in range(i + 1, len(inputs)): inp_i = tf.expand_dims(inputs[i], axis=1) inp_j = tf.expand_dims(inputs[j], axis=-1) - kernel = self.add_weight(shape=(inp_i.shape[2], inp_j.shape[1]), - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - trainable=True) + kernel = self.kernels[i, j] product = tf.reduce_sum(tf.matmul(tf.matmul(inp_i, kernel), inp_j), axis=-1, keepdims=False) outer_products_list.append(product) @@ -197,6 +210,7 @@ def call(self, inputs, **kwargs): class CrossNetwork(tf.keras.Model): def __init__(self, + layer_nums=3, kernel_initializer='glorot_uniform', kernel_regularizer=tf.keras.regularizers.l2(1e-5), bias_initializer='zeros', @@ -204,38 +218,114 @@ def __init__(self, **kwargs): super(CrossNetwork, self).__init__(**kwargs) - + self.layer_nums = layer_nums self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer self.bias_initializer = bias_initializer self.bias_regularizer = bias_regularizer - def call(self, inputs, layers_num=3, require_logit=True, **kwargs): + # weights to be built + self.W = [] + self.kernel = None - x0 = tf.expand_dims(tf.concat(inputs, axis=1), axis=-1) - x = tf.transpose(x0, [0, 2, 1]) - - for i in range(layers_num): - kernel = self.add_weight(shape=(x0.shape[1], 1), + def build(self, input_shape): + # type(input_shape) == list, input_shape[0] == TensorShape(?, embedding_size) + m = len(input_shape) * int(input_shape[0][1]) + for i in range(self.layer_nums): + kernel = self.add_weight(shape=(m, 1), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, trainable=True) - bias = self.add_weight(shape=(x0.shape[1], 1), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, + bias = self.add_weight(shape=(m, 1), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, trainable=True) + self.W.append([kernel, bias]) + self.kernel = self.add_weight(shape=(m, 1), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + # Be sure to call this at the end + super(CrossNetwork, self).build(input_shape) + + def call(self, inputs, require_logit=True, **kwargs): + + x0 = tf.expand_dims(tf.concat(inputs, axis=1), axis=-1) + x = tf.transpose(x0, [0, 2, 1]) + + for i in range(self.layer_nums): + kernel, bias = self.W[i] x = tf.matmul(tf.matmul(x0, x), kernel) + bias + tf.transpose(x, [0, 2, 1]) x = tf.transpose(x, [0, 2, 1]) x = tf.squeeze(x, axis=1) if require_logit: - kernel = self.add_weight(shape=(x0.shape[1], 1), + kernel = self.kernel + x = tf.matmul(x, kernel) + + return x + + +class CIN_(tf.keras.layers.Layer): + def __init__(self, + hidden_width=(128, 64), + kernel_initializer='glorot_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-5), + **kwargs): + super(CIN_, self).__init__(**kwargs) + self.hidden_width = hidden_width + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + + # weights to be built + self.conv_kernels = [] + self.kernel = None + + def build(self, input_shape): + # input_shape = (batch_size, field_num, D) + print(input_shape) + _, m, D = input_shape + m, D = int(m), int(D) + hidden_width = list(self.hidden_width) # H for every level + self.conv_kernels = [] + field_nums = [m] + hidden_width + for idx, layer_size in enumerate(hidden_width): + kernel = self.add_weight(shape=(1, m * field_nums[idx], layer_size), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, trainable=True) - x = tf.matmul(x, kernel) + self.conv_kernels.append(kernel) - return x + self.kernel = self.add_weight(shape=(sum(field_nums), 1), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + + super(CIN_, self).build(input_shape) + + def call(self, inputs, require_logit=True, **kwargs): + # inputs [b, m, D] + m = int(inputs.shape[1]) + D = int(inputs.shape[-1]) + x0 = inputs + finals = [x0] + x0 = tf.split(finals[-1], D * [1], 2) # (D, batch_size, m, 1) + for idx, layer_size in enumerate(self.hidden_width): + x = tf.split(finals[-1], D * [1], 2) # (D, batch_size, field_num, 1) + dot = tf.matmul(x0, x, transpose_b=True) # (D, batch_size, m, field_num) + dot = tf.reshape(dot, shape=[D, -1, m * (self.hidden_width[idx - 1] if idx > 0 else m)]) + dot = tf.transpose(dot, perm=[1, 0, 2]) # (batch_size, D, m * field_num) + conv = tf.nn.conv1d(dot, filters=self.conv_kernels[idx], stride=1, padding='VALID') + # activation + out = tf.nn.relu(conv) # (batch_size, D, layer_size) + out = tf.transpose(out, perm=[0, 2, 1]) # (batch_size, layer_size, D) + finals.append(out) + + finals = tf.concat(finals, axis=1) # (?, m+sum(hidden_width), D) + finals = tf.reduce_sum(finals, -1) # (?, m+sum(hidden_width)) + logits = tf.matmul(finals, self.kernel) + + return logits class CIN(tf.keras.Model): @@ -250,6 +340,9 @@ def __init__(self, self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer + def build(self, input_shape): + pass + def call(self, inputs, hidden_width=(128, 64), require_logit=True, **kwargs): # [b, n, m] @@ -345,28 +438,37 @@ def call(self, inputs, **kwargs): class AutoIntInteraction(tf.keras.Model): def __init__(self, att_embedding_size=8, heads=2, use_res=True, seed=2333, **kwargs): - super(AutoIntInteraction, self).__init__(**kwargs) self.att_embedding_size = att_embedding_size self.heads = heads self.use_res = use_res self.seed = seed + # weight to be build + self.W_Query = None + self.W_key = None + self.W_Value = None + self.W_Res = None + + def build(self, input_shape): + m = input_shape[-1] + self.W_Query = self.add_weight(shape=[m, self.att_embedding_size * self.heads], + initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) + self.W_key = self.add_weight(shape=[m, self.att_embedding_size * self.heads], + initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) + self.W_Value = self.add_weight(shape=[m, self.att_embedding_size * self.heads], + initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) - def call(self, inputs, **kwargs): - - m = inputs.shape[-1] - - W_Query = self.add_weight(shape=[m, self.att_embedding_size * self.heads], - initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) - W_key = self.add_weight(shape=[m, self.att_embedding_size * self.heads], - initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) - W_Value = self.add_weight(shape=[m, self.att_embedding_size * self.heads], - initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) + if self.use_res: + self.W_Res = self.add_weight(shape=[m, self.att_embedding_size * self.heads], + initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) + # don't forget to call this + super(AutoIntInteraction, self).build() - queries = tf.matmul(inputs, W_Query) - keys = tf.matmul(inputs, W_key) - values = tf.matmul(inputs, W_Value) + def call(self, inputs, **kwargs): + queries = tf.matmul(inputs, self.W_Query) + keys = tf.matmul(inputs, self.W_key) + values = tf.matmul(inputs, self.W_Value) queries = tf.stack(tf.split(queries, self.heads, axis=2)) keys = tf.stack(tf.split(keys, self.heads, axis=2)) @@ -380,9 +482,7 @@ def call(self, inputs, **kwargs): result = tf.squeeze(result, axis=0) if self.use_res: - W_Res = self.add_weight(shape=[m, self.att_embedding_size * self.heads], - initializer=tf.keras.initializers.RandomNormal(seed=self.seed)) - result = result + tf.matmul(inputs, W_Res) + result = result + tf.matmul(inputs, self.W_Res) result = tf.keras.activations.relu(result) @@ -392,34 +492,44 @@ def call(self, inputs, **kwargs): class FGCNNlayer(tf.keras.layers.Layer): def __init__(self, filters, kernel_width, new_feat_filters, pool_width, **kwargs): - super(FGCNNlayer, self).__init__(**kwargs) self.filters = filters self.kernel_width = kernel_width self.new_feat_filters = new_feat_filters self.pool_width = pool_width + # module to be built + self.conv2d = None + self.max_pooling2d = None + self.dense = None - def call(self, inputs, **kwargs): - - output = inputs - output = tf.keras.layers.Conv2D( + def build(self, input_shape): + self.conv2d = tf.keras.layers.Conv2D( filters=self.filters, strides=(1, 1), kernel_size=(self.kernel_width, 1), padding='same', activation='tanh', use_bias=True - )(output) - output = tf.keras.layers.MaxPooling2D( - pool_size=(self.pool_width, 1) - )(output) - new_feat_output = tf.keras.layers.Flatten()(output) - new_feat_output = tf.keras.layers.Dense( - units=output.shape[1] * output.shape[2] * self.new_feat_filters, + ) + self.max_pooling2d = tf.keras.layers.MaxPooling2D(pool_size=(self.pool_width, 1)) + conv_out_shape = self.conv2d.compute_output_shape(input_shape) + pool_out_shape = self.max_pooling2d.compute_output_shape(conv_out_shape) + + self.dense = tf.keras.layers.Dense( + units=pool_out_shape.shape[1] * pool_out_shape.shape[2] * self.new_feat_filters, activation='tanh', use_bias=True - )(new_feat_output) + ) + + super(FGCNNlayer, self).build(input_shape) # Be sure to call this at the end + + def call(self, inputs, **kwargs): + output = inputs + output = self.conv2d(output) + output = self.max_pooling2d(output) + new_feat_output = tf.keras.layers.Flatten()(output) + new_feat_output = self.dense(new_feat_output) new_feat_output = tf.reshape(new_feat_output, shape=(-1, output.shape[1] * self.new_feat_filters, output.shape[2])) @@ -428,24 +538,55 @@ def call(self, inputs, **kwargs): class BiInteraction(tf.keras.Model): - def __init__(self, mode='all', **kwargs): - + def __init__(self, mode='all', kernel_initializer='glorot_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-5), **kwargs): super(BiInteraction, self).__init__(**kwargs) - self.mode = mode + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + + # weights to be built + self.W = None + self.Ws = [] + self.Wss = {} + + def build(self, input_shape): + # type(input_shape) == list, input_shape[0] == TensorShape(?, embedding_size) + embedding_size = int(input_shape[0][-1]) + if self.mode == 'all': + self.W = self.add_weight(shape=(embedding_size, embedding_size), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + + elif self.mode == 'each': + for i in range(len(input_shape) - 1): + W = self.add_weight(shape=(embedding_size, embedding_size), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + self.Ws.append(W) + + elif self.mode == 'interaction': + for i in range(len(input_shape) - 1): + for j in range(i, len(input_shape)): + W = self.add_weight(shape=(embedding_size, embedding_size), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True) + self.Wss[i, j] = W + else: + raise ValueError('Expected mode (all, each, interaction), got {} instead'.format(self.mode)) + + # don't forget to call this + super(BiInteraction, self).build(input_shape) def call(self, inputs, **kwargs): output = list() - embedding_size = inputs[0].shape[-1] if self.mode == 'all': - W = self.add_weight( - shape=(embedding_size, embedding_size), - initializer='glorot_uniform', - regularizer=tf.keras.regularizers.l2(1e-5), - trainable=True - ) + W = self.W for i in range(len(inputs) - 1): for j in range(i, len(inputs)): inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] @@ -453,12 +594,7 @@ def call(self, inputs, **kwargs): elif self.mode == 'each': for i in range(len(inputs) - 1): - W = self.add_weight( - shape=(embedding_size, embedding_size), - initializer='glorot_uniform', - regularizer=tf.keras.regularizers.l2(1e-5), - trainable=True - ) + W = self.Ws[i] for j in range(i, len(inputs)): inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] output.append(inter) @@ -466,12 +602,7 @@ def call(self, inputs, **kwargs): elif self.mode == 'interaction': for i in range(len(inputs) - 1): for j in range(i, len(inputs)): - W = self.add_weight( - shape=(embedding_size, embedding_size), - initializer='glorot_uniform', - regularizer=tf.keras.regularizers.l2(1e-5), - trainable=True - ) + W = self.Wss[i, j] inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] output.append(inter) @@ -479,31 +610,37 @@ def call(self, inputs, **kwargs): return output -class SENet(tf.keras.Model): +class SENet(tf.keras.layers.Layer): def __init__(self, axis=-1, reduction=4, **kwargs): - super(SENet, self).__init__(**kwargs) self.axis = axis self.reduction = reduction - - def call(self, inputs, **kwargs): - - # inputs [batch_size, feats_num, embedding_size] - feats_num = inputs.shape[1] - - weights = tf.reduce_mean(inputs, axis=self.axis, keepdims=False) # [batch_size, feats_num] - W1 = self.add_weight( + # weights to be built + self.W1 = None + self.W2 = None + + def build(self, input_shape): + assert len(input_shape) == 3 + feats_num = input_shape[1] + self.W1 = self.add_weight( shape=(feats_num, self.reduction), trainable=True, initializer='glorot_normal' ) - W2 = self.add_weight( + self.W2 = self.add_weight( shape=(self.reduction, feats_num), trainable=True, initializer='glorot_normal' ) + + super(SENet, self).build(input_shape) + + def call(self, inputs, **kwargs): + # inputs [batch_size, feats_num, embedding_size] + weights = tf.reduce_mean(inputs, axis=self.axis, keepdims=False) # [batch_size, feats_num] + W1, W2 = self.W1, self.W2 weights = tf.keras.activations.relu(tf.tensordot(weights, W1, axes=(-1, 0))) weights = tf.keras.activations.relu(tf.tensordot(weights, W2, axes=(-1, 0))) diff --git a/examples/ctr_predict.py b/examples/ctr_predict.py index f5b95dc..e30d126 100644 --- a/examples/ctr_predict.py +++ b/examples/ctr_predict.py @@ -15,6 +15,7 @@ from models.AFM import AFM from models.AutoInt import AutoInt from models.CCPM import CCPM +from models.NFFM import NFFM from core.features import FeatureMetas if __name__ == "__main__": @@ -24,6 +25,8 @@ # Get columns' names sparse_features = list(data.columns) + sparse_features.remove('click') + sparse_features.remove('id') target = ['click'] # Preprocess your data @@ -37,14 +40,14 @@ train_0 = train[train.click == 0] train_1 = train[train.click == 1] train = pd.concat([train_1, train_0[0:len(train_1)]]) - train_model_input = {name: train[name] for name in sparse_features} - test_model_input = {name: test[name] for name in sparse_features} + train_model_input = {name: train[name].values for name in sparse_features} + test_model_input = {name: test[name].values for name in sparse_features} # Instantiate a FeatureMetas object, add your features' meta information to it feature_metas = FeatureMetas() for feat in sparse_features: feature_metas.add_sparse_feature(name=feat, one_hot_dim=data[feat].nunique(), embedding_dim=32) - + # a warning need to be fixed see https://stackoverflow.com/questions/35892412/tensorflow-dense-gradient-explanation # Instantiate a model and compile it model = DeepFM( feature_metas=feature_metas, @@ -60,7 +63,7 @@ history = model.fit(x=train_model_input, y=train[target].values, batch_size=128, - epochs=1, + epochs=3, verbose=2, validation_split=0.2) diff --git a/models/__init__.py b/models/__init__.py index e69de29..4d61003 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -0,0 +1,15 @@ + +from .AFM import AFM +from .AutoInt import AutoInt +from .CCPM import CCPM +from .DCN import DCN +from .DeepFM import DeepFM +from .FGCNN import FGCNN +from .FiBiNet import FiBiNet +from .FNN import FNN +from .MLR import MLR +from .NFFM import NFFM +from .NFM import NFM +from .PNN import PNN +from .WideAndDeep import WideAndDeep +from .xDeepFM import xDeepFM diff --git a/models/__pycache__/DeepFM.cpython-36.pyc b/models/__pycache__/DeepFM.cpython-36.pyc deleted file mode 100644 index f5f5f1b..0000000 Binary files a/models/__pycache__/DeepFM.cpython-36.pyc and /dev/null differ diff --git a/models/__pycache__/FNN.cpython-36.pyc b/models/__pycache__/FNN.cpython-36.pyc deleted file mode 100644 index 9b83cdc..0000000 Binary files a/models/__pycache__/FNN.cpython-36.pyc and /dev/null differ diff --git a/models/__pycache__/PNN.cpython-36.pyc b/models/__pycache__/PNN.cpython-36.pyc deleted file mode 100644 index e87b9a9..0000000 Binary files a/models/__pycache__/PNN.cpython-36.pyc and /dev/null differ diff --git a/models/__pycache__/WideAndDeep.cpython-36.pyc b/models/__pycache__/WideAndDeep.cpython-36.pyc deleted file mode 100644 index 90f8064..0000000 Binary files a/models/__pycache__/WideAndDeep.cpython-36.pyc and /dev/null differ diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 4107667..0000000 Binary files a/models/__pycache__/__init__.cpython-36.pyc and /dev/null differ