Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the volume(DistArray<Tile,Policy>) function. #457

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 54 additions & 29 deletions src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,9 +869,10 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// first minimally contains the same number of elements as
/// the tile.
/// \throw TiledArray::Exception if the tile is already initialized.
template <typename Integer, typename InIter,
typename = std::enable_if_t<(std::is_integral_v<Integer>)&&detail::
is_input_iterator<InIter>::value>>
template <
typename Integer, typename InIter,
typename = std::enable_if_t<(std::is_integral_v<Integer>) &&
detail::is_input_iterator<InIter>::value>>
typename std::enable_if<detail::is_input_iterator<InIter>::value>::type set(
const std::initializer_list<Integer>& i, InIter first) {
set<std::initializer_list<Integer>>(i, first);
Expand Down Expand Up @@ -964,10 +965,9 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// \throw TiledArray::Exception if index \c i has the wrong rank. Strong
/// throw guarantee.
/// \throw TiledArray::Exception if tile \c i is already set.
template <
typename Index, typename Value,
typename = std::enable_if_t<
(std::is_integral_v<Index>)&&is_value_or_future_to_value_v<Value>>>
template <typename Index, typename Value,
typename = std::enable_if_t<(std::is_integral_v<Index>) &&
is_value_or_future_to_value_v<Value>>>
void set(const std::initializer_list<Index>& i, Value&& v) {
set<std::initializer_list<Index>>(i, std::forward<Value>(v));
}
Expand Down Expand Up @@ -1459,7 +1459,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
shape() & typeid(pmap().get()).hash_code();
int64_t count = 0;
for (auto it = begin(); it != end(); ++it) ++count;
ar& count;
ar & count;
for (auto it = begin(); it != end(); ++it) ar & it->get();
}

Expand All @@ -1476,39 +1476,39 @@ class DistArray : public madness::archive::ParallelSerializableObject {
auto& world = TiledArray::get_default_world();

std::size_t typeid_hash = 0l;
ar& typeid_hash;
ar & typeid_hash;
if (typeid_hash != typeid(*this).hash_code())
TA_EXCEPTION(
"DistArray::serialize: source DistArray type != this DistArray type");

ProcessID world_size = -1;
ProcessID world_rank = -1;
ar& world_size& world_rank;
ar & world_size & world_rank;
if (world_size != world.size() || world_rank != world.rank())
TA_EXCEPTION(
"DistArray::serialize: source DistArray world != this DistArray "
"world");

trange_type trange;
shape_type shape;
ar& trange& shape;
ar & trange & shape;

// use default pmap, ensure it's the same pmap used to serialize
auto volume = trange.tiles_range().volume();
auto pmap = detail::policy_t<DistArray>::default_pmap(world, volume);
size_t pmap_hash_code = 0;
ar& pmap_hash_code;
ar & pmap_hash_code;
if (pmap_hash_code != typeid(pmap.get()).hash_code())
TA_EXCEPTION(
"DistArray::serialize: source DistArray pmap != this DistArray pmap");
pimpl_.reset(
new impl_type(world, std::move(trange), std::move(shape), pmap));

int64_t count = 0;
ar& count;
ar & count;
for (auto it = begin(); it != end(); ++it, --count) {
Tile tile;
ar& tile;
ar & tile;
this->set(it.ordinal(), std::move(tile));
}
if (count != 0)
Expand Down Expand Up @@ -1541,27 +1541,27 @@ class DistArray : public madness::archive::ParallelSerializableObject {
// make sure source data matches the expected type
// TODO would be nice to be able to convert the data upon reading
std::size_t typeid_hash = 0l;
localar& typeid_hash;
localar & typeid_hash;
if (typeid_hash != typeid(*this).hash_code())
TA_EXCEPTION(
"DistArray::load: source DistArray type != this DistArray type");

// make sure same number of clients for every I/O node
int num_io_clients = 0;
localar& num_io_clients;
localar & num_io_clients;
if (num_io_clients != ar.num_io_clients())
TA_EXCEPTION("DistArray::load: invalid parallel archive");

trange_type trange;
shape_type shape;
localar& trange& shape;
localar & trange & shape;

// send trange and shape to every client
for (ProcessID p = 0; p < world.size(); ++p) {
if (p != me && ar.io_node(p) == me) {
world.mpi.Send(int(1), p, tag); // Tell client to expect the data
madness::archive::MPIOutputArchive dest(world, p);
dest& trange& shape;
dest & trange & shape;
dest.flush();
}
}
Expand All @@ -1573,13 +1573,13 @@ class DistArray : public madness::archive::ParallelSerializableObject {
new impl_type(world, std::move(trange), std::move(shape), pmap));

int64_t count = 0;
localar& count;
localar & count;
for (size_t ord = 0; ord != volume; ++ord) {
if (!is_zero(ord)) {
auto owner_rank = pmap->owner(ord);
if (ar.io_node(owner_rank) == me) {
Tile tile;
localar& tile;
localar & tile;
this->set(ord, std::move(tile));
--count;
}
Expand All @@ -1598,7 +1598,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
world.mpi.Recv(flag, p, tag);
TA_ASSERT(flag == 1);
madness::archive::MPIInputArchive source(world, p);
source& trange& shape;
source & trange & shape;

// use default pmap
auto volume = trange.tiles_range().volume();
Expand Down Expand Up @@ -1643,7 +1643,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
}
}
}
localar& count;
localar & count;
for (size_t ord = 0; ord != volume; ++ord) {
if (!is_zero(ord)) {
auto owner_rank = pmap()->owner(ord);
Expand Down Expand Up @@ -1857,12 +1857,37 @@ auto rank(const DistArray<Tile, Policy>& a) {
return a.trange().tiles_range().rank();
}

///
/// \brief Get the total elements in the non-zero tiles of an array.
/// For tensor-of-tensor tiles, the total is the sum of the number of
/// elements in the inner tensors of non-zero tiles.
///
template <typename Tile, typename Policy>
size_t volume(const DistArray<Tile, Policy>& a) {
// this is the number of tiles
if (a.size() > 0) // assuming dense shape
return a.trange().elements_range().volume();
return 0;
size_t volume(const DistArray<Tile, Policy>& array) {
std::atomic<size_t> vol = 0;

auto local_vol = [&vol](Tile const& in_tile) {
if constexpr (detail::is_tensor_of_tensor_v<Tile>) {
auto reduce_op = [](size_t& MADNESS_RESTRICT result, auto&& arg) {
result += arg->total_size();
};
auto join_op = [](auto& MADNESS_RESTRICT result, size_t count) {
result += count;
};
vol += in_tile.reduce(reduce_op, join_op, size_t{0});
} else
vol += in_tile.total_size();
};

for (auto&& local_tile_future : array)
array.world().taskq.add(local_vol, local_tile_future.get());

array.world().gop.fence();

size_t vol_ = vol;
array.world().gop.sum(&vol_, 1);

return vol_;
}

template <typename Tile, typename Policy>
Expand Down Expand Up @@ -2002,13 +2027,13 @@ template <class Tile, class Policy>
void save(const TiledArray::DistArray<Tile, Policy>& x,
const std::string name) {
archive::ParallelOutputArchive<> ar2(x.world(), name.c_str(), 1);
ar2& x;
ar2 & x;
}

template <class Tile, class Policy>
void load(TiledArray::DistArray<Tile, Policy>& x, const std::string name) {
archive::ParallelInputArchive<> ar2(x.world(), name.c_str(), 1);
ar2& x;
ar2 & x;
}

} // namespace madness
Expand Down
53 changes: 53 additions & 0 deletions tests/dist_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,57 @@ BOOST_AUTO_TEST_CASE(rebind) {
std::is_same_v<TiledArray::detail::complex_t<SpArrayTD>, SpArrayTZ>);
}

BOOST_AUTO_TEST_CASE(volume) {
using T = Tensor<double>;
using ToT = Tensor<T>;
using Policy = SparsePolicy;
using ArrayToT = DistArray<ToT, Policy>;

size_t constexpr nrows = 3;
size_t constexpr ncols = 4;
TiledRange const trange({{0, 2, 5, 7}, {0, 5, 7, 10, 12}});
TA_ASSERT(trange.tiles_range().extent().at(0) == nrows &&
trange.tiles_range().extent().at(1) == ncols,
"Following code depends on this condition.");

// this Range is used to construct all inner tensors of the tile with
// tile index @c tix.
auto inner_dims = [nrows, ncols](Range::index_type const& tix) -> Range {
static std::array<size_t, nrows> const rows{7, 8, 9};
static std::array<size_t, ncols> const cols{7, 8, 9, 10};

TA_ASSERT(tix.size() == 2, "Only rank-2 tensor expected.");
return Range({rows[tix.at(0) % nrows], cols[tix.at(1) % ncols]});
};

// let's make all 'diagonal' tiles zero
auto zero_tile = [](Range::index_type const& tix) -> bool {
return tix.at(0) == tix.at(1);
};

auto make_tile = [inner_dims, zero_tile, &trange](auto& tile,
auto const& rng) {
auto&& tix = trange.element_to_tile(rng.lobound());
if (zero_tile(tix))
return 0.;
else {
tile = ToT(rng, [inner_rng = inner_dims(tix)](auto&&) {
return T(inner_rng, 0.1);
});
return tile.norm();
}
};

auto& world = get_default_world();
auto array = make_array<ArrayToT>(world, trange, make_tile);

// manually compute the volume of array
size_t vol = 0;
for (auto&& tix : trange.tiles_range())
if (!zero_tile(tix))
vol += trange.tile(tix).volume() * inner_dims(tix).volume();

BOOST_REQUIRE(vol == TA::volume(array));
}

BOOST_AUTO_TEST_SUITE_END()
Loading