Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use torch native amp #3128

Merged
merged 11 commits into from
Aug 8, 2023
Merged

use torch native amp #3128

merged 11 commits into from
Aug 8, 2023

Conversation

helpmefindaname
Copy link
Collaborator

@helpmefindaname helpmefindaname commented Feb 28, 2023

nvidia-apex amp is deprecated for some time now, as pytorch has torch.cuda.amp since pytorch 1.6.

This PR upgrades the usage to the newer version, hence setting use_amp=True on trainer.train or trainer.finetune will work out of the box.
Also use_amp will be used in all tests that train a model, hence the tests should be faster (waiting for the pipeline to finish, to evaluate it)

I did 2 training runs:
Transformer (distilbert) using use_amp reduces the time from 440 s/epoch to 217 s/epoch.
Flair-embeddings + wordembeddings (without a BiLstm layer) using use_amp reduces the time for the first epoch (no embedding storage) from 221 s/epoch to 70s/epoch.

This PR also fixes recreation of the types to be aligned with the layers, hence you can speed up inference by using tagger.half().

pytorch generally selects the dtype to cast to depending on your device. However, one can select the dtype themself, by using torch.set_autocast_cpu_dtype or torch.set_autocast_gpu_dtype. Notice, that cpu currently only support torch.bfloat16

However be aware, that using amp, reduces the accuracy of gradients and therefore can lead to a lower score or higher loss.

@helpmefindaname
Copy link
Collaborator Author

update: running AMP on cpu is possible, however I experienced extreme slow downs compared to non-amp on CPU.
E.g. training distilbert without crf/lstm on the WNUT_17 corpus takes 218 seconds per epoch without amp. With amp It takes 320 for the first two batches, estimating a epoch training time of 17182 seconds per epoch.

Hence I deleted AMP from the testing scripts again

@helpmefindaname helpmefindaname force-pushed the use_torch_amp branch 2 times, most recently from 4d375a7 to ff6806d Compare March 20, 2023 18:48
@helpmefindaname
Copy link
Collaborator Author

after feedback from @dchaplinsky I also upgraded the language model trainer. I have verified that amp works, however I haven't gathered information about the speed boost on languagemodelling

@alanakbik
Copy link
Collaborator

@helpmefindaname have you checked if this gives speedups? At least the following script on my local machine trained on cuda:0 becomes slower if use_amp=True.

flair.set_seed(123)

# set this to True or False
use_amp = True

# get downsampled corpus
corpus = CONLL_03(in_memory=False).downsample(0.05)

# make label dictionary
label_dict = corpus.make_label_dictionary("ner")

# init embeddings
embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)

# init simple tagger
tagger = TokenClassifier(
    embeddings=embeddings,
    label_dictionary=label_dict,
    label_type="ner",
)

# train model
trainer = ModelTrainer(tagger, corpus)

trainer.fine_tune(
    f"resources/taggers/test_tagger_{use_amp}-chunk_4",
    monitor_test=True,
    shuffle=False,
    max_epochs=1,
    use_amp=use_amp,
)

Without amp I get 400 samples / sec, but with use_amp=True I get only 160 samples / sec.

Additionally, it throws the following warning:

UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "

@helpmefindaname
Copy link
Collaborator Author

That is interesting, I get:

with amp:

2023-04-30 14:08:37,927 epoch 1 - iter 162/188 - loss 0.61785687 - time (sec): 14.64 - samples/sec: 579.71 - lr: 0.000008 - momentum: 0.000000
2023-04-30 14:08:39,605 epoch 1 - iter 180/188 - loss 0.56930119 - time (sec): 16.32 - samples/sec: 596.59 - lr: 0.000003 - momentum: 0.000000
2023-04-30 14:08:40,285 ----------------------------------------------------------------------------------------------------
2023-04-30 14:08:40,285 EPOCH 1 done: loss 0.5559 - lr: 0.000003

without amp:

2023-04-30 14:11:57,709 epoch 1 - iter 162/188 - loss 0.57335274 - time (sec): 55.76 - samples/sec: 152.20 - lr: 0.000008 - momentum: 0.000000
2023-04-30 14:12:03,893 epoch 1 - iter 180/188 - loss 0.53085163 - time (sec): 61.95 - samples/sec: 157.15 - lr: 0.000003 - momentum: 0.000000
2023-04-30 14:12:06,460 ----------------------------------------------------------------------------------------------------
2023-04-30 14:12:06,460 EPOCH 1 done: loss 0.5190 - lr: 0.000003

So amp is faster for me, on almost the opposite ratio.
Notice, that if I would run with amp but have my battery not powered, I would get ~100/sec being worse than amp with power.

About the warning:
this seems to be an open problem in pytorch see pytorch/pytorch#67590 I implemented the same hack also used in huggingface/transformers#11144

with that fix on the amp, I get:

2023-04-30 14:54:11,327 epoch 1 - iter 162/188 - loss 0.75855084 - time (sec): 14.72 - samples/sec: 576.75 - lr: 0.000011 - momentum: 0.000000
2023-04-30 14:54:13,192 epoch 1 - iter 180/188 - loss 0.69159031 - time (sec): 16.58 - samples/sec: 587.15 - lr: 0.000005 - momentum: 0.000000
2023-04-30 14:54:13,925 ----------------------------------------------------------------------------------------------------
2023-04-30 14:54:13,925 EPOCH 1 done: loss 0.6729 - lr: 0.000005

So in that specific case the fix performs worse, however I don't think that generalizes to models with more training epochs, etc.

@alanakbik
Copy link
Collaborator

@helpmefindaname thanks for adding this! (Locally, I am still seeing slowdowns when using AMP, but as discussed offline, there may be a CPU/GPU tradeoff that is causing this.)

@alanakbik alanakbik merged commit c96660c into master Aug 8, 2023
@alanakbik alanakbik deleted the use_torch_amp branch August 8, 2023 09:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants