Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p(t) in Algorithm 2 #3

Open
boxaio opened this issue Jan 21, 2024 · 1 comment
Open

p(t) in Algorithm 2 #3

boxaio opened this issue Jan 21, 2024 · 1 comment

Comments

@boxaio
Copy link

boxaio commented Jan 21, 2024

In Algorithm 2 (in practice), you sample the time points t \sim p(t), where p(t) can be viewed as a proposal importance sampling distribution. One can take p(t) to be estimated using Eq.(85), as you mentioned in Appendix.C. But in this code repository (see losses.py)
you have
p_t = time_sampler.invdensity(t)
which is according to an uniform distribution (see dynamics.utils).
So I wonder how you actually implemented your claim in Algorithm 2 in Appendix C.

@necludov
Copy link
Owner

necludov commented Jan 23, 2024

Hi @boxaio!
I haven't implemented this in the JAX version. I have a pytorch implementation of the same algorithm, which is not public.
Here is the code for the class implementing the loss function

class AdaptiveLoss:
    def __init__(self, net, config, n=100, beta=0.99):
        self.t0, self.t1 = config.model.t0, config.model.t1
        self.alpha, self.beta = config.train.alpha, beta
        self.timesteps = np.linspace(self.t0, self.t1, n)
        self.dt = (self.t1-self.t0)/(n-1)
        self.rank = get_rank()
        self.ws = get_world_size()
        
        self.q_t, self.w, self.dwdt = get_q(config)
        self.boundary_conditions = (self.w(torch.tensor(self.t0)).item() != 0.0,
                                    self.w(torch.tensor(self.t1)).item() != 0.0)
        print('boundary conditions are: ', self.boundary_conditions)
        config.train.boundary_conditions = self.boundary_conditions
        
        self.s = get_s(net, config)
        
        self.buffer = {'values': [],
                       'times': [],
                       'size': 100,
                       'mean': np.zeros_like(self.timesteps),
                       'var': np.ones_like(self.timesteps),
                       'p': np.ones_like(self.timesteps),
                       'u0': 0.5}
        self.construct_dist()
        
        meters = [DDPAverageMeter('train_loss'),
                  DDPAverageMeter('dsdx_std'),
                  DDPAverageMeter('dsdt_std'),
                  DDPAverageMeter('s_1_std'),
                  DDPAverageMeter('s_0_std'),
                  DDPAverageMeter('s_std')]
        self.meters = dict((m.name,m) for m in meters)
        
    def load_state_dict(self, buffer_dict):
        self.buffer = buffer_dict
        self.construct_dist()
        
    def state_dict(self):
        return self.buffer
        
    def construct_dist(self):
        dt, t = self.dt, self.timesteps
        p = self.buffer['p']
        self.fp = scipy.interpolate.interp1d(t, p, kind='linear')
        self.dpdt = scipy.interpolate.interp1d(t, np.concatenate([p[1:]-p[:-1], p[-1:]-p[-2:-1]])/dt, kind='zero')
        intercept = lambda t: self.fp(t)-self.dpdt(t)*t
        t0_interval = scipy.interpolate.interp1d(t, t, kind='zero')
        mass = np.concatenate([np.zeros([1]), ((p[1:]+p[:-1])*dt/2).cumsum()[:-1], np.ones([1])])
        F0_interval = scipy.interpolate.interp1d(t, mass, kind='zero')
        F0_inv = scipy.interpolate.interp1d(mass, t, kind='zero')
        def F(t):
            t0_ = t0_interval(t)
            F0_ = F0_interval(t)
            k, b = self.dpdt(t), intercept(t)
            output = 0.5*k*(t**2-t0_**2) + b*(t-t0_)
            return F0_ + output 

        def F_inv(y):
            t0_ = F0_inv(y)
            F0_ = F0_interval(t0_)
            k, b = self.dpdt(t0_), intercept(t0_)
            c = y - F0_
            c = c + 0.5*k*t0_**2 + b*t0_
            D = np.sqrt(b**2 + 2*k*c)
            output = (-b + D) * (np.abs(k) > 0)  + c/b * (np.abs(k) == 0.0)
            output[np.abs(k) > 0] /= k[np.abs(k) > 0]
            return output
        
        self.F_inv = F_inv
        
    def sample_t(self, n, device):
        u = (self.buffer['u0'] + np.sqrt(2)*np.arange(n*self.ws)) % 1
        self.buffer['u0'] = (self.buffer['u0'] + np.sqrt(2)*n*self.ws) % 1
        u = u[self.rank*n:(self.rank+1)*n]
        t = self.F_inv(u)
        assert ((t < 0.0).sum() == 0) and ((t > 1.0).sum() == 0)
        p_t, dpdt = self.fp(t), self.dpdt(t)
        p_0, p_1 = self.fp(self.t0*np.ones_like(t)), self.fp(self.t1*np.ones_like(t))
        t = torch.from_numpy(t).to(device).float()
        p_t, dpdt = torch.from_numpy(p_t).to(device).float(), torch.from_numpy(dpdt).to(device).float()
        p_0, p_1 = torch.from_numpy(p_0).to(device).float(), torch.from_numpy(p_1).to(device).float()
        return t, p_t, dpdt

    def update_history(self, new_p, t, p_t):
        new_p, t, p_t = new_p.cpu().numpy().flatten(), t.cpu().numpy().flatten(), p_t.cpu().numpy().flatten()
        weights = np.exp(-np.abs(self.timesteps.reshape(-1, 1) - t.reshape(1,-1))*1e2)
        weights = weights/weights.sum(1,keepdims=True)
        self.buffer['mean'] = self.beta*self.buffer['mean'] + (1-self.beta)*(weights@new_p)
        mean_func = scipy.interpolate.interp1d(self.timesteps, self.buffer['mean'], kind='linear')
        self.buffer['var'] = self.beta*self.buffer['var'] + (1-self.beta)*(weights@((mean_func(t) - new_p)**2))
        p = np.sqrt(self.buffer['var'])
        p = (1.0-self.alpha)*p/((p[1:]+p[:-1])*self.dt/2).sum() + self.alpha/(self.t1-self.t0)
        self.buffer['p'] = p
        self.construct_dist()
    
    def eval_loss(self, x):
        q_t, w, dwdt, s = self.q_t, self.w, self.dwdt, self.s
        assert (2 == x.dim())
        t_0, t_1 = self.t0, self.t1
        device = x.device
        bs = x.shape[0]
        t, p_t, dpdt = self.sample_t(bs, device)
        while (x.dim() > t.dim()): t = t.unsqueeze(-1)
        x_t, _ = q_t(x, t)
        x_t.requires_grad, t.requires_grad = True, True
        s_t = s(t, x_t)
        assert (2 == s_t.dim())
        dsdt, dsdx = torch.autograd.grad(s_t.sum(), [t, x_t], create_graph=True, retain_graph=True)
        x_t.requires_grad, t.requires_grad = False, False
        
        loss = (0.5*(dsdx**2).sum(1, keepdim=True) + dsdt.sum(1, keepdim=True))*w(t)
        self.meters['dsdx_std'].update((0.5*(dsdx**2).sum(1)*w(t).squeeze()).detach().cpu().std())
        self.meters['dsdt_std'].update((dsdt.sum(1)*w(t).squeeze()).detach().cpu().std())
        loss = loss + s_t*dwdt(t)
        self.meters['s_std'].update((s_t*dwdt(t)).squeeze().detach().cpu().std())
        loss = loss.squeeze()/p_t
        time_loss = loss.detach()*p_t
            
        s_1_std, s_0_std = 0.0, 0.0
        if self.boundary_conditions[0]:
            t_0 = t_0*torch.ones([bs, 1], device=device)
            x_0, _ = q_t(x, t_0)
            left_bound = (s(t_0,x_0)*w(t_0)).squeeze()
            loss = loss + left_bound
            self.meters['s_0_std'].update(left_bound.detach().cpu().std())
        if self.boundary_conditions[1]:
            t_1 = t_1*torch.ones([bs, 1], device=device)
            x_1, _ = q_t(x, t_1)
            right_bound = (-s(t_1,x_1)*w(t_1)).squeeze()
            loss = loss + right_bound
            self.meters['s_1_std'].update(right_bound.detach().cpu().std())
            
        self.meters['train_loss'].update(loss.detach().mean().cpu())
        self.update_history(gather(time_loss), gather(t), gather(p_t))
        return loss.mean(), self.meters
    
    def get_dxdt(self):
        def dxdt(t, x):
            return torch.autograd.grad(self.s(t, x).sum(), x, create_graph=True, retain_graph=True)[0]
        return dxdt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants