Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using Adafactor without learn rate #11612

Closed
oliverguhr opened this issue May 6, 2021 · 11 comments · Fixed by #12123
Closed

Error when using Adafactor without learn rate #11612

oliverguhr opened this issue May 6, 2021 · 11 comments · Fixed by #12123
Assignees

Comments

@oliverguhr
Copy link
Contributor

oliverguhr commented May 6, 2021

Hi,
I get these strange errors when I use the Adafactor. This code will result in this (expected) error:

optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=1e-4)

ValueError: Cannot combine manual lr and relative_step=True options

however, if I do not set a manual learn rate I get a different error. Btw: This code is recommended in the documentation.

optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
# same for 
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True)

will return this error

TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'

Environment info

  • transformers version: 4.5.1
  • Platform: Linux
  • Python version: 3.7.1
  • PyTorch version (GPU?): 1.8.0+cu111 and 1.8.1+cu111
  • Tensorflow version (GPU?): -
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help

Trainer: @sgugger

@sgugger
Copy link
Collaborator

sgugger commented May 6, 2021

This was added by @jsrozner and @stas00 in #10526, so pinging them here.

@oliverguhr
Copy link
Contributor Author

oliverguhr commented May 6, 2021

Thank you @sgugger for the feedback. I install the latest transformers version from source using:
pip install git+https://github.com/huggingface/transformers

and set the recommended parameters from the patch:

optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=True, warmup_init=True, lr=None)

TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'

However, the error message remains the same. Can you give me a hint where I can address this issue?

For reference this is the code that I am using:

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_2_id))

args = TrainingArguments(
    output_dir=f"models/{run_name}/checkpoints",
    run_name=run_name,    
    evaluation_strategy = "epoch",    
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=1,
    num_train_epochs=2,
    report_to=["tensorboard"],
    logging_dir='runs/'+run_name,
    logging_first_step=True,
    logging_steps=100,
    save_steps= 10000,
    save_total_limit=10,
    seed=16,     
    fp16=True
)

optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=True, warmup_init=True, lr=None)
lrs = get_constant_schedule_with_warmup(optimizer,100)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model,
    args,    
    train_dataset=tokenized_dataset_train,
    eval_dataset=tokenized_dataset_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_sklearn, 
    optimizers=(optimizer,lrs)
)

@stas00
Copy link
Contributor

stas00 commented May 6, 2021

@oliverguhr, please always post a full traceback for errors. It's impossible otherwise to know where the error came from, please refer to https://github.com/huggingface/transformers/blob/master/ISSUES.md#the-github-issues item (3).

The actual recommendation is:

Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)

The alternative one I saved because others said it worked well for them.

Once you post the full traceback then we can see why it fails.

Thank you!

p.s. colab notebook reproducing the problem is even better

@stas00 stas00 self-assigned this May 6, 2021
@oliverguhr
Copy link
Contributor Author

oliverguhr commented May 6, 2021

Thanks for looking at this @stas00

Here is a traceback and this is a colab notebook to reproduce the issue.

Hint: Depending on setting

lrs = get_constant_schedule_with_warmup(optimizer,100)

or

lrs = none

get_constant_schedule_with_warmup fails directly or trainer.train().

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-18-031302865887> in <module>()
     12 
     13 optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=True, warmup_init=True, lr=None)
---> 14 lrs = get_constant_schedule_with_warmup(optimizer,100)
     15 
     16 training_args = TrainingArguments(

5 frames

/usr/local/lib/python3.7/dist-packages/transformers/optimization.py in get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch)
     67         return 1.0
     68 
---> 69     return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
     70 
     71 

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, lr_lambda, last_epoch, verbose)
    201                     len(optimizer.param_groups), len(lr_lambda)))
    202             self.lr_lambdas = list(lr_lambda)
--> 203         super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)
    204 
    205     def state_dict(self):

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, last_epoch, verbose)
     75         self.verbose = verbose
     76 
---> 77         self.step()
     78 
     79     def state_dict(self):

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in step(self, epoch)
    150             if epoch is None:
    151                 self.last_epoch += 1
--> 152                 values = self.get_lr()
    153             else:
    154                 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in get_lr(self)
    249 
    250         return [base_lr * lmbda(self.last_epoch)
--> 251                 for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
    252 
    253 

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in <listcomp>(.0)
    249 
    250         return [base_lr * lmbda(self.last_epoch)
--> 251                 for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
    252 
    253 

TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'

@stas00
Copy link
Contributor

stas00 commented May 6, 2021

Thank you for creating the reproducible colab notebook, @oliverguhr - that's very helpful.

So when you use Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=None) the learning rate scheduling is performed internally by the optimizer and so there is no need for a scheduler.

But I see that barebones HF Trainer doesn't support training w/o a scheduler. So we aren't quite supporting this option then and perhaps we should.

Regardless of the outcome we should document the conclusion of this thread in the Adafactor docstring.

So here are a few ideas meanwhile:

  1. Create a dummy scheduler that always returns a fixed lr, example:
from torch.optim.lr_scheduler import LambdaLR
class DummyLR(LambdaLR):
    def __init__(self, optimizer, lr=0):
        for group in optimizer.param_groups:
            group['initial_lr'] = lr
        lr_lambda = lambda x: lr
        super().__init__(optimizer, lr_lambda)
        for group in optimizer.param_groups:
            del group['initial_lr']

def get_dummy_schedule(optimizer):
    return DummyLR(optimizer)

lrs = get_dummy_schedule(optimizer) 

Let me know if this unblocks you a bit.

  1. Alternatively, if you want to be able to access lr outside of optimizer, here a proxy scheduler that pulls the LR out of the optimizer at run time, rather than feeding the optimizer.
from torch.optim.lr_scheduler import LambdaLR
class AdafactorSchedule(LambdaLR):
    def __init__(self, optimizer, initial_lr=0):
        for group in optimizer.param_groups:
            group['initial_lr'] = initial_lr
        lr_lambda = lambda x: initial_lr
        super().__init__(optimizer, lr_lambda)
        for group in optimizer.param_groups:
            del group['initial_lr']

    def get_lr(self):
        opt = self.optimizer
        lrs = [opt._get_lr(group, opt.state[group["params"][0]]) for group in opt.param_groups if group["params"][0].grad is not None]
        if len(lrs) == 0:
            lrs = self.base_lrs # if called before stepping
        # print(f"lr={lrs}")
        return lrs

def get_adafactor_schedule(optimizer):
    return AdafactorSchedule(optimizer)


optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=True, warmup_init=True, lr=None)
scheduler = get_adafactor_schedule(optimizer) 

clearly this is a quick hack, but it seems to work. it returns initial_lr during startup and the actual lr during stepping (disable the debug print to see).

As you can see I had to hack initial_lr into it since optimizer doesn't have any lr until it starts stepping.

If this is desired than we could add Adafactor.get_scheduler() which would return the above. Perhaps it needs to assert if lr != None I haven't looked that close.

If you like the 2nd solution feel free to clean it up, and making a PR, perhaps getting rid of LambdaLR to not need the group['initial_lr'] hack, going straight for the _LRScheduler super class.

  1. Make HF Trainer support scheduler=None - that would be hard for the loggers and other places that expect being able to get the value for lr. I think a clean version of the 2nd solution is probably more suitable.

@stas00
Copy link
Contributor

stas00 commented May 6, 2021

So @sgugger suggests the 3rd option. For that will have to track down all the cases where the scheduler is used and condition those on scheduler != None

Not sure about back-compat though since we auto-create a scheduler if it's not passed:

if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
)

@stas00
Copy link
Contributor

stas00 commented May 7, 2021

Another proposition from @sgugger is:

this can be handled with the lr_scheduler_type argument: we could add an acceptable value "no" that would leave the scheduler at None.

@oliverguhr
Copy link
Contributor Author

@stas00 Sorry for the late reply and thanks for your feedback. The DummyLR worked for me, but this parameter combination did not improve my results, maybe these parameter settings are kind of an edge case.

Regarding the 3rd option: Would it possible to check if lr_scheduler is None and optimizer is Adafactor and then auto-create an instance of the "AdafactorScheduler"? This could eliminate the need to check all the other parts of the code that rely on the LR value from the opimzier.

@stas00
Copy link
Contributor

stas00 commented May 31, 2021

Regarding the 3rd option: Would it possible to check if lr_scheduler is None and optimizer is Adafactor and then auto-create an instance of the "AdafactorScheduler"? This could eliminate the need to check all the other parts of the code that rely on the LR value from the opimzier.

@sgugger, what is your take - AdafactorScheduler is the hack I posted here: #11612 (comment)

I'm happy with either way, but let's resolve it one way or another.

@sgugger
Copy link
Collaborator

sgugger commented Jun 1, 2021

Mmm, the lr_scheduler is always None by default. Can we add a value "adafactor" that would use that AdafactorScheduler?

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

@oliverguhr, we went with the AdafactorSchedule - please check that it works for you

#12123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants