Skip to content

Commit

Permalink
fix(torch): class_weights with multigpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Oct 11, 2022
1 parent 17b8cbb commit 9c1ed4c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,15 @@ namespace dd
if (_devices[i] != _main_device)
{
r.module = _module.clone(_devices[i]);

r.module->train();
torch::Tensor class_weights_i
= torch::numel(class_weights) == 0
? class_weights
: class_weights.to(_devices[i]);
r.loss = std::make_shared<TorchLoss>(
_loss, r.module->has_model_loss(), _seq_training, _timeserie,
_regression, _classification, _segmentation, _ctc,
class_weights, _reg_weight, *r.module, this->_logger);
class_weights_i, _reg_weight, *r.module, this->_logger);
}
}
_module.train();
Expand Down

0 comments on commit 9c1ed4c

Please sign in to comment.