-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
仍对ema_update函数存在疑问 #6
Comments
感谢你的关注! momentum/ema 是一个在SSL without using negative pairs算法中经常使用的更新target network的技巧。直觉上,在训练中我们鼓励online encoder(+predictor)和target encoder的表示尽可能相似,因此为了避免collapse,我们需要大的ema decay(eg. 0.996)。也有很多工作探索了ema对于训练稳定性、regularization的影响。事实上一些工作发现大的ema decay反而使得target encoder学习到了相较于online encoder更加stable/generalizable的表示。 但从实验的角度来说,我的建议是将ema decay作为一个超参数来进行调试。在GraphMAE2中,我们发现比较常见的0.996就可以带来不错的效果。也欢迎你探索其他可能的取值对于结果影响。 |
感谢您的回复! |
我不太了解你使用的具体数据,但单独使用GraphMAE2中的reconstruction loss应该可以预期一个和GraphMAE差不多的合理结果。具体可以参考下paper里的ablation studies中GraphMAE2中的两个decoding部分对结果的影响。 如果出现loss归零的情况,我的建议是首先尝试调整一下mask/remask ratio等超参数。以及节点特征是否具有某些特殊的性质。 |
在之前的Issue中讨论了self.momentum参数的问题
def ema_update(self):
def update(student, teacher):
with torch.no_grad():
# m = momentum_schedule[it] # momentum parameter
m = self.momentum
for param_q, param_k in zip(student.parameters(), teacher.parameters()):
param_k.data.mul(m).add((1 - m) * param_q.detach().data)
update(self.encoder, self.encoder_LT)
update(self.projector, self.projector_ema)
注意到self._momentum=0.94是一个接近一的值,为了保持encoder 和 ema_encoder 存在一定的距离。但注意到ema_encoder的参数不计算梯度,低学习率是否会导致ema_encoder得不到充分的训练
The text was updated successfully, but these errors were encountered: