-
Notifications
You must be signed in to change notification settings - Fork 1
/
mmoe_diffgating.py
257 lines (221 loc) · 11.7 KB
/
mmoe_diffgating.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2019/8/21 10:17
# @Author : Daishijun
# @Site :
# @File : mmoe_diffgating.py
# @Software : PyCharm
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Layer, InputSpec
# from keras import backend as K
from tensorflow.python.keras import activations, initializers, regularizers, constraints
# from keras import activations, initializers, regularizers, constraints
# from keras.engine.topology import Layer, InputSpec
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
class MMoEdiffGate(Layer):
"""
Multi-gate Mixture-of-Experts model.
"""
def __init__(self,
units,
num_experts,
num_tasks,
use_expert_bias=True,
use_gate_bias=True,
expert_activation='relu',
gate_activation='softmax',
expert_bias_initializer='zeros',
gate_bias_initializer='zeros',
expert_bias_regularizer=None,
gate_bias_regularizer=None,
expert_bias_constraint=None,
gate_bias_constraint=None,
expert_kernel_initializer='VarianceScaling',
gate_kernel_initializer='VarianceScaling',
expert_kernel_regularizer=None,
gate_kernel_regularizer=None,
expert_kernel_constraint=None,
gate_kernel_constraint=None,
activity_regularizer=None,
**kwargs):
"""
Method for instantiating MMoE layer.
:param units: Number of hidden units
:param num_experts: Number of experts
:param num_tasks: Number of tasks
:param use_expert_bias: Boolean to indicate the usage of bias in the expert weights
:param use_gate_bias: Boolean to indicate the usage of bias in the gate weights
:param expert_activation: Activation function of the expert weights
:param gate_activation: Activation function of the gate weights
:param expert_bias_initializer: Initializer for the expert bias
:param gate_bias_initializer: Initializer for the gate bias
:param expert_bias_regularizer: Regularizer for the expert bias
:param gate_bias_regularizer: Regularizer for the gate bias
:param expert_bias_constraint: Constraint for the expert bias
:param gate_bias_constraint: Constraint for the gate bias
:param expert_kernel_initializer: Initializer for the expert weights
:param gate_kernel_initializer: Initializer for the gate weights
:param expert_kernel_regularizer: Regularizer for the expert weights
:param gate_kernel_regularizer: Regularizer for the gate weights
:param expert_kernel_constraint: Constraint for the expert weights
:param gate_kernel_constraint: Constraint for the gate weights
:param activity_regularizer: Regularizer for the activity
:param kwargs: Additional keyword arguments for the Layer class
"""
# Hidden nodes parameter
self.units = units
self.num_experts = num_experts
self.num_tasks = num_tasks
# Weight parameter
self.expert_kernels = None
self.gate_kernels = None
self.expert_kernel_initializer = initializers.get(expert_kernel_initializer)
self.gate_kernel_initializer = initializers.get(gate_kernel_initializer)
self.expert_kernel_regularizer = regularizers.get(expert_kernel_regularizer)
self.gate_kernel_regularizer = regularizers.get(gate_kernel_regularizer)
self.expert_kernel_constraint = constraints.get(expert_kernel_constraint)
self.gate_kernel_constraint = constraints.get(gate_kernel_constraint)
# Activation parameter
self.expert_activation = activations.get(expert_activation)
self.gate_activation = activations.get(gate_activation)
# Bias parameter
self.expert_bias = None
self.gate_bias = None
self.use_expert_bias = use_expert_bias
self.use_gate_bias = use_gate_bias
self.expert_bias_initializer = initializers.get(expert_bias_initializer)
self.gate_bias_initializer = initializers.get(gate_bias_initializer)
self.expert_bias_regularizer = regularizers.get(expert_bias_regularizer)
self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer)
self.expert_bias_constraint = constraints.get(expert_bias_constraint)
self.gate_bias_constraint = constraints.get(gate_bias_constraint)
# Activity parameter
self.activity_regularizer = regularizers.get(activity_regularizer)
# Keras parameter
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
super(MMoEdiffGate, self).__init__(**kwargs)
def build(self, input_shape):
"""
Method for creating the layer weights.
:param input_shape: Keras tensor (future input to layer)
or list/tuple of Keras tensors to reference
for weight shape computations
"""
assert isinstance(input_shape, list) #需要包含多个输入,所以inputshape需要是个list
assert input_shape is not None and len(input_shape) >= 2 and len(input_shape[0])
input_expert_dimension = int(input_shape[0][-1])
input_gate_dimension = int(input_shape[1][-1])
# Initialize expert weights (number of input features * number of units per expert * number of experts)
self.expert_kernels = self.add_weight(
name='expert_kernel',
shape=(input_expert_dimension, self.units, self.num_experts),
initializer=self.expert_kernel_initializer,
regularizer=self.expert_kernel_regularizer,
constraint=self.expert_kernel_constraint,
)
# Initialize expert bias (number of units per expert * number of experts)
if self.use_expert_bias:
self.expert_bias = self.add_weight(
name='expert_bias',
shape=(self.units, self.num_experts),
initializer=self.expert_bias_initializer,
regularizer=self.expert_bias_regularizer,
constraint=self.expert_bias_constraint,
)
# Initialize gate weights (number of input features * number of experts * number of tasks)
self.gate_kernels = [self.add_weight(
name='gate_kernel_task_{}'.format(i),
shape=(input_gate_dimension, self.num_experts),
initializer=self.gate_kernel_initializer,
regularizer=self.gate_kernel_regularizer,
constraint=self.gate_kernel_constraint
) for i in range(self.num_tasks)]
# Initialize gate bias (number of experts * number of tasks)
if self.use_gate_bias:
self.gate_bias = [self.add_weight(
name='gate_bias_task_{}'.format(i),
shape=(self.num_experts,),
initializer=self.gate_bias_initializer,
regularizer=self.gate_bias_regularizer,
constraint=self.gate_bias_constraint
) for i in range(self.num_tasks)]
# self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension})
super(MMoEdiffGate, self).build(input_shape)
def call(self, inputs, **kwargs):
"""
Method for the forward function of the layer.
:param inputs: Input tensor
:param kwargs: Additional keyword arguments for the base method
:return: A tensor
"""
assert isinstance(inputs, list)
expert_input, gate_input = inputs
gate_outputs = []
final_outputs = []
# f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
expert_outputs = tf.tensordot(a=expert_input, b=self.expert_kernels, axes=1)
# Add the bias term to the expert weights if necessary
if self.use_expert_bias:
expert_outputs = K.bias_add(x=expert_outputs, bias=self.expert_bias)
expert_outputs = self.expert_activation(expert_outputs)
# g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
for index, gate_kernel in enumerate(self.gate_kernels):
gate_output = K.dot(x=gate_input, y=gate_kernel)
# Add the bias term to the gate weights if necessary
if self.use_gate_bias:
gate_output = K.bias_add(x=gate_output, bias=self.gate_bias[index])
gate_output = self.gate_activation(gate_output)
gate_outputs.append(gate_output)
## []list, 每一个元素是一维数组,数组元素是每个expert的gating权值, 共有task个数组
# f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
for gate_output in gate_outputs:
expanded_gate_output = K.expand_dims(gate_output, axis=1)
weighted_expert_output = expert_outputs * K.repeat_elements(expanded_gate_output, self.units, axis=1)
final_outputs.append(K.sum(weighted_expert_output, axis=2))
return final_outputs
## 返回值也是[]list, 有task个元素,是expert layer的对于每个task的输出。
def compute_output_shape(self, input_shape):
"""
Method for computing the output shape of the MMoE layer.
:param input_shape: Shape tuple (tuple of integers)
:return: List of input shape tuple where the size of the list is equal to the number of tasks
"""
assert isinstance(input_shape, list)
assert input_shape is not None and len(input_shape) >= 2
output_shape = list(input_shape[0])
output_shape[-1] = self.units
output_shape = tuple(output_shape)
return [output_shape for _ in range(self.num_tasks)]
def get_config(self):
"""
Method for returning the configuration of the MMoE layer.
:return: Config dictionary
"""
config = {
'units': self.units,
'num_experts': self.num_experts,
'num_tasks': self.num_tasks,
'use_expert_bias': self.use_expert_bias,
'use_gate_bias': self.use_gate_bias,
'expert_activation': activations.serialize(self.expert_activation),
'gate_activation': activations.serialize(self.gate_activation),
'expert_bias_initializer': initializers.serialize(self.expert_bias_initializer),
'gate_bias_initializer': initializers.serialize(self.gate_bias_initializer),
'expert_bias_regularizer': regularizers.serialize(self.expert_bias_regularizer),
'gate_bias_regularizer': regularizers.serialize(self.gate_bias_regularizer),
'expert_bias_constraint': constraints.serialize(self.expert_bias_constraint),
'gate_bias_constraint': constraints.serialize(self.gate_bias_constraint),
'expert_kernel_initializer': initializers.serialize(self.expert_kernel_initializer),
'gate_kernel_initializer': initializers.serialize(self.gate_kernel_initializer),
'expert_kernel_regularizer': regularizers.serialize(self.expert_kernel_regularizer),
'gate_kernel_regularizer': regularizers.serialize(self.gate_kernel_regularizer),
'expert_kernel_constraint': constraints.serialize(self.expert_kernel_constraint),
'gate_kernel_constraint': constraints.serialize(self.gate_kernel_constraint),
'activity_regularizer': regularizers.serialize(self.activity_regularizer)
}
base_config = super(MMoEdiffGate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
## config里设定的值会覆盖 父类中的值。