Skip to content

Commit

Permalink
amends 577fda2 to always use the old implementation of ta_tensor_to_u…
Browse files Browse the repository at this point in the history
…m_tensor when element conversion is needed
  • Loading branch information
evaleev committed Jun 26, 2024
1 parent d2dd697 commit 9341917
Showing 1 changed file with 44 additions and 25 deletions.
69 changes: 44 additions & 25 deletions src/TiledArray/device/btas_um_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct ArchiveLoadImpl<Archive, TiledArray::btasUMTensorVarray<T>> {
TiledArray::btasUMTensorVarray<T> &t) {
TiledArray::Range range{};
TiledArray::device_um_btas_varray<T> store{};
ar &range &store;
ar & range & store;
t = TiledArray::btasUMTensorVarray<T>(std::move(range), std::move(store));
// device::setDevice(TiledArray::deviceEnv::instance()->default_device_id());
// auto &stream = device::stream_for(range);
Expand All @@ -83,7 +83,7 @@ struct ArchiveStoreImpl<Archive, TiledArray::btasUMTensorVarray<T>> {
auto stream = TiledArray::device::stream_for(t.range());
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
t.storage(), stream);
ar &t.range() & t.storage();
ar & t.range() & t.storage();
}
};

Expand Down Expand Up @@ -674,25 +674,12 @@ template <typename UMTensor, typename TATensor, typename Policy>
typename std::enable_if<!std::is_same<UMTensor, TATensor>::value,
TiledArray::DistArray<UMTensor, Policy>>::type
ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
auto convert_tile_memcpy = [](const TATensor &tile) {
/// UMTensor must be wrapped into TA::Tile

using Tensor = typename UMTensor::tensor_type;

auto stream = device::stream_for(tile.range());
typename Tensor::storage_type storage;
make_device_storage(storage, tile.range().area(), stream);
Tensor result(tile.range(), std::move(storage));

DeviceSafeCall(
device::memcpyAsync(result.data(), tile.data(),
tile.size() * sizeof(typename Tensor::value_type),
device::MemcpyDefault, stream));

device::sync_madness_task_with(stream);
return TiledArray::Tile<Tensor>(std::move(result));
};
using inT = typename TATensor::value_type;
using outT = typename UMTensor::value_type;
// check if element conversion is necessary
constexpr bool T_conversion = !std::is_same_v<inT, outT>;

// this is safe even when need to convert element types, but less efficient
auto convert_tile_um = [](const TATensor &tile) {
/// UMTensor must be wrapped into TA::Tile

Expand All @@ -711,14 +698,46 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
result.storage(), stream);

// N.B. move! without it have D-to-H transfer due to calling UM
// allocator construct() on the host
return TiledArray::Tile<Tensor>(std::move(result));
};

const char *use_legacy_conversion =
std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION");
auto um_array = use_legacy_conversion
? to_new_tile_type(array, convert_tile_um)
: to_new_tile_type(array, convert_tile_memcpy);
if constexpr (T_conversion) {
auto um_array = to_new_tile_type(array, convert_tile_um);
} else {
// this is more efficient for copying:
// - avoids copy on host followed by UM transfer, instead uses direct copy
// - replaced unneeded copy (which also caused D-to-H transfer due to
// calling UM allocator construct() on the host) by move
// This eliminates all spurious UM traffic in (T) W3 contractions
auto convert_tile_memcpy = [](const TATensor &tile) {
/// UMTensor must be wrapped into TA::Tile

using Tensor = typename UMTensor::tensor_type;

auto stream = device::stream_for(tile.range());
typename Tensor::storage_type storage;
make_device_storage(storage, tile.range().area(), stream);
Tensor result(tile.range(), std::move(storage));

DeviceSafeCall(
device::memcpyAsync(result.data(), tile.data(),
tile.size() * sizeof(typename Tensor::value_type),
device::MemcpyDefault, stream));

device::sync_madness_task_with(stream);
// N.B. move! without it have D-to-H transfer due to calling UM
// allocator construct() on the host
return TiledArray::Tile<Tensor>(std::move(result));
};

const char *use_legacy_conversion =
std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION");
auto um_array = use_legacy_conversion
? to_new_tile_type(array, convert_tile_um)
: to_new_tile_type(array, convert_tile_memcpy);
}

array.world().gop.fence();
return um_array;
Expand Down

0 comments on commit 9341917

Please sign in to comment.