Skip to content

Commit

Permalink
Merge pull request #9780 from chengduoZH/feature/fix_batch_size_is_li…
Browse files Browse the repository at this point in the history
…ttler_than_gpu_count

Crash training, if the number of samples is less than the count of devices.
  • Loading branch information
chengduo authored Apr 10, 2018
2 parents b1224da + 7e7611d commit e0babe7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ void ParallelExecutor::SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
for (auto it : feed_tensors) {
auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
PADDLE_ENFORCE_EQ(
member_->places_.size(), lod_tensors.size(),
"The number of samples of current batch is less than the count of "
"devices, currently, it is not allowed. (%d vs %d)",
member_->places_.size(), lod_tensors.size());
for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var?
member_->local_scopes_[j]
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/parallel_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __init__(self,
# performance. Worth tunning for other models in the future.
num_threads = len(self._places)
else:
min(len(self._places) * 2, multiprocessing.cpu_count())
num_threads = min(
len(self._places) * 2, multiprocessing.cpu_count())

main = main_program
main = main if main else framework.default_main_program()
Expand Down

0 comments on commit e0babe7

Please sign in to comment.