Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Aligns training and testing data #33

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jrbourbeau
Copy link
Member

This PR is to ensure that training and testing data have balance partitions

Closes #32

@mrocklin
Copy link
Member

I think that this is a good start. However I think we've seen cases where the divisions are the same and yet the number of rows in each partition still differ. I think that in that case we still raise a non-informative error.

@jrbourbeau
Copy link
Member Author

Thanks for the feedback @mrocklin!

I've added a new align_training_data function to rechunk/repartition labels so it has the same number of rows per partition as data. Since we can load all the training data into distributed memory, we can compute the chunk sizes for data and labels. If they're different, then .rechunk is called on labels accordingly.

@jrbourbeau
Copy link
Member Author

I've also added some tests, but am running into issues with test failures. Some failures seem to be related to changes in this PR, while other failures are also in master.

For example, running pytest dask_xgboost/tests/test_core.py::test_classifier fails with a ChildProcessError in both this PR and master.

test_classifier traceback
[gw0] darwin -- Python 3.6.6 /Users/jbourbeau/miniconda/envs/quansight/bin/python
loop = <tornado.platform.asyncio.AsyncIOLoop object at 0x1c26981048>

    def test_classifier(loop):  # noqa
>       with cluster() as (s, [a, b]):

dask_xgboost/tests/test_core.py:38:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../miniconda/envs/quansight/lib/python3.6/contextlib.py:81: in __enter__
    return next(self.gen)
../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:626: in cluster
    scheduler_q = mp_context.Queue()
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/context.py:102: in Queue
    return Queue(maxsize, ctx=self.get_context())
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/queues.py:42: in __init__
    self._rlock = ctx.Lock()
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/context.py:67: in Lock
    return Lock(ctx=self.get_context())
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/synchronize.py:163: in __init__
    SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/synchronize.py:81: in __init__
    register(self._semlock.name)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:83: in register
    self._send('REGISTER', name)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:90: in _send
    self.ensure_running()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <multiprocessing.semaphore_tracker.SemaphoreTracker object at 0xa221cf390>

    def ensure_running(self):
        '''Make sure that semaphore tracker process is running.

        This can be run from any process.  Usually a child process will use
        the semaphore created by its parent.'''
        with self._lock:
            if self._pid is not None:
                # semaphore tracker was launched before, is it still running?
>               pid, status = os.waitpid(self._pid, os.WNOHANG)
E               ChildProcessError: [Errno 10] No child processes

../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:46: ChildProcessError

While pytest dask_xgboost/tests/test_core.py::test_basic passes on master, but fails in this PR with an AssertionError: yield from wasn't used with future error. Clearly I'm doing something wrong involving the futures interface, but I'm not sure where I'm going wrong.

test_basic traceback
[gw0] darwin -- Python 3.6.6 /Users/jbourbeau/miniconda/envs/quansight/bin/python
def test_func():
        del _global_workers[:]
        _global_clients.clear()
        active_threads_start = set(threading._active)

        reset_config()

        dask.config.set({'distributed.comm.timeouts.connect': '5s'})
        # Restore default logging levels
        # XXX use pytest hooks/fixtures instead?
        for name, level in logging_levels.items():
            logging.getLogger(name).setLevel(level)

        result = None
        workers = []

        with pristine_loop() as loop:
            with check_active_rpc(loop, active_rpc_timeout):
                @gen.coroutine
                def coro():
                    with dask.config.set(config):
                        s = False
                        for i in range(5):
                            try:
                                s, ws = yield start_cluster(
                                    ncores, scheduler, loop, security=security,
                                    Worker=Worker, scheduler_kwargs=scheduler_kwargs,
                                    worker_kwargs=worker_kwargs)
                            except Exception as e:
                                logger.error("Failed to start gen_cluster, retryng", exc_info=True)
                            else:
                                workers[:] = ws
                                args = [s] + workers
                                break
                        if s is False:
                            raise Exception("Could not start cluster")
                        if client:
                            c = yield Client(s.address, loop=loop, security=security,
                                             asynchronous=True, **client_kwargs)
                            args = [c] + args
                        try:
                            future = func(*args)
                            if timeout:
                                future = gen.with_timeout(timedelta(seconds=timeout),
                                                          future)
                            result = yield future
                            if s.validate:
                                s.validate_state()
                        finally:
                            if client:
                                yield c._close(fast=s.status == 'closed')
                            yield end_cluster(s, workers)
                            yield gen.with_timeout(timedelta(seconds=1),
                                                   cleanup_global_workers())

                        try:
                            c = yield default_client()
                        except ValueError:
                            pass
                        else:
                            yield c._close(fast=True)

                        raise gen.Return(result)

>               result = loop.run_sync(coro, timeout=timeout * 2 if timeout else timeout)

../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:909:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/ioloop.py:576: in run_sync
    return future_cell[0].result()
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:1147: in run
    yielded = self.gen.send(value)
../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:890: in coro
    result = yield future
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:1133: in run
    value = future.result()
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:326: in wrapper
    yielded = next(result)
dask_xgboost/tests/test_core.py:144: in test_basic
    dbst = yield dxgb.train(c, param, ddf, dlabels)
dask_xgboost/core.py:244: in train
    data, labels = align_training_data(client, data, labels)
dask_xgboost/core.py:191: in align_training_data
    data_chunks = tuple(data.map_partitions(len).compute())
../dask/dask/base.py:156: in compute
    (result,) = compute(self, traverse=False, **kwargs)
../dask/dask/base.py:398: in compute
    return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
../dask/dask/base.py:398: in <listcomp>
    return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
../dask/dask/dataframe/core.py:74: in finalize
    return _concat(results)
../dask/dask/dataframe/core.py:58: in _concat
    if isinstance(first(core.flatten(args)), np.ndarray):
../../miniconda/envs/quansight/lib/python3.6/site-packages/toolz/itertoolz.py:368: in first
    return next(iter(seq))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

seq = <Future finished exception=CancelledError("('len-ebcfd41fecadb4b7f7d33c5221f4960b', 2)",)>
container = <class 'list'>

    def flatten(seq, container=list):
        """

        >>> list(flatten([1]))
        [1]

        >>> list(flatten([[1, 2], [1, 2]]))
        [1, 2, 1, 2]

        >>> list(flatten([[[1], [2]], [[1], [2]]]))
        [1, 2, 1, 2]

        >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
        [(1, 2), (1, 2)]

        >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
        [1, 2, 3, 4]
        """
        if isinstance(seq, str):
            yield seq
        else:
>           for item in seq:
E           AssertionError: yield from wasn't used with future

../dask/dask/core.py:272: AssertionError

Any thoughts you may have here would be very appreciated

@jrbourbeau jrbourbeau mentioned this pull request Nov 20, 2018
@mrocklin
Copy link
Member

It would be good to verify that we compute things only once, otherwise we may load and preprocess our data many times. In practice this can be annoying. There are currently two issues stopping this:

  1. Within align_training_data we call compute on the shape twice, once for data and once for labels. In the common case where these have a common history that common history will be recomputed unnecessarily.
  2. We then call client.compute (which is more like persist today) within _train

We have to persist the data in memory in the _train function. Ideally we would verify alignment only after this stage when we know that it's cheap and won't result in any additional recomputation.

Generally I find things like this by trying them out on a small problem and watching the diagnostic dashboard.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants