Skip to content

Commit

Permalink
added basic pytorch tests for msallreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
vaeksare committed Jul 26, 2019
1 parent c5b1a7f commit 9c0a7ac
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
1 change: 0 additions & 1 deletion horovod/common/ops/msallreduce_operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// =============================================================================

#include "msallreduce_operations.h"
//#include "fusion_buffer_manager.h"
#include <boost/asio/post.hpp>

namespace horovod {
Expand Down
48 changes: 48 additions & 0 deletions test/test_torchmsallreduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from distutils.version import LooseVersion
import collections
import inspect
import itertools
import numpy as np
import os
import tempfile
import torch
import torch.nn.functional as F
import unittest
import warnings

import horovod.torch as hvd

class TorchTests(unittest.TestCase):

def __init__(self, *args, **kwargs):
super(TorchTests, self).__init__(*args, **kwargs)
warnings.simplefilter('module')

def test_horovod_multiple_allreduce_cpu(self):
hvd.init()
size = hvd.size()
if hvd.rank() == 0:
tensors = [torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]),torch.FloatTensor([[9.0, 10.0], [11.0, 12.0]])]
else:
tensors = [torch.FloatTensor([[5.0, 6.0], [7.0, 8.0]]), torch.FloatTensor([[13.0, 14.0], [15.0, 16.0]])]
summed = 0
for tensor in tensors:
summed += hvd.allreduce(tensor, average=False)
print(summed)

def test_horovod_single_allreduce_cpu(self):
hvd.init()
size = hvd.size()
if hvd.rank() == 0:
tensor = torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]])
else:
tensor = torch.FloatTensor([[5.0, 6.0], [7.0, 8.0]])
summed = hvd.allreduce(tensor, average=False)
print(summed)

if __name__ == "__main__":
unittest.main()

0 comments on commit 9c0a7ac

Please sign in to comment.