Skip to content

Commit

Permalink
TA::host_allocator is serializable, so that btas::Tensor can be used …
Browse files Browse the repository at this point in the history
…as a tile again
  • Loading branch information
evaleev committed Sep 22, 2024
1 parent 8ab356e commit f294db3
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 238 deletions.
2 changes: 0 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ TiledArray/external/btas.h
TiledArray/external/madness.h
TiledArray/external/umpire.h
TiledArray/host/env.h
TiledArray/host/allocator.h
TiledArray/math/blas.h
TiledArray/math/gemm_helper.h
TiledArray/math/outer.h
Expand Down Expand Up @@ -223,7 +222,6 @@ if(CUDA_FOUND OR HIP_FOUND)
TiledArray/device/kernel/thrust/reduce_kernel.h
TiledArray/device/platform.h
TiledArray/device/thrust.h
TiledArray/device/allocators.h
TiledArray/device/um_storage.h)
if(CUDA_FOUND)
list(APPEND TILEDARRAY_HEADER_FILES
Expand Down
138 changes: 0 additions & 138 deletions src/TiledArray/device/allocators.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/TiledArray/device/um_storage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/


#include <TiledArray/device/allocators.h>
#include <TiledArray/external/device.h>
#include <TiledArray/device/thrust.h>

#ifdef TILEDARRAY_HAS_CUDA
Expand Down
2 changes: 1 addition & 1 deletion src/TiledArray/device/um_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TILEDARRAY_DEVICE_UM_VECTOR_H__INCLUDED
#define TILEDARRAY_DEVICE_UM_VECTOR_H__INCLUDED

#include <TiledArray/device/allocators.h>
#include <TiledArray/external/device.h>

#ifdef TILEDARRAY_HAS_DEVICE

Expand Down
15 changes: 14 additions & 1 deletion src/TiledArray/external/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,9 +798,22 @@ class Env {
static std::unique_ptr<Env> instance_{nullptr};
return instance_;
}
};
}; // class Env

namespace detail {

struct get_um_allocator {
umpire::Allocator& operator()() {
return deviceEnv::instance()->um_allocator();
}
};

struct get_pinned_allocator {
umpire::Allocator& operator()() {
return deviceEnv::instance()->pinned_allocator();
}
};

// in a madness device task point to its local optional stream to use by
// madness_task_stream_opt; set to nullptr after task callable finished
inline std::optional<Stream>*& madness_task_stream_opt_ptr_accessor() {
Expand Down
83 changes: 80 additions & 3 deletions src/TiledArray/external/umpire.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,54 @@ bool operator!=(
return !(lhs == rhs);
}

template <class T, class StaticLock, typename UmpireAllocatorAccessor>
class umpire_based_allocator
: public umpire_based_allocator_impl<T, StaticLock> {
public:
using base_type = umpire_based_allocator_impl<T, StaticLock>;
using typename base_type::const_pointer;
using typename base_type::const_reference;
using typename base_type::pointer;
using typename base_type::reference;
using typename base_type::value_type;

umpire_based_allocator() noexcept : base_type(&UmpireAllocatorAccessor{}()) {}

template <class U>
umpire_based_allocator(
const umpire_based_allocator<U, StaticLock, UmpireAllocatorAccessor>&
rhs) noexcept
: base_type(
static_cast<const umpire_based_allocator_impl<U, StaticLock>&>(
rhs)) {}

template <typename T1, typename T2, class StaticLock_,
typename UmpireAllocatorAccessor_>
friend bool operator==(
const umpire_based_allocator<T1, StaticLock_, UmpireAllocatorAccessor_>&
lhs,
const umpire_based_allocator<T2, StaticLock_, UmpireAllocatorAccessor_>&
rhs) noexcept;
}; // class umpire_based_allocator

template <class T1, class T2, class StaticLock,
typename UmpireAllocatorAccessor>
bool operator==(
const umpire_based_allocator<T1, StaticLock, UmpireAllocatorAccessor>& lhs,
const umpire_based_allocator<T2, StaticLock, UmpireAllocatorAccessor>&
rhs) noexcept {
return lhs.umpire_allocator() == rhs.umpire_allocator();
}

template <class T1, class T2, class StaticLock,
typename UmpireAllocatorAccessor>
bool operator!=(
const umpire_based_allocator<T1, StaticLock, UmpireAllocatorAccessor>& lhs,
const umpire_based_allocator<T2, StaticLock, UmpireAllocatorAccessor>&
rhs) noexcept {
return !(lhs == rhs);
}

/// see
/// https://stackoverflow.com/questions/21028299/is-this-behavior-of-vectorresizesize-type-n-under-c11-and-boost-container/21028912#21028912
template <typename T, typename A>
Expand Down Expand Up @@ -202,7 +250,7 @@ struct ArchiveLoadImpl<Archive,
const Archive& ar,
TiledArray::umpire_based_allocator_impl<T, StaticLock>& allocator) {
std::string allocator_name;
ar& allocator_name;
ar & allocator_name;
allocator = TiledArray::umpire_based_allocator_impl<T, StaticLock>(
umpire::ResourceManager::getInstance().getAllocator(allocator_name));
}
Expand All @@ -214,7 +262,7 @@ struct ArchiveStoreImpl<
static inline void store(
const Archive& ar,
const TiledArray::umpire_based_allocator_impl<T, StaticLock>& allocator) {
ar& allocator.umpire_allocator()->getName();
ar & allocator.umpire_allocator()->getName();
}
};

Expand All @@ -224,7 +272,7 @@ struct ArchiveLoadImpl<Archive, TiledArray::default_init_allocator<T, A>> {
TiledArray::default_init_allocator<T, A>& allocator) {
if constexpr (!std::allocator_traits<A>::is_always_equal::value) {
A base_allocator;
ar& base_allocator;
ar & base_allocator;
allocator = TiledArray::default_init_allocator<T, A>(base_allocator);
}
}
Expand All @@ -244,4 +292,33 @@ struct ArchiveStoreImpl<Archive, TiledArray::default_init_allocator<T, A>> {
} // namespace archive
} // namespace madness

namespace madness {
namespace archive {

template <class Archive, class T, class StaticLock,
typename UmpireAllocatorAccessor>
struct ArchiveLoadImpl<Archive, TiledArray::umpire_based_allocator<
T, StaticLock, UmpireAllocatorAccessor>> {
static inline void load(
const Archive& ar,
TiledArray::umpire_based_allocator<T, StaticLock,
UmpireAllocatorAccessor>& allocator) {
allocator = TiledArray::umpire_based_allocator<T, StaticLock,
UmpireAllocatorAccessor>{};
}
};

template <class Archive, class T, class StaticLock,
typename UmpireAllocatorAccessor>
struct ArchiveStoreImpl<Archive, TiledArray::umpire_based_allocator<
T, StaticLock, UmpireAllocatorAccessor>> {
static inline void store(
const Archive& ar,
const TiledArray::umpire_based_allocator<
T, StaticLock, UmpireAllocatorAccessor>& allocator) {}
};

} // namespace archive
} // namespace madness

#endif // TILEDARRAY_EXTERNAL_UMPIRE_H___INCLUDED
32 changes: 19 additions & 13 deletions src/TiledArray/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,27 @@ class aligned_allocator;

// fwddecl host_allocator
namespace TiledArray {
template <class T>
class host_allocator_impl;
template <typename T, typename A>
namespace detail {
struct get_host_allocator;
struct NullLock;
template <typename Tag = void>
class MutexLock;
} // namespace detail

template <class T, class StaticLock, typename UmpireAllocatorAccessor>
class umpire_based_allocator;

template <typename T, typename A = std::allocator<T>>
class default_init_allocator;

class hostEnv;

/// pooled thread-safe host memory allocator
template <typename T>
using host_allocator = default_init_allocator<T, host_allocator_impl<T>>;
using host_allocator =
default_init_allocator<T,
umpire_based_allocator<T, detail::MutexLock<hostEnv>,
detail::get_host_allocator>>;
} // namespace TiledArray

namespace madness {
Expand Down Expand Up @@ -87,18 +102,9 @@ class Env;
}
using deviceEnv = device::Env;

template <class T, class StaticLock, typename UmpireAllocatorAccessor>
class umpire_based_allocator;

template <typename T, typename A = std::allocator<T>>
class default_init_allocator;

namespace detail {
struct get_um_allocator;
struct get_pinned_allocator;
struct NullLock;
template <typename Tag = void>
class MutexLock;
} // namespace detail

/// pooled thread-safe unified memory (UM) allocator for device computing
Expand Down
Loading

0 comments on commit f294db3

Please sign in to comment.