Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite parallel sampling using multiprocessing #3011

Merged
merged 11 commits into from
Jun 14, 2018

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Jun 7, 2018

This PR gets rid of joblib altogether, and replaces it with a custom implementation.
This should solve some long standing problems:

  • Pickling models in not always possible, and with this approach we only need to do this on windows now.
  • Problems pickling large traces. Right now we send the finished trace to the main thread by pickling it and sending it though a pipe. This causes problems if the trace is very large, as on many systems there is a limit for the size of one pipe.send. Very large traces can't use multiprocessing right now.
  • joblib doesn't seem to be particularly reliable. I've seen lots of cases where processes got stuck, and I had to kill them manually, or interrupting sampling works only after sending several interrupts. (One caviat with the proposed method is that we can only interrupt traces between samples. So very slow samplers might be stuck for a few seconds. We can still kill them though.)
  • More control over how exceptions are reported. We can format the remote exceptions nicely. This is not implemented yet, but there is code for this here: https://hg.python.org/cpython/rev/c4f92b597074
  • Better progress bars, as the main processes always knows how far the individual processes are. (but there seem to be tqdm bugs that still mess that up...)
  • We can implement trace backends so that the main process writes all samples to the file as they come in, and we never need to keep many of them in memory, and can avoid parallel io.

We start one process per chain, that communicates with the main thread by sending messages through a pipe. Ones a draw is ready, it is stored in shared memory, and the main process can access it. Then the sampler process is asked to write the next sample to that shared memory.

Performance wise it shouldn't be much different, the pipes seem to be fast enough to deal with the speed of our sampler. It is possible that we gain a bit of speed for large models with cheap computations, as we have a dedicated thread for sampling, that doesn't also have to deal with storing data in the trace. For very small models it might be a bit slower, as we do have some small constant overhead for sending messages.

Since I use a couple of py3 only features, the old code is still used if py<3

TODO

  • Add warnings (they are not working yet)
  • Fix progress bar
  • Better formatting of exceptions
  • Test on windows
  • Optional, improve exception printing in jupyter via set_custom_exc

@fonnesbeck
Copy link
Member

I was under the impression that its simpler to use concurrent.futures instead of multiprocessing. Not true?

@aseyboldt
Copy link
Member Author

For computations that don't need to communicate with the main process and that are never aborted I think that is true. But we can profit quite a bit if we keep talking to the main process in our use case. This makes status updates (progress bar), writing results to files and interrupting the sampler much easier. And that is not supported well by concurrent.futures.

@aseyboldt
Copy link
Member Author

Another nice bonus: We can return a partial trace now if parallel sampling is interrupted with a KeyboardInterrupt.

@aseyboldt aseyboldt force-pushed the multiproc branch 2 times, most recently from ec02da7 to c6876c8 Compare June 12, 2018 18:25
@twiecki
Copy link
Member

twiecki commented Jun 14, 2018

One other option to entertain is that since we plan to drop Python 2 support anyhow, we already start by deprecating parallel sampling for python 2. I suppose it's not really worth it because we have something that sorta works and we would just take that away for seemingly no good reason. Yet, the aggregate cost of dealing with it is quite high.

@junpenglao
Copy link
Member

What is our timeline on py27?

@aseyboldt
Copy link
Member Author

There are still some tests on py2.7/float32 that are failing (due to issues with sqlite and text backends), but I think that is just the unrelated #3018. So this is ready for review/merge.

@junpenglao
Copy link
Member

This is the one progress bar version right (ie, no snail race)?

@aseyboldt
Copy link
Member Author

Yes

shape_dtypes[var.name] = (shape, dtype)
return shape_dtypes

def stop_tuning(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is much better.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, need release note.


self._progress = None
if progressbar:
self._progress = tqdm_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to add the position argument here in order to not have tqdms interfering with each other

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one progress bar now, but that progress bar counts samples from all chains.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants