Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 3 】为 PaddleScience 增加损失函数权重自适应功能 #142

Merged
merged 11 commits into from
Aug 23, 2022
Merged

Conversation

Asthestarsfalll
Copy link
Contributor

@Asthestarsfalll Asthestarsfalll commented Jul 27, 2022

PR types

New features

PR changes

APIs

Describe

添加Grad Norm以实现多loss均衡,目前尚未添加测试代码,需要进一步考虑如何进行测试

@paddle-bot
Copy link

paddle-bot bot commented Jul 27, 2022

Thanks for your contribution!

@paddle-bot
Copy link

paddle-bot bot commented Jul 27, 2022

✅ This PR's description meets the template requirements!
Please wait for other CI results.

rightpeach
rightpeach previously approved these changes Jul 29, 2022
Copy link
Contributor

@rightpeach rightpeach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请添加测试代码,如有任何问题请及时沟通

@Asthestarsfalll
Copy link
Contributor Author

你好,grad norm中需要进行微分计算,如何使用numpy进行检验呢?

@Asthestarsfalll
Copy link
Contributor Author

测试结果:
image

@pytest.mark.api_network_GradNorm
def test_GradNorm0():
xy_data = np.array([[0.1, 0.5, 0.3, 0.4, 0.2]])
u = np.array([1.138526], dtype=np.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些数值是在相同的初始化方法、随机种子、输入的情况下,使用该仓库的逻辑通过Pytorch计算得来,代码如下:

import torch
import torch.nn as nn
from functools import partial
import numpy as np
from torch.nn.init import constant_

class FCNet(nn.Module):
    def __init__(self,
                 num_ins,
                 num_outs,
                 num_layers,
                 hidden_size,
                 activation='tanh',
                 n_loss=1):
        super(FCNet, self).__init__()

        self.num_ins = num_ins
        self.num_outs = num_outs
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.weights = nn.Parameter(torch.ones(n_loss).float())
        # self.weights = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]).float())
        if activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'tanh':
            self.activation = torch.tanh
        else:
            assert 0, "Unsupported activation type."
        w = []
        self.num_layers = num_layers
        for i in range(num_layers):
            if i == 0:
                lsize = num_ins
                rsize = hidden_size
            elif i == (num_layers - 1):
                lsize = hidden_size
                rsize = num_outs
            else:
                lsize = hidden_size
                rsize = hidden_size
            w.append(nn.Linear(lsize, rsize, bias=False))
        self.fc = nn.ModuleList(w)
        self._init_weights()

    def _init_weights(self):
        for i in self.fc:
            if isinstance(i, nn.Linear):
                constant_(i.weight, 1)

    def forward(self, inp):
        u = inp
        for i in range(self.num_layers - 1):
            u = self.fc[i](u)
            u = self.activation(u)
        return self.fc[-1](u)

loss_func = [torch.sum, torch.mean, partial(torch.norm, p=2), partial(torch.norm, p=3)]

def cal_gradnorm(ins,
                num_ins,
                num_outs,
                num_layers,
                hidden_size,
                n_loss,
                alpha,
                activation='tanh',
                weight_attr=None):
    net = FCNet(
        num_ins=num_ins,
        num_outs=num_outs,
        num_layers=num_layers,
        hidden_size=hidden_size,
        activation=activation,
        n_loss=n_loss)

    res = net(ins)
    print(res)
    losses = []
    for idx in range(n_loss):
        losses.append(loss_func[idx](res))
    losses = torch.stack(losses)
    weighted_loss = losses * net.weights
    loss = torch.sum(weighted_loss)
    loss.backward(retain_graph=True)
    initial_task_loss = losses.detach().numpy()
    net.weights.grad.data = net.weights.grad.data * 0.0
    W = net.fc[-1]
    norms = []
    for i in range(n_loss):
        # get the gradient of this task loss with respect to the shared parameters
        gygw = torch.autograd.grad(losses[i], W.parameters(), retain_graph=True)
        # compute the norm
        norms.append(torch.norm(torch.mul(net.weights[i], gygw[0])))
    norms = torch.stack(norms)
    print("norms: ", norms)

    if torch.cuda.is_available():
        loss_ratio = losses.data.cpu().numpy() / initial_task_loss
    else:
        loss_ratio = losses.data.numpy() / initial_task_loss

    inverse_train_rate = loss_ratio / np.mean(loss_ratio)
    print("inverse_train_rate: ", inverse_train_rate)

    if torch.cuda.is_available():
        mean_norm = np.mean(norms.data.cpu().numpy())
    else:
        mean_norm = np.mean(norms.data.numpy())
    
    constant_term = torch.tensor(mean_norm * (inverse_train_rate ** alpha), requires_grad=False)

    print("constant_term: ", constant_term)

    if torch.cuda.is_available():
        constant_term = constant_term.cuda()
    grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
    net.weights.grad = torch.autograd.grad(grad_norm_loss, net.weights)[0]
    print(net.weights.grad)
    return grad_norm_loss


def randtool(dtype, low, high, shape):
    """
    np random tools
    """
    if dtype == "int":
        return np.random.randint(low, high, shape)

    elif dtype == "float":
        return low + (high - low) * np.random.random(shape)


if __name__ == '__main__':
    np.random.seed(22)
    xy_data = randtool("float", 0, 10, (9, 2))
    print(xy_data)
    # xy_data = torch.tensor(np.array([[0.1, 0.5, 0.2, 0.4]]), dtype=torch.float32)
    # xy_data = torch.tensor(np.array([[0.1, 0.5, 0.3, 0.4, 0.2]]), dtype=torch.float32)
    # res = cal_gradnorm(xy_data, 4, 3, 5, 20, activation='sigmoid', n_loss=3, alpha=0.5)
    res = cal_gradnorm(torch.tensor(xy_data, dtype=torch.float32), 2, 3, 2, 1, activation='tanh', n_loss=4, alpha=0.5)
    print(res.item())
    

@Asthestarsfalll
Copy link
Contributor Author

2.3版本可正常运行,develop版本运行出错,正在尝试修复

@Asthestarsfalll
Copy link
Contributor Author

Asthestarsfalll commented Aug 11, 2022

@rightpeach 你好,CI已通过

@Asthestarsfalll
Copy link
Contributor Author

@rightpeach 你好,可以开始review吗

Copy link
Contributor

@rightpeach rightpeach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以解释一下原论文的技术思路,并提供代码注释及说明?此外请问是否跑通过原始论文代码?如跑通可以提供一下paddle复现版本与原论文代码的结果差异。并辛苦说明一下替换的API有哪些,以及遇到或遗留的问题。

@Asthestarsfalll
Copy link
Contributor Author

@rightpeach
GradNorm的主要思想就是平衡各个loss的梯度,从而平衡网络对不同任务的学习。
具体实现方式上 ,需要一个trainable的weights作为各个loss的权重,当计算出所有的loss之后,需要将其与weights对应相乘,先进行一次反向传播,求出所有参数的梯度,下面再进入grad norm loss的计算,并以此来更新weights的梯度。
GradNorm的优化目标就是最小化grad norm loss, grad norm loss被定义为当前各个loss的梯度和目标梯度的l1范数之和,这里的梯度作者选取的是weights相对于网络当中的最后一个共享层的梯度,在FCNet中即为最后一个全连接层。
而目标梯度则是由平均梯度和各个loss的逆学习速率计算得来,其中平均梯度就是当前所有loss的梯度均值,学习速率可以被定义为当前loss与初始loss的比值,该值越小则认为学习速率越快,逆学习率被定义为当前loss除以所有loss学习速率的均值,同样是值越小学习速率越快 。
将逆学习率和平均梯度相乘就可以得到目标梯度,这样,学习速率更高的loss就会拥有更小的梯度,从而达到平衡的效果。
同时这里有一个超参数alpha,用于控制grad norm loss的学习速率。
在训练时,每个step都需要renormalize,使得weights的总和不变,保证loss的变化是因为weights间数值的调整,而不是因为weighs整体变小导致loss变小。
伪代码如下:
image

@Asthestarsfalll
Copy link
Contributor Author

原始论文提供的是一个很简单的样例,我将其修改了一下,代码在上面,结果上没有很大差异,可以保证grad norm loss的相对误差在1e-7次方。
API并未做什么替换,因为其涉及到的计算都是加乘乘方一类常用API。
遇到的问题是grad的设置问题,网络初始化时grad为None,无法调用set_value方法,因此需要先backward一下获得grad。

Copy link
Contributor

@rightpeach rightpeach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请通过CI

@Asthestarsfalll
Copy link
Contributor Author

@rightpeach 已通过

@rightpeach rightpeach merged commit 93c91c5 into PaddlePaddle:develop Aug 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants