Skip to content

Commit

Permalink
'proto'
Browse files Browse the repository at this point in the history
  • Loading branch information
horpto committed Jan 28, 2019
1 parent 0bc2ca6 commit 16fa8fa
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 66 deletions.
88 changes: 22 additions & 66 deletions gensim/models/ldamulticore.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,76 +248,41 @@ def update(self, corpus, chunks_as_numpy=False):
"consider increasing the number of passes or iterations to improve accuracy"
)

job_queue = Queue(maxsize=2 * self.workers)
result_queue = Queue()

# rho is the "speed" of updating; TODO try other fncs
# pass_ + num_updates handles increasing the starting t for each pass,
# while allowing it to "reset" on the first pass of each update
def rho():
return pow(self.offset + pass_ + (self.num_updates / self.chunksize), -self.decay)

logger.info("training LDA model using %i processes", self.workers)
pool = Pool(self.workers, worker_e_step, (job_queue, result_queue,))
pool = utils.MapPool(worker_e_step, self.workers)

other = LdaState(self.eta, self.state.sstats.shape) # temporary LdaState, reseted on every pass
eval_every = self.eval_every or 0

for pass_ in range(self.passes):
queue_size, reallen = [0], 0
other = LdaState(self.eta, self.state.sstats.shape)

def process_result_queue(force=False):
"""
Clear the result queue, merging all intermediate results, and update the
LDA model if necessary.
"""
merged_new = False
while not result_queue.empty():
other.merge(result_queue.get())
queue_size[0] -= 1
merged_new = True
if (force and merged_new and queue_size[0] == 0) or (not self.batch and (other.numdocs >= updateafter)):
other.reset()

mapper = pool.map(corpus, args=(self,), chunksize=self.chunksize, as_numpy=chunks_as_numpy)
for processed_state in mapper:
other.merge(processed_state)
if other.numdocs >= updateafter:
self.do_mstep(rho(), other, pass_ > 0)
other.reset()
if self.eval_every is not None \
and ((force and queue_size[0] == 0)
or (self.eval_every != 0 and (self.num_updates / updateafter) % self.eval_every == 0)):
self.log_perplexity(chunk, total_docs=lencorpus)

chunk_stream = utils.grouper(corpus, self.chunksize, as_numpy=chunks_as_numpy)
for chunk_no, chunk in enumerate(chunk_stream):
reallen += len(chunk) # keep track of how many documents we've processed so far

# put the chunk into the workers' input job queue
chunk_put = False
while not chunk_put:
try:
job_queue.put((chunk_no, chunk, self), block=False, timeout=0.1)
chunk_put = True
queue_size[0] += 1
logger.info(
"PROGRESS: pass %i, dispatched chunk #%i = documents up to #%i/%i, "
"outstanding queue size %i",
pass_, chunk_no, chunk_no * self.chunksize + len(chunk), lencorpus, queue_size[0]
)
except queue.Full:
# in case the input job queue is full, keep clearing the
# result queue, to make sure we don't deadlock
process_result_queue()

process_result_queue()
# endfor single corpus pass

# wait for all outstanding jobs to finish
while queue_size[0] > 0:
process_result_queue(force=True)

if reallen != lencorpus:
if eval_every > 0 and (self.num_updates / updateafter) % eval_every == 0:
self.log_perplexity(pool.chunk, total_docs=lencorpus)

self.do_mstep(rho(), other, pass_ > 0)
self.log_perplexity(pool.chunk, total_docs=lencorpus)

if pool.processed != lencorpus:
raise RuntimeError("input corpus size changed during training (don't use generators as input)")
# endfor entire update

pool.terminate()


def worker_e_step(input_queue, result_queue):
def worker_e_step(chunk, worker_lda):
"""Perform E-step for each job.
Parameters
Expand All @@ -329,15 +294,6 @@ def worker_e_step(input_queue, result_queue):
After the worker finished the job, the state of the resulting (trained) worker model is appended to this queue.
"""
logger.debug("worker process entering E-step loop")
while True:
logger.debug("getting a new job")
chunk_no, chunk, worker_lda = input_queue.get()
logger.debug("processing chunk #%i of %i documents", chunk_no, len(chunk))
worker_lda.state.reset()
worker_lda.do_estep(chunk) # TODO: auto-tune alpha?
del chunk
logger.debug("processed chunk, queuing the result")
result_queue.put(worker_lda.state)
del worker_lda # free up some memory
logger.debug("result put")
worker_lda.state.reset()
worker_lda.do_estep(chunk) # TODO: auto-tune alpha?
return worker_lda.state
91 changes: 91 additions & 0 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import tempfile
from functools import wraps
import multiprocessing
import queue
import shutil
import sys
import subprocess
Expand Down Expand Up @@ -1309,6 +1310,96 @@ def chunkize(corpus, chunksize, maxsize=0, as_numpy=False):
for chunk in chunkize_serial(corpus, chunksize, as_numpy=as_numpy):
yield chunk

class MapPool(object):

def __init__(self, func, workers=None, initializer=None, initargs=None, queuesize=None):
if workers is None:
workers = cpu_count()
if queuesize is None:
queuesize = workers * 2

self.job_queue = multiprocessing.Queue(maxsize=queuesize)
self.result_queue = multiprocessing.Queue()
self.pool = multiprocessing.Pool(workers,
initializer=self._worker_starter,
initargs=(self.job_queue, self.result_queue, func, initializer, initargs))

# keep track of how many documents we've processed so far
# resetted before map
self.processed = 0
self.queue_watermark = 0
self.chunk = None

@staticmethod
def _worker_starter(job_queue, result_queue, func, initializer, initargs):
logger.debug("worker process started")
if initializer is not None:
if initargs is None:
initargs = ()
initializer(*initargs)

while True:
logger.debug("getting a new job")
chunk_no, chunk, args = job_queue.get()
logger.debug("processing chunk #%i of %i documents", chunk_no, len(chunk))
result = func(chunk, *args)
del chunk
logger.debug("processed chunk, queuing the result")
result_queue.put(result)
del result # free up some memory
logger.debug("result put")

def map(self, iterable, args=None, chunksize=None, as_numpy=False):
self.queue_watermark = 0
self.processed = 0

if args is None:
args = ()

chunk_stream = enumerate(grouper(iterable, chunksize, as_numpy=as_numpy))
for chunk_no, self.chunk in chunk_stream:
self.processed += len(self.chunk)
# put the chunk into the workers' input job queue
job = (chunk_no, self.chunk, args)
while True:
try:
self.job_queue.put(job, block=False)
self.queue_watermark += 1
logger.info(
"PROGRESS: dispatched chunk #%i = documents up to #%i, "
"outstanding queue size %i",
chunk_no, chunk_no * chunksize + len(self.chunk), self.queue_watermark
)
break
except queue.Full:
# in case the input job queue is full, keep clearing the
# result queue, to make sure we don't deadlock
while not self.result_queue.empty():
self.queue_watermark -= 1
yield self.result_queue.get()

if not self.result_queue.empty():
self.queue_watermark -= 1
yield self.result_queue.get()

# wait for all outstanding jobs to finish
while self.queue_watermark > 0:
self.queue_watermark -= 1
yield self.result_queue.get()
assert self.queue_watermark == 0, "All chunks should be processed, left %s chunks" % self.queue_watermark

def __del__(self):
self.terminate()

def terminate(self):
self.pool.terminate()

def close(self):
self.pool.close()

def join(self):
self.pool.join()


def smart_extension(fname, ext):
"""Append a file extension `ext` to `fname`, while keeping compressed extensions like `.bz2` or
Expand Down

0 comments on commit 16fa8fa

Please sign in to comment.