Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper way of making a data generator which can handle multiple workers #1638

Closed
parag2489 opened this issue Feb 4, 2016 · 18 comments
Closed

Comments

@parag2489
Copy link
Contributor

I am having difficulty in writing a data generator which can work with multiple workers. My data generator works fine with one worker, but with > 1 workers, it gives me the following error:

UnboundLocalError: local variable 'generator_output' referenced before assignment

I have tried many things such as declaring X_train, X_test, y_train, y_test as global. I also tried wrapping the myGenerator() function into a Python mutex. Those solutions didn't work.

My system: Ubuntu 14.04, Python 2.7, Tesla K40 GPU
A script and its sample out to reproduce the issue is given below:

import time
import logging
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution1D, Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras import callbacks

class printbatch(callbacks.Callback):
    def on_epoch_begin(self, epoch, logs={}):
        print(logs)
    def on_epoch_end(self, epoch, logs={}):
        print(logs)

nb_classes = 10
nb_epoch = 12

img_rows, img_cols = 28, 28
nb_filters = 32
nb_pool = 2
nb_conv = 3

def myGenerator():
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    y_train = np_utils.to_categorical(y_train,10)
    X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
    X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255
    while 1:
        for i in range(1875):
            print("came till here")
            yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]
            print("and here")
        print("and here too")

model = Sequential()

model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                        border_mode='valid',
                        input_shape=(1, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adadelta')

pb = printbatch()
my_generator = myGenerator()

print("Built the generator")

t0=time.time()
model.fit_generator(my_generator, samples_per_epoch = 5000, nb_epoch = 2, verbose=2, show_accuracy=True, callbacks=[pb], validation_data=None, class_weight=None, nb_worker=2)
t1=time.time()

print("Training completed in " + str(t1-t0) + " seconds")

Sample output:

Built the generator
Epoch 1/2
{}
Traceback (most recent call last):
  File "/testGenerator_multiWorkers.py", line 72, in <module>
    model.fit_generator(my_generator, samples_per_epoch = 5000, nb_epoch = 2, verbose=2, show_accuracy=True, callbacks=[pb], validation_data=None, class_weight=None, nb_worker=2)
  File "build/bdist.linux-x86_64/egg/keras/models.py", line 966, in fit_generator

UnboundLocalError: local variable 'generator_output' referenced before assignment

came till here

Process finished with exit code 1

P.S. I searched about this issue but the above issue is not similar to most other ones such as this.

@wongjingping
Copy link

Hi there,

I have encountered this rather uninformative error message multiple times, and here are some of the possible root causes:

  • Generator is not yielding the batches properly. To check that your generator is working correctly, you can just run my_generator.next() to make sure that it is giving the outputs correctly.
  • Errors are thrown in the generator. The error messages are sometimes suppressed in my spyder console, for example when I couldn't download the pickle data files when I called the load_data() function.

Hope this helps!

@parag2489
Copy link
Contributor Author

When I make nb_workers=1, the code works flawlessly - trains and prints the logs etc. That's why I think that problem may not be in the generator. Still I will try to see if myGenerator.next() works properly.

The function load_data() actually works properly, always, irrespective of the number of workers - since it always prints "came till here", which is like a checkpoint in my code.

@parag2489
Copy link
Contributor Author

@wongjingping

I found that making a data generator threadsafe works (of course, you should first check that your data generator has no other errors and they are solely arising from running it on two workers). The detailed procedure is given in this link and that is what I have followed. For completeness, the original code snippet can be modified as follows. The code included only between ###### ... ###### is new, rest is a Keras example.

import time
import logging
import threading
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution1D, Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras import callbacks

class printbatch(callbacks.Callback):
    def on_batch_end(self, batch, logs={}):
        if batch%10 == 0:
            print "Batch " + str(batch) + " ends"
    def on_epoch_begin(self, epoch, logs={}):
        print(logs)
    def on_epoch_end(self, epoch, logs={}):
        print(logs)

nb_classes = 10
nb_epoch = 12

img_rows, img_cols = 28, 28
nb_filters = 32
nb_pool = 2
nb_conv = 3

#################### Now make the data generator threadsafe ####################

class threadsafe_iter:
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def next(self):
        with self.lock:
            return self.it.next()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe.
    """
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g

@threadsafe_generator
def myGenerator():  # write the definition of your data generator
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    y_train = np_utils.to_categorical(y_train,10)
    X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
    X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255
    while 1:
        for i in range(1875):
            yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]
        # print("Came here")

########## Data generator is now threadsafe and should work with multiple workers ##########

model = Sequential()

model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                        border_mode='valid',
                        input_shape=(1, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adadelta')

pb = printbatch()
my_generator = myGenerator()

print("Built the generator")

t0=time.time()
model.fit_generator(my_generator, samples_per_epoch = 5000, nb_epoch = 2, verbose=2, show_accuracy=True, callbacks=[pb], validation_data=None, class_weight=None, nb_worker=2)
t1=time.time()

print("Training completed in " + str(t1-t0) + " seconds")


It works with Python 2.7 and latest Theano. However, if I apply the same trick on the ImageDataGenerator class, it fails. To be clear, I call form an object of ImageDataGenerator class in a function, say myGenerator() and make that function threadsafe as done above. I return the formed ImageDataGenerator object with the help of a multiprocessing queue. Doing this, still gives me the error: UnboundLocalError: local variable 'generator_output' referenced before assignment. The code snippet which tries to make ImageDataGenerator threadsafe is as follows:

@threadsafe_generator
def myGenerator(queue):  # write the definition of your data generator
    datagen = ImageDataGenerator(
        featurewise_center=True,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        width_shift_range=0,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0,  # randomly shift images vertically (fraction of total height)
        horizontal_flip=False,  # randomly flip images
        vertical_flip=False)  # randomly flip images
    q.put(datagen)

You can simply the above function as:

q = Queue()
myGenerator(q)
datagen = q.get()

This doesn't work. As far as this issue is concerned, I am closing it.

@fchollet
Copy link
Collaborator

I found that making a data generator threadsafe works.

Right. It is specified in the documentation that using a number of workers
higher than one should only be done with a thread-safe generator.

On 10 February 2016 at 13:01, parag2489 notifications@github.com wrote:

Closed #1638 #1638.


Reply to this email directly or view it on GitHub
#1638 (comment).

@wongjingping
Copy link

@parag2489 thanks for sharing!

@wjbaibai
Copy link

@parag2489 could u tell me why 32*1875, I cannot figure out the meanings of the parameter? thanks a lot,same problems with u

@wongjingping
Copy link

@wjbaibai There are a total of 60,000 examples, and since he's using a batch size of 32 images, that makes 60,000/32 = 1875 iterations per epoch.

@parag2489
Copy link
Contributor Author

You have to write a loop which will continuously fetch the data in the batches of 32. Training set of MNIST has 60000 examples. Also, 32*1875 = 60000 (or 60000/32 = 1875). So, the following piece of code will give you a chunk of 32 examples at a time.

for i in range(1875):
        yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]

@timehaven
Copy link

The comments and suggestions in this issue and its cousin #1627 were very helpful for me to efficiently process large numbers of images. I wrote it all up in a tutorial fashion that I hope can help others.

https://techblog.appnexus.com/a-keras-multithreaded-dataframe-generator-for-millions-of-image-files-84d3027f6f43

@remiresnap
Copy link

A python2/3 compatible version of the decorator that @parag2489 posted:

class threadsafe_iter:
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self): # Py3
        return next(self.it)

    def next(self):     # Py2
        with self.lock:
            return self.it.next()

@pashakovalenko
Copy link

@remiresnap It looks like your snippet lacks self.lock for Py3 version:

    def __next__(self): # Py3
        with self.lock:
            return next(self.it)

@kbrose
Copy link

kbrose commented Nov 27, 2017

Does making it threadsafe with the above decorator(s) actually result in any speed up, though? I'm finding that it is the same speed (or maybe even a little slower) with my generator. Which makes sense because it looks like it's locking every time it does anything...

@parag2489
Copy link
Contributor Author

parag2489 commented Nov 27, 2017 via email

@zippeurfou
Copy link

zippeurfou commented Dec 9, 2017

@parag2489
Correct me If I am wrong:
If you set up more than one worker you will end up processing the same data twice with the dummy example given and the safe threading implementation.
If so, how could we "know" which working is running in the gen?
i.e. what you would like is something like that
data_size / ( batch_size * number_workers)
The same way multi_gpu_model works.
So you want having each worker working on different part of the dataset.
i.e. assuming you have two workers and a batch size of 32 here are the index of the data they would work:

worker 1:
0 - 16 => i=0
32 - 48 => i=1
64 - 80 => i=2
....

worker 2:
16 - 32 => i=0
48 - 64 => i=1
80 - 96 => i=2
....

in code this would look something like that:

X_train[(i+worker_idx)*(32/n_workers):(i+1+worker_idx)*(32/n_workers)], y_train[(i+worker_idx)*(32/n_workers):(i+1+worker_idx)*(32/n_workers)]

My guess is that I need a generator wrapper.
Or am I missing something?

@ysyyork
Copy link

ysyyork commented Dec 29, 2017

In my case, if the batch size is really large and the data augmentation take really long, even use multithread, the thread safe generator solution is not fast enough because. The lock in the thread safe generator will block the generator from yielding the next batch of data. So technically it's not multithreading anything. Sequence class is a better choice cus it's using thread pool and it's real multithreading.

@gledsonmelotti
Copy link

gledsonmelotti commented May 9, 2018

Hello how are you? I apologize for the inconvenience. Could you help me. I'm trying to create my own generator from the above comments. However, when I apply model.fit_generator, I realize that my network does not use batch_size. For example, if I have 32676 images and batch_size of 64, I should realize 510 iterations per epoch. But my network has 32676 iterations per epoch. My dataset is large and with two channel images, so I need to create my own generator. I can not use the commands ImageDataGenerator, flow_from_directory and model.fit_generator direct from keras, because my images have two channels and these commands only work with 1 and 3 channel images. Would it be possible for you to help me?

I also did a generator for validation. That's why I use validationGenerator ().

I send my own generator to you:

  ######################## Generator ##################################

      def trainingGenerator():
            train_Class1_dir='/media/HD500/RGB_MIN/train/Class1'
            train_Class2_dir='/media/HD500/RGB_MIN/train/Class2'

############################ Class1 ###############################
            X_trainP = []
            trainP_ids = next(os.walk(train_Class1_dir))[2]
            for n, id_ in tqdm(enumerate(trainP_ids), total=len(trainP_ids)):
                  treinamento = train_Class1_dir + '/' + id_
                  X_trainP.append(treinamento)
            Y_trainP = np.ones((len(X_trainP), 1), dtype=np.uint8)
############################ Class 2 ###########################
            X_trainPN = []
            trainPN_ids = next(os.walk(train_Class2_dir))[2]
            for n, id_ in tqdm(enumerate(trainPN_ids), total=len(trainPN_ids)):
                  treinamento = train_Class2_dir + '/' + id_
                  X_trainPN.append(treinamento)
            Y_trainPN = np.zeros((len(X_trainPN), 1), dtype=np.uint8)
############ Dataset of Train ########################
            X_trainFinal = X_trainP + X_trainPN
            Y_train = np.concatenate((Y_trainP,Y_trainPN),axis=0)
            num_classes = np.unique(Y_train).shape[0]
            Y_train = np_utils.to_categorical(Y_train, num_classes) # One-hot encode the labels

 ########################### Image #############################
           img_width, img_height, img_channels = 227, 227, 4
           X_train = np.zeros((len(X_trainFinal), img_width, img_height, img_channels), dtype=np.uint8)
           for n, path1 in tqdm(enumerate(X_trainFinal), total=len(X_trainFinal)):
                   path = path1
                   img = imageio.imread(path)[:,:,:img_channels]
                   img = resize(img, (img_height, img_width), mode='constant', preserve_range=True)
                   X_train[n] = img

           batch_size=64
           X_train = X_train.astype('float32')
           X_train /255
           while 1:
                    for i in range(len(X_train)//batch_size):
                          yield X_train[i*batch_size:(i+1)*batch_size], Y_train[i*batch_size:(i+1)*batch_size]

  MyTrainingGenerator = trainingGenerator()
  MyValidationGenerator = validationGenerator()

  Results_Train = model.fit_generator(MyTrainingGenerator,
                    steps_per_epoch=nb_train_samples // batch_size,
                    epochs=num_epochs,
                    validation_data=MyValidationGenerator, 
                    validation_steps = nb_validation_samples // batch_size,
                    callbacks=[History, checkpointer, csv_logger],
                    verbose=1)

I thank you for your attention,
Gledson Melotti

chiwanpark added a commit to chiwanpark/shopping-classification that referenced this issue Nov 15, 2018
  - Keras requires a thread-safe generator
  - See keras-team/keras#1638
chiwanpark added a commit to chiwanpark/shopping-classification that referenced this issue Nov 15, 2018
  - Keras requires a thread-safe generator
  - See keras-team/keras#1638
chiwanpark added a commit to chiwanpark/shopping-classification that referenced this issue Nov 16, 2018
  - Keras requires a thread-safe generator
  - See keras-team/keras#1638
@bw4sz
Copy link

bw4sz commented Dec 4, 2018

In my case, if the batch size is really large and the data augmentation take really long, even use multithread, the thread safe generator solution is not fast enough because. The lock in the thread safe generator will block the generator from yielding the next batch of data. So technically it's not multithreading anything. Sequence class is a better choice cus it's using thread pool and it's real multithreading.

@ysyyork what strategy did you end up adopting? I have the same issue and I'm seeing that ADDING workers increases the time to process the same number of images.

sample profile

6332465 function calls (6181010 primitive calls) in 171.387 seconds

Ordered by: cumulative time, call count
List reduced from 1212 to 40 due to restriction <40>

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 2/1    0.000    0.000  171.396  171.396 interfaces.py:27(wrapper)
   1    0.000    0.000  171.396  171.396 training.py:1277(fit_generator)
   1    0.000    0.000  171.396  171.396 training_generator.py:21(fit_generator)
170816/99234    0.105    0.000  144.263    0.001 {built-in method builtins.next}
19079  143.157    0.008  143.157    0.008 {method 'acquire' of '_thread.lock' objects}
  11    0.000    0.000  143.140   13.013 data_utils.py:583(get)
   3    0.000    0.000  143.140   47.713 threading.py:263(wait)
  11    0.000    0.000  143.136   13.012 threading.py:533(wait)
  10    0.000    0.000  143.136   14.314 pool.py:601(get)
  10    0.000    0.000  143.136   14.314 pool.py:598(wait)
  10    0.000    0.000   19.579    1.958 training.py:1171(train_on_batch)

when using workers > 1 and multiprocessing = F on a Sequence object. Looks to me its doing alot of waiting.

@MTDzi
Copy link

MTDzi commented Apr 24, 2019

@parag2489

From my experience, if your data generator takes very little time to prepare and yield the batch, then multithreading is not very effective. In other words, if the overhead of distributing the batches to multiple workers is more than the time being saved in preprocessing by parallelizing that operation, then I did not see much improvement.

Isn't Python's Global Interpreter Lock (GIL) the real culprit here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests