Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

variables work outside of context manager scope #110

Open
hughperkins opened this issue Jun 22, 2021 · 2 comments
Open

variables work outside of context manager scope #110

hughperkins opened this issue Jun 22, 2021 · 2 comments

Comments

@hughperkins
Copy link
Contributor

ie, if I do:

    with higher.innerloop_ctx(model, opt_inner, copy_initial_weights=False) as (f_model, diff_opt_inner):
        pass
    print('f_model.parameters()', list(f_model.parameters()))
    inputs = torch.rand(N, 2)
    targets = torch.rand(N, 2)
    pred = f_model(inputs)
    loss = F.mse_loss(pred, targets)
    diff_opt_inner.step(loss=loss)
    print('f_model.parameters()', list(f_model.parameters()))

what I expect to happen:

  • crash since f_model out of scope, and diff_opt_inner out of scope

What actually happens:

  • behaves identically to if inside the context manager block

(I wonder if this is connected with eg #105?)

Full code:

import higher
import torch
from torch import nn, optim
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = nn.Linear(2, 2)

    def forward(self, x):
        x = self.h1(x)
        return x


def run():
    N = 4
    model = Model()

    opt_outer = optim.Adam(lr=0.01, params=model.parameters())
    opt_inner = optim.Adam(lr=0.01, params=model.parameters())

    print('model.parameters()', list(model.parameters()))

    with higher.innerloop_ctx(model, opt_inner, copy_initial_weights=False) as (f_model, diff_opt_inner):
        pass
    print('f_model.parameters()', list(f_model.parameters()))
    inputs = torch.rand(N, 2)
    targets = torch.rand(N, 2)
    pred = f_model(inputs)
    loss = F.mse_loss(pred, targets)
    diff_opt_inner.step(loss=loss)
    print('f_model.parameters()', list(f_model.parameters()))

    print('model.parameters.grad', [p.grad for p in model.parameters()])
    inputs = torch.rand(N, 2)
    targets = torch.rand(N, 2)
    pred = f_model(inputs)
    loss = F.mse_loss(pred, targets)
    loss.backward()
    print('model.parameters.grad', [p.grad for p in model.parameters()])


if __name__ == '__main__':
    run()

@brando90
Copy link

brando90 commented Nov 4, 2021

I am also interested in this. My experience is that there is no issue just returning the fmodel and diffopt

@brando90
Copy link

brando90 commented Nov 4, 2021

but there might be some memory issues

#119

#75

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants