Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

仍对ema_update函数存在疑问 #6

Open
jiafengren opened this issue Mar 18, 2024 · 3 comments
Open

仍对ema_update函数存在疑问 #6

jiafengren opened this issue Mar 18, 2024 · 3 comments

Comments

@jiafengren
Copy link

在之前的Issue中讨论了self.momentum参数的问题
def ema_update(self):
def update(student, teacher):
with torch.no_grad():
# m = momentum_schedule[it] # momentum parameter
m = self.momentum
for param_q, param_k in zip(student.parameters(), teacher.parameters()):
param_k.data.mul
(m).add
((1 - m) * param_q.detach().data)
update(self.encoder, self.encoder_LT)
update(self.projector, self.projector_ema)
注意到self._momentum=0.94是一个接近一的值,为了保持encoder 和 ema_encoder 存在一定的距离。但注意到ema_encoder的参数不计算梯度,低学习率是否会导致ema_encoder得不到充分的训练

@yf-he
Copy link
Contributor

yf-he commented Mar 18, 2024

在之前的Issue中讨论了self.momentum参数的问题 def ema_update(self): def update(student, teacher): with torch.no_grad(): # m = momentum_schedule[it] # momentum parameter m = self.momentum for param_q, param_k in zip(student.parameters(), teacher.parameters()): param_k.data.mul(m).add((1 - m) * param_q.detach().data) update(self.encoder, self.encoder_LT) update(self.projector, self.projector_ema) 注意到self._momentum=0.94是一个接近一的值,为了保持encoder 和 ema_encoder 存在一定的距离。但注意到ema_encoder的参数不计算梯度,低学习率是否会导致ema_encoder得不到充分的训练

感谢你的关注!

momentum/ema 是一个在SSL without using negative pairs算法中经常使用的更新target network的技巧。直觉上,在训练中我们鼓励online encoder(+predictor)和target encoder的表示尽可能相似,因此为了避免collapse,我们需要大的ema decay(eg. 0.996)。也有很多工作探索了ema对于训练稳定性、regularization的影响。事实上一些工作发现大的ema decay反而使得target encoder学习到了相较于online encoder更加stable/generalizable的表示。

但从实验的角度来说,我的建议是将ema decay作为一个超参数来进行调试。在GraphMAE2中,我们发现比较常见的0.996就可以带来不错的效果。也欢迎你探索其他可能的取值对于结果影响。

@jiafengren
Copy link
Author

jiafengren commented Mar 19, 2024

感谢您的回复!
我尝试分别使用两种策略验证GraphMAE2对于节点分类的影响,当只应用多视图随机重掩码解码策略时,loss在短短的几个epoch后归零,这是否意味着发生了特征坍缩?
当完整的应用GraphMAE2后,假阳性有明显的降低,我能否这样理解:GraphMAE的固定的重掩码策略导致encoder过多的关注哪些具有高辨识性的节点特征从而导致对于具有微小变化的节点特征不敏感,多视图随机重掩码策略意在将encoder的均匀地分配给不同的节点特征,但导致了过拟合的问题,因而采用潜在表征预测策略保持online encoder和target encoder的表示相似性,来避免过拟合导致的梯度消失问题(特征坍缩)。

@yf-he
Copy link
Contributor

yf-he commented Mar 21, 2024

感谢您的回复! 我尝试分别使用两种策略验证GraphMAE2对于节点分类的影响,当只应用多视图随机重掩码解码策略时,loss在短短的几个epoch后归零,这是否意味着发生了特征坍缩? 当完整的应用GraphMAE2后,假阳性有明显的降低,我能否这样理解:GraphMAE的固定的重掩码策略导致encoder过多的关注哪些具有高辨识性的节点特征从而导致对于具有微小变化的节点特征不敏感,多视图随机重掩码策略意在将encoder的均匀地分配给不同的节点特征,但导致了过拟合的问题,因而采用潜在表征预测策略保持online encoder和target encoder的表示相似性,来避免过拟合导致的梯度消失问题(特征坍缩)。

我不太了解你使用的具体数据,但单独使用GraphMAE2中的reconstruction loss应该可以预期一个和GraphMAE差不多的合理结果。具体可以参考下paper里的ablation studies中GraphMAE2中的两个decoding部分对结果的影响。

如果出现loss归零的情况,我的建议是首先尝试调整一下mask/remask ratio等超参数。以及节点特征是否具有某些特殊的性质。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants