Skip to content

Latest commit

 

History

History
36 lines (28 loc) · 1.4 KB

README.md

File metadata and controls

36 lines (28 loc) · 1.4 KB

tensorgo

Using the tensorgo API for TensorFlow Async Model Parallel

The system is designed to be simple to use, while maintaining efficiency speedup and approximate model performence(may be better). Three lines to transfer your model into a multi-gpu trainer.

from tensorgo.train.multigpu import MultiGpuTrainer
from tensorgo.train.config import TrainConfig
# [Define your own model using initial tensorflow API]
bow_model = ...

train_config = TrainConfig(dataset=training_dataset, model=bow_model, n_towers=5, commbatch=1500)
trainer = MultiGpuTrainer(train_config)
probs, labels = trainer.run([model.prob, model.label], 
                            feed_dict={model.dropout_prob=0.2,
                                        model.bacth_norm_on=True})

ToDo list

  • add benchmark for image model, like cifar10 benchmark of official TF benchmak
  • add unit test
  • add model saver
  • add user-defined api for model output
  • sync all the parameters between workers and server before training
  • add feed_dict api for dropout/batchnorm paramenters

Reference