Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Background Preprocessing #3527

Merged
merged 21 commits into from
Apr 6, 2021
Merged

Background Preprocessing #3527

merged 21 commits into from
Apr 6, 2021

Conversation

stephenroller
Copy link
Contributor

@stephenroller stephenroller commented Mar 16, 2021

Patch description
The fabled background preprocessing is here. You can enable it by setting --num-workers to a value greater than 0!

Some rough estimates of performance increases:

  • Fine tuning BlenderBot: 5-10%
  • Fine tuning BART: 15-25%
  • Training a Bart-sized model on CommonCrawl-like data: 110% speedup (2.1x faster)

Testing steps
Sooooo much manual testing. There is also new CI.

@stephenroller stephenroller changed the base branch from master to yactb March 26, 2021 17:38
@stephenroller stephenroller marked this pull request as ready for review March 26, 2021 19:03
@stephenroller
Copy link
Contributor Author

Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Makes sense to me - will defer to Kurt to approve because I think he has more context on this but happy to do so if desired. What value of --num-workers would you recommend for different cases? How much additional mem does this typically use?


def receive_data(self, future):
def receive_data(self, future, direct_result=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe it'd be worth adding a line of documentation on future? I understand from the line above that it's a chunk, but offhand I don't even know what a chunk is...


self.re = re
except ImportError:
if regex is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was always weird

@stephenroller
Copy link
Contributor Author

Number of workers gives diminishing returns, but I have still found that 8 is faster than 4. 10 is the "max" we can do on our cluster based on CPUs-per-GPU.

@stephenroller
Copy link
Contributor Author

As far as memory, it's loading the dataset separately per worker. My chunk teachers haven't had issues tho.

@@ -2277,12 +2304,23 @@ def __init__(self, opt, shared=None):
c for c in self.fold_chunks if c % self.dws == self.rank
]

# deal with --num-workers
self.threading = not (opt.get('num_workers', 0) > 0 and self.is_train)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we still thread during validation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some clarifying comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the clarification

"""
Loads data.

Load data into self.samples until buffersize is reached.
"""
chunk_output, chunk_reset_cnt = future.result()
output = future if direct_result else future.result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: direct_result is a bit ambiguous --> if true, we return the future, not the future.result(), which is somewhat counterintuitive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, refactored to be more clear I hope

@@ -162,20 +165,26 @@ def to(self, dev):
"""
Move all tensors in the batch to a device.

NOT in place.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out, attrdict doesn't link batch.value to batch['value'] after initialization so it was like real weird lol

import torch.multiprocessing as mp

self._num_workers = self.opt['num_workers']
self._process_queue = mp.Queue(maxsize=4 * self._num_workers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arbitrary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya seemed good to me lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a comment

def launch_process(cls, index, opt, model_agent, process_queue):
import torch

torch.set_num_threads(1) # prevent threads from spawning in this worker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this local in scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya i think so but not 100% sure

Base automatically changed from yactb to master March 31, 2021 16:37
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really appreciate the clarifications, sorry for the super delayed re-review

these tricks, or simply use these options in the first place.

This tutorial is only for illustration purposes of how to speed up training,
and may not get you the best _performing_ model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming this is just a standard disclaimer? or d you actually see worse performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LR isn't optimized for this, and many of them don't train for remotely long enough.


This tutorial is only for illustration purposes of how to speed up training,
and may not get you the best _performing_ model.
:::

A summary of the speedups is in this table:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe indicate here, summary of speedups "for 90m blenderbot model" or something like that

| FP16 | 212s | 11s | 223s | 2.60x |
| Larger batchsize (FP16) | 168s | 10s | 178s | 3.59x |
| Background preprocessing | 144s | 10s | 154s | 4.15x |
| Using 4 GPUs | 63s | 4s | 67s | 8.64x |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gimme 8x all day

--dynamic-batching full \
--fp16 true --batchsize 128
--fp16 true --fp16-impl mem_efficient --batchsize 128 --eval-batchsize 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to explicitly specify which --fp16-impl for max efficiency?

tokenization and dialogue state tracking, in the background thread. This can be
enabled by setting `--num-workers` to a value greater than 0. A good rule of
thumb is to set `--num-workers` to the number of CPU cores you have PER GPU.
Sometimes you can go a bit over, but you will need to play. On my server, there
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "you will need to play..." "around with it"?

--skip-generation true \
--dynamic-batching full \
--fp16 true --fp16-impl mem_efficient --batchsize 128 --eval-batchsize 256 \
--num-workers 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you set it to 8 tho specify ur setting to 4

training runs like in this document. If we train for 4 epochs instead, our FP16
large-batch run takes 614 seconds and the multi-GPU training takes 222 seconds.
Also, if you have access to a SLURM cluster, distributed_train is sometimes
faster than multiprocessin_train. With SLURM, multi-GPU training takes 167 seconds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: multiprocessin

@@ -2277,12 +2304,23 @@ def __init__(self, opt, shared=None):
c for c in self.fold_chunks if c % self.dws == self.rank
]

# deal with --num-workers
self.threading = not (opt.get('num_workers', 0) > 0 and self.is_train)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the clarification

@stephenroller stephenroller merged commit a9afe29 into master Apr 6, 2021
@stephenroller stephenroller deleted the realbackground branch April 6, 2021 20:18
@sjscotti
Copy link

Hi
I am finetuning BlenderBot2 (400M) model with some domain-specific data and to speed things up I tried to use the --num-workers flag on a Windows 10 machine where I am running in an Anaconda environment. I tries values of 4 and of 8 (I have an 8 core machine) but both gave me errors like shown below. The training works OK without this flag. Is there an issue with using this flag in Windows?
Thanks!

<prior output is not shown>
...
14:48:29 | creating task(s): jsonfile
14:48:29 | [loading data from json file into task:new_custom_task.txt]
14:48:29 | Metadata does not exist. Please double check your datapath.
Traceback (most recent call last):
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\Scripts\parlai.exe\__main__.py", line 7, in <module>
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\__main__.py", line 14, in main
    superscript_main()
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\core\script.py", line 325, in superscript_main
    return SCRIPT_REGISTRY[cmd].klass._run_from_parser_and_opt(opt, parser)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\core\script.py", line 108, in _run_from_parser_and_opt
    return script.run()
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\scripts\train_model.py", line 932, in run
    self.train_loop = TrainLoop(self.opt)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\scripts\train_model.py", line 349, in __init__
    self.world = create_task(opt, self.agent)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\core\worlds.py", line 1444, in create_task
    world = BackgroundDriverWorld(opt, world)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\core\worlds.py", line 1246, in __init__
    self._process_pool = self._start_processes()
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\parlai\core\worlds.py", line 1254, in _start_processes
    return mp.start_processes(
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\site-packages\torch\multiprocessing\spawn.py", line 169, in start_processes
    mp = multiprocessing.get_context(start_method)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\multiprocessing\context.py", line 239, in get_context
    return super().get_context(method)
  File "C:\Users\Steve\anaconda3\envs\ParlAI\lib\multiprocessing\context.py", line 193, in get_context
    raise ValueError('cannot find context for %r' % method) from None
ValueError: cannot find context for 'fork'

@stephenroller
Copy link
Contributor Author

Sorry, we don't support Windows at this time. You can try without bg preprocessing.

@sjscotti
Copy link

sjscotti commented Aug 1, 2021

Thanks for the speedy reply @stephenroller
I am using about 3GB of "shared" GPU memory right now, and I was hoping that background processing would off load enough of the work so that my GPU only used the 8GB of memory on the GPU card and didn't spill over into "shared" GPU memory. Do you think that the changes to support Windows would be major? Any suggestions for what I would need to do to give it a try?
For example, in worlds.py there is a line in the call to mp.start_processes that has start_method='fork', on line 1265.
Would changing it to start_method='spawn', which is more Windows compatible, cause any problems elsewhere?
Thanks again!

@stephenroller
Copy link
Contributor Author

I don't think Bg preprocessing will do anything to save you GPU memory. It just does tokenization etc in a background thread so the GPU is always being fed examples.

Unfortunately, our implementation heavily depends on the semantics of fork, and won't work with others.

@sjscotti
Copy link

sjscotti commented Aug 1, 2021

Ok, thanks! I’ve tried all the other tricks I can find to reduce GPU memory such as blocklength set to 1, fp16 using mem_efficient, and adafactor optimization (I didn’t want to reduce the various truncation settings from the BlenderBot2 settings since I am not sure whether they would affect results from executing the BlenderBot2), so I guess I’ll have to live with the training speed I’ve got.

@stephenroller
Copy link
Contributor Author

If you have multiple GPUs, --ddp-backend zero2 can help.

I've got an intern who is also lamenting how slow training similar models is. I'll let you know how we speed it up.

@sjscotti
Copy link

sjscotti commented Aug 1, 2021

Thanks, any help is appreciated. And unfortunately I only have 1 GPU …

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants