Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working example with Keras #2333

Open
dzubo opened this issue Nov 2, 2018 · 13 comments
Open

Working example with Keras #2333

dzubo opened this issue Nov 2, 2018 · 13 comments
Labels
documentation Improve or add to documentation

Comments

@dzubo
Copy link

dzubo commented Nov 2, 2018

I have issues running Keras models with Dask when using multiple workers.

Is there any minimal working example?

I try this code:

import numpy as np

import keras
from keras.layers import Input, Dense
from keras.models import Model

import dask
from dask import compute, delayed
from dask.distributed import Client
from distributed.protocol import serialize, deserialize

@delayed
def get_model(id):
    inputs = Input(shape=(10, ))
    x = Dense(20)(inputs)
    predictions = Dense(1, activation='linear')(x)

    model = Model(inputs=inputs, outputs=predictions)
    model.compile(optimizer='RMSProp', loss='mean_absolute_error')
    return model

client = Client()

params = [{'id': 1}, {'id': 2}]

for p in params:
    p['model'] = get_model(p['id'])

print(params)
results = compute(params)
print(results)

gives error message:

Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
^CTraceback (most recent call last):
  File "dask-keras.py", line 35, in <module>
    client = Client()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 628, in __init__
    self.start(timeout=timeout)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 751, in start
    sync(self.loop, self._start, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 275, in sync
    e.wait(10)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/threading.py", line 551, in wait
    signaled = self._cond.wait(timeout)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/threading.py", line 299, in wait
    gotit = waiter.acquire(True, timeout)
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-11, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-2, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-4, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-10, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-12, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-1, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-9, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-5, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-6, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-8, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-3, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-7, started daemon)>
distributed.nanny - WARNING - Worker process 19972 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19975 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19971 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19980 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19977 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19979 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19970 was killed by unknown signal
...

If I run scheduler in the command line:

dask-scheduler
dask-worker --memory-limit 10GB --nprocs 1 --nthreads 6 --name local <scheduler-ip>:8786

and replace in the code

client = Client('<scheduler-ip>:8786')

then I get this:

Using TensorFlow backend.
[{'id': 1, 'model': Delayed('get_model-de411eed-9c7d-49d0-af7d-95da92d39d5d')}, {'id': 2, 'model': Delayed('get_model-452c16b3-85cc-4472-a1e3-35c215a32647')}]
distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/core.py", line 131, in loads
    value = _deserialize(head, fs, deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 179, in deserialize
    return loads(header, frames)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 75, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type Model.
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 139, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 38, in dask_dumps
    header, frames = dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/keras.py", line 22, in serialize_keras_model
    weights = model.get_weights()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/engine/network.py", line 492, in get_weights
    return K.batch_get_value(weights)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2420, in batch_get_value
    return get_session().run(ops)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1025, in _run
    raise RuntimeError('The Session graph is empty.  Add operations to the '
RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

Traceback (most recent call last):
  File "dask-keras.py", line 47, in <module>
    results = compute(params)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/dask/base.py", line 392, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 2308, in get
    direct=direct)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1647, in gather
    asynchronous=asynchronous)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 665, in sync
    return sync(self.loop, func, *args, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 277, in sync
    six.reraise(*error[0])
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/six.py", line 693, in reraise
    raise value
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 262, in f
    result[0] = yield future
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1518, in _gather
    response = yield future
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1567, in _gather_remote
    response = yield self.scheduler.gather(keys=keys)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/core.py", line 574, in send_recv_from_rpc
    result = yield send_recv(comm=comm, op=key, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/core.py", line 451, in send_recv
    response = yield comm.read(deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/tcp.py", line 203, in read
    deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 307, in wrapper
    yielded = next(result)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/utils.py", line 79, in from_frames
    res = _from_frames()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/utils.py", line 65, in _from_frames
    deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/core.py", line 131, in loads
    value = _deserialize(head, fs, deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 179, in deserialize
    return loads(header, frames)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 75, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type Model.
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 139, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 38, in dask_dumps
    header, frames = dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/keras.py", line 22, in serialize_keras_model
    weights = model.get_weights()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/engine/network.py", line 492, in get_weights
    return K.batch_get_value(weights)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2420, in batch_get_value
    return get_session().run(ops)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1025, in _run
    raise RuntimeError('The Session graph is empty.  Add operations to the '
RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().
@mrocklin
Copy link
Member

My experience was that Tensorflow graphs didn't like being created in one thread and then executed in another. I'm not sure though. I think that @bnaul may have some experience here.

@bw4sz
Copy link

bw4sz commented Oct 23, 2019

Anyone coming here should look at this really nice example that I found super helpful.

https://anaconda.org/defusco/keras-dask/notebook

@TomAugspurger
Copy link
Member

cc @AlbertDeFusco. I think that example doesn't use a distributed cluster yet.

@AlbertDeFusco
Copy link

Correct. I have not gotten it working with distributed

@bw4sz
Copy link

bw4sz commented Oct 24, 2019

@AlbertDeFusco Can you give some intuition about when such a strategy might be useful? I'm on a slurm cluster and I'm seeing that my GPU utilization is not 100% but the CPU utilization (single core) is 100%. My thought is that by creating a multi-core dask generator I can feed the keras queue faster. I'm slightly concerned about the amount of overhead involved that might swamp out any performance gains. I'd be interested in wild guesses if this feels like a reasonable use case.

@hoangthienan95
Copy link

@bw4sz did you get it working with SLURM? I'm also on an HPC (LSF). Can you share the code that you did to make it work?

@AlbertDeFusco
Copy link

Hi @bw4sz , I was using my generator in the case where I wanted to train a model with data that was larger than available memory.

Much like the Dask-ML Incremental wrapper my DaskGenerator provides no parallelization. It is an out-of-memory streaming technique compatible with Keras .fit(). While .fit() can do multiprocessing, I was not able to get it to work.

If your dataset is on a Distributed cluster there are somethings that may help performance. 1) persist the transformed data that goes into the model and 2) set the chunksize/partitions to a size that will fit into memory on the client (gpu). I might recommend using the largest possible chunk sizes, but I have no evidence to back this up.

I was not using the Keras .fit() multiprocessing because it caused errors. Here's a helpful quote from a good article on Keras fit_generator

Note that our implementation enables the use of the multiprocessing argument of fit_generator, where the number of threads specified in n_workers are those that generate batches in parallel. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation).

If your goal is to predict in parallel with an already trained model, there may be a way to utilize Distributed, but it might require some initialization per worker. This stackoverflow reply may give you some inspiration to develop a procedure similar to the way dask-xgboost works.

https://stackoverflow.com/a/49133682

@bw4sz
Copy link

bw4sz commented Nov 7, 2019 via email

@bw4sz
Copy link

bw4sz commented Jan 16, 2020

I have a working prediction example with keras for those who come back here.

#Load modules
from __future__ import print_function
import keras
import sys
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import platform
import numpy as np
import glob
import dask
import distributed

print("dask version is {}".format(dask.__version__))
print("distributed version is {}".format(distributed.__version__))
print("keras version is {}".format(keras.__version__))
print(sys.version)
# Define a trained and saved model
def train_model():
    batch_size = 128
    num_classes = 10
    epochs = 1
    
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    # the data, split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    if K.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)
    
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    
    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])
    
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
    
    return model
def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
    return x_test

This fails:

client=distributed.Client()
model = keras.models.load_model("MNIST.h5")
x_test = load_data()
batch_array = np.split(x_test,100)
results = []
for batch in batch_array:
    prediction = dask.delayed(model.predict_on_batch)(batch)
    results.append(prediction)
#Gather
results = dask.compute(*results)

this succeeds

#Example 2 LocalCluster - load data and model on each worker
model = dask.delayed(keras.models.load_model)("MNIST.h5")
x_test = dask.delayed(load_data)()
batch_array = np.split(x_test,100)
results = []
for batch in batch_array:
    prediction = dask.delayed(model.predict_on_batch)(batch)
    results.append(prediction)

#Gather
results = dask.compute(*results)
results[0].shape

@bw4sz
Copy link

bw4sz commented Mar 9, 2020

I also wanted to add here that if you are repeatedly loading a model on a worker to perform prediction (I could not get the keras seralize model to work on GPU), make sure to clear the tensorflow backend each time or else you will see a steady scary growth in memory until it spills.

    from keras import backend as K            
    K.clear_session()

calling gc.collect() had no effect, you must clear the session.

@rileyhun
Copy link

rileyhun commented May 21, 2020

#Example 2 LocalCluster - load data and model on each worker
model = dask.delayed(keras.models.load_model)("MNIST.h5")
x_test = dask.delayed(load_data)()
results = []
for batch in batch_array:
prediction = dask.delayed(model.predict_on_batch)(batch)
results.append(prediction)

#Gather
results = dask.compute(*results)
results[0].shape

What is batch_array? You don't define it in your code. Where is x_test used?

@bw4sz
Copy link

bw4sz commented May 24, 2020

I saw this too, it was cut out from a ipython notebook.

https://github.com/weecology/NEON_crown_maps/blob/master/dask_keras_example.ipynb

#Compute prediction in batch loop of size 100 (slightly contrived example)
batch_array = np.split(x_test,100)

edited above.

I was reviewing this and it still needs more thought. Yes it runs, but the predict function would need to be pretty slow to make dask useful here.

@rileyhun
Copy link

rileyhun commented May 25, 2020

I saw this too, it was cut out from a ipython notebook.

https://github.com/weecology/NEON_crown_maps/blob/master/dask_keras_example.ipynb

#Compute prediction in batch loop of size 100 (slightly contrived example)
batch_array = np.split(x_test,100)

edited above.

I was reviewing this and it still needs more thought. Yes it runs, but the predict function would need to be pretty slow to make dask useful here.

Thanks very much for confirming! I guessed that that's what it was and I got it working as well. Even if there isn't a speed-up, it helps offload memory usage from an API we deployed with 2GB memory restriction.

@GenevieveBuckley GenevieveBuckley added the documentation Improve or add to documentation label Oct 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improve or add to documentation
Projects
None yet
Development

No branches or pull requests

8 participants