Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Why DifferentiableOptimizer detaches parameters when track_higher_grads = False? #102

Open
Renovamen opened this issue Mar 7, 2021 · 7 comments

Comments

@Renovamen
Copy link

Renovamen commented Mar 7, 2021

Hi! Thank you for this awesome library, it helps me a lot.

I am not sure whether I'm missing something, but I'm confused about why DifferentiableOptimizer detaches parameters when track_higher_grads = False:

higher/higher/optim.py

Lines 251 to 257 in 1e20cf9

new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if self._track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_()

which cuts the gradient path back to the original model parameters, even though copy_initial_weights=False. When we set copy_initial_weights=False, we want to allow gradients flow back to the original model parameters, but line 257 cut off the gradient flow.

In my use case, I want to implement something like FOMAML and here is a simplify version of my code:

def inner_loop(self, fmodel, diffopt, train_input, train_target):
    # ...

def outer_loop(self, task_batch):
    self.out_optim.zero_grad()

    for task_data in task_batch:
        support_input, support_target, query_input, query_target = task_data

        with higher.innerloop_ctx(
            self.model, self.in_optim, copy_initial_weights=False, track_higher_grads=False
        ) as (fmodel, diffopt):
            self.inner_loop(fmodel, diffopt, support_input, support_target)

            query_output = fmodel(query_input)
            query_loss = F.cross_entropy(query_output, query_target)
            query_loss.backward()

    for param in self.model.parameters():
        print(param.grad)  # output: None
    self.out_optim.step()

The gradients were not propagated back to the original parameters. My code works well after I edit the code of higher to:

new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
    for p, index in zip(group['params'], mapping):
        new_params[index] = p

I know this problem can be solved by manully mapping the gradients, but I just wonder why detaching parameters is necessary here. And thank you for your nice work again!

@eric-mitchell
Copy link

eric-mitchell commented Jun 16, 2021

As a workaround, I think you can use diff_opt.step(loss, grad_callback=lambda grads: [g.detach() for g in grads]). This gives the same outer loop gradient as when using torch.autograd.grad to compute gradients with track_higher_grads=False, but .backward() still works. As a bonus, you also get first-order gradients for inner loop learning rates (if you're learning those). With track_higher_grads=False, you don't get gradients for learning rates.

@brando90
Copy link

brando90 commented Nov 1, 2022

As a workaround, I think you can use diff_opt.step(loss, grad_callback=lambda grads: [g.detach() for g in grads]). This gives the same outer loop gradient as when using torch.autograd.grad to compute gradients with track_higher_grads=False, but .backward() still works. As a bonus, you also get first-order gradients for inner loop learning rates (if you're learning those). With track_higher_grads=False, you don't get gradients for learning rates.

@eric-mitchell is the right work around to set track_higher_grads=True but do diff_opt.step(loss, grad_callback=lambda grads: [g.detach() for g in grads])? I think so based on my experiments. I simply checked if your suggestion changed the gradients norm value once the code has been deterministically seeded (so your grads call back would be the only this changing the behvaiour):

do track_higher_order_grads = True but without Eric's grads_callback trick:

1.111317753791809

with deterministic code. So if I run it again it should print the same number.

1.1113194227218628

close enough!🙂 . Now let's change the seed (from 0 to 42, 142, 1142), the grad norm value should change:

1.5447670221328735
1.1538511514663696
1.8301351070404053

now returning to zero:

1.1113179922103882

close enough again!🙂

Now if eric's trick works (passing a grads callback), then the gradient value should change since it's now using FO and no higher order info. So will change my code in steps.
First I will leave the track track_higher_order_grads = True and use the call back.
This gives this gradient:

0.09500227868556976

Running it again I get (to confirm determinism of code):

0.09500067681074142

confirming that this combination does something different (i.e. his grads_callback changes the behaviour).

Now what if I use Eric's call back but use track_higher_order_grads=False:

AttributeError: 'NoneType' object has no attribute 'norm'

gives a bug. So setting track_higher_order_grads is always wrong it seems.

This makes me feel your solution at least changes the behaviour though I don't know why it works or why the original code by higher doesn't work.

My self contained reproducible script:

import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging

from collections import OrderedDict

import higher  # tested with higher v0.2

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

logger = logging.getLogger(__name__)


def conv3x3(in_channels, out_channels, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
        nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )


class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            conv3x3(in_channels, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size)
        )

        self.classifier = nn.Linear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features)
        return logits


def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points
    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.
    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has
        shape `(num_examples,)`.
    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())


def train(args):
    logger.warning('This script is an example to showcase the data-loading '
                   'features of Torchmeta in conjunction with using higher to '
                   'make models "unrollable" and optimizers differentiable, '
                   'and as such has been  very lightly tested.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download,
                       )
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size)
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)

            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(train_inputs, train_targets,
                                                         test_inputs, test_targets)):
                # track_higher_grads = True
                track_higher_grads = False
                with higher.innerloop_ctx(model, inner_optimiser, track_higher_grads=track_higher_grads, copy_initial_weights=False) as (fmodel, diffopt):
                    train_logit = fmodel(train_input)
                    inner_loss = F.cross_entropy(train_logit, train_target)

                    # diffopt.step(inner_loss)
                    diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])

                    test_logit = fmodel(test_input)
                    outer_loss += F.cross_entropy(test_logit, test_target)

                    # inspired by https://github.com/facebookresearch/higher/blob/15a247ac06cac0d22601322677daff0dcfff062e/examples/maml-omniglot.py#L165
                    # outer_loss = F.cross_entropy(test_logit, test_target)
                    # outer_loss.backward()

                    with torch.no_grad():
                        accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            # print(list(model.parameters()))
            # print(f"{meta_optimizer.param_groups[0]['params'] is list(model.parameters())}")
            # print(f"{meta_optimizer.param_groups[0]['params'][0].grad is not None=}")
            print(f"{meta_optimizer.param_groups[0]['params'][0].grad=}")
            print(f"{meta_optimizer.param_groups[0]['params'][0].grad.norm()}")
            assert meta_optimizer.param_groups[0]['params'][0].grad is not None
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)


if __name__ == '__main__':
    seed = 0

    import random
    import numpy as np
    import torch
    import os

    os.environ["PYTHONHASHSEED"] = str(seed)
    # - make pytorch determinsitc
    # makes all ops determinsitic no matter what. Note this throws an errors if you code has an op that doesn't have determinsitic implementation
    torch.manual_seed(seed)
    # if always_use_deterministic_algorithms:
    torch.use_deterministic_algorithms(True)
    # makes convs deterministic
    torch.backends.cudnn.deterministic = True
    # doesn't allow benchmarking to select fastest algorithms for specific ops
    torch.backends.cudnn.benchmark = False
    # - make python determinsitic
    np.random.seed(seed)
    random.seed(seed)

    import argparse

    parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')

    parser.add_argument('--folder', type=str, default=Path('~/data/torchmeta_data').expanduser(),
                        help='Path to the folder the data is downloaded to.')
    parser.add_argument('--num-shots', type=int, default=5,
                        help='Number of examples per class (k in "k-shot", default: 5).')
    parser.add_argument('--num-ways', type=int, default=5,
                        help='Number of classes per task (N in "N-way", default: 5).')

    parser.add_argument('--step-size', type=float, default=0.4,
                        help='Step-size for the gradient step for adaptation (default: 0.4).')
    parser.add_argument('--hidden-size', type=int, default=64,
                        help='Number of channels for each convolutional layer (default: 64).')

    parser.add_argument('--output-folder', type=str, default=None,
                        help='Path to the output folder for saving the model (optional).')
    parser.add_argument('--batch-size', type=int, default=16,
                        help='Number of tasks in a mini-batch of tasks (default: 16).')
    parser.add_argument('--num-batches', type=int, default=100,
                        help='Number of batches the model is trained over (default: 100).')
    parser.add_argument('--num-workers', type=int, default=1,
                        help='Number of workers for data loading (default: 1).')
    parser.add_argument('--download', action='store_false',
                        help='Do not Download the Omniglot dataset in the data folder.')
    parser.add_argument('--use-cuda', action='store_true',
                        help='Use CUDA if available.')

    args = parser.parse_args()
    args.device = torch.device('cuda' if args.use_cuda
                                         and torch.cuda.is_available() else 'cpu')

    train(args)

@brando90
Copy link

brando90 commented Nov 1, 2022

Now I will check how fast the code runs by reading the output of tdqm. If it's truly doing FO (and not using higher grads), then there should be some speed up. Running this in my m1 laptop. The combination for the following run is track_higher_grads = True and diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads]) so this should be FO (the faster one). So it should end quicker than the next run with higher grads/hessians:

0.03890747204422951
100%|██████████| 100/100 [06:32<00:00,  3.92s/it, accuracy=0.5092]

Now with track_higher_grads = True and diffopt.step(inner_loss) , which is with higher grads (hessian):

0.08946451544761658
100%|██████████| 100/100 [09:59<00:00,  6.00s/it, accuracy=0.9175]

since it's taking much longer I will conclude this indeed uses hessians & it's NOT the fo maml. I assume the difference would be more noticeable if the networks was larger (due to ~ quadratic size of Hessien).

@brando90
Copy link

brando90 commented Nov 1, 2022

As a workaround, I think you can use diff_opt.step(loss, grad_callback=lambda grads: [g.detach() for g in grads]). This gives the same outer loop gradient as when using torch.autograd.grad to compute gradients with track_higher_grads=False, but .backward() still works. As a bonus, you also get first-order gradients for inner loop learning rates (if you're learning those). With track_higher_grads=False, you don't get gradients for learning rates.

@eric-mitchell hi eric! Do you mind explaining to us (briefly) why your solution works. I must admit it's strange given that the code seemed to already do a detach & I would have expected the requires_grads not do anything (but perhaps it does clearly).

Thank you for your time!

@brando90
Copy link

brando90 commented Nov 1, 2022

More qualitatively sanity checks. First order MAML in my real script:

-> it=0: train_loss=4.249784290790558, train_acc=0.24000000208616257
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=0: val_loss=3.680968999862671, val_acc=0.2666666731238365
  0% (0 of 70000) |       | Elapsed Time: 0:00:00 | ETA:  --:--:-- |   0.0 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: train_loss=4.253764450550079, train_acc=2.712197299694683
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: val_loss=3.5652921199798584, val_acc=0.36666667461395264
  0% (1 of 70000) || Elapsed Time: 0:00:08 | ETA:  6 days, 18:55:28 |   0.1 it/smeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
  0% (2 of 70000) || Elapsed Time: 0:00:16 | ETA:  6 days, 18:56:48 |   0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: train_loss=4.480343401432037, train_acc=3.732449478260403
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: val_loss=3.6090375185012817, val_acc=0.19999999552965164
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
  0% (3 of 70000) || Elapsed Time: 0:00:25 | ETA:  6 days, 18:46:19 |   0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: train_loss=2.822919726371765, train_acc=0.3426572134620805
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: val_loss=4.102218151092529, val_acc=0.30666667222976685
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
  0% (4 of 70000) || Elapsed Time: 0:00:33 | ETA:  6 days, 18:47:29 |   0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'

now not FO maml:

-> it=0: train_loss=4.590916454792023, train_acc=0.23333333432674408
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=0: val_loss=3.6842236518859863, val_acc=0.2666666731238365
  0% (0 of 70000) |       | Elapsed Time: 0:00:00 | ETA:  --:--:-- |   0.0 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: train_loss=4.803018927574158, train_acc=2.596685569748149
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: val_loss=3.0977725982666016, val_acc=0.3199999928474426
  0% (1 of 70000) || Elapsed Time: 0:00:16 | ETA:  13 days, 1:18:19 |  16.1 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
  0% (2 of 70000) || Elapsed Time: 0:00:32 | ETA:  13 days, 1:09:53 |  16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: train_loss=4.257768213748932, train_acc=2.2006314379501504
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: val_loss=7.144366264343262, val_acc=0.30666665732860565
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
  0% (3 of 70000) || Elapsed Time: 0:00:48 | ETA:  13 days, 1:00:01 |  16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'


sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: train_loss=4.1194663643836975, train_acc=1.929317718150093
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: val_loss=3.4890414476394653, val_acc=0.35333333164453506
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
  0% (4 of 70000) || Elapsed Time: 0:01:04 | ETA:  13 days, 0:46:34 |  16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'

FO is 6 days while higher order one is 13, so it's likely correct!

@brando90
Copy link

brando90 commented Nov 1, 2022

solution is easy, they are doing detach on params p not on gradients g which is totally of course!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants