Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
WIP Adding async scan algorithms.
Browse files Browse the repository at this point in the history
  • Loading branch information
alliepiper committed Aug 6, 2020
1 parent 945cd09 commit 45c8380
Show file tree
Hide file tree
Showing 9 changed files with 1,223 additions and 2 deletions.
2 changes: 2 additions & 0 deletions testing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ add_subdirectory(unittest)
# List of tests that aren't implemented for all backends, but are implemented for CUDA.
set(partially_implemented_CUDA
async_copy
async_exclusive_scan
async_for_each
async_reduce
async_reduce_into
async_inclusive_scan
async_sort
async_transform
event
Expand Down
130 changes: 130 additions & 0 deletions testing/async_exclusive_scan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2014

#include <unittest/unittest.h>
#include <unittest/util_async.h>

#include <thrust/async/scan.h>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

// TODO Finish implementing tests. Draw from other async algos, as well as
// the older scan tests.

namespace
{

template <typename value_type>
struct async_exclusive_scan_def
{
public:
// Input and output types for the algorithms:
using input_type = thrust::device_vector<value_type>;
using output_type = thrust::device_vector<value_type>;

using postfix_args_type = std::tuple< // List any extra arg overloads:
std::tuple<>, // - no extra args
std::tuple<value_type>, // - initial_value
std::tuple<value_type, thrust::maximum<>> // - initial_value, binary_op
>;

// Create instances of the extra arguments to use when invoking the
// algorithm:
static postfix_args_type generate_postfix_args()
{
return {
{}, // no extra args
{42}, // initial_value
{57, thrust::maximum<>{}} // initial_value, binary_op
};
}

// Generate an instance of the input:
static input_type generate_input()
{
input_type input(1024);
thrust::sequence(input.begin(), input.end(), 25, 3);
return input;
}

// Invoke a reference implementation for a single overload as described by
// postfix_tuple. This tuple contains instances of any additional arguments
// to pass to the algorithm. The tuple/index_sequence pattern is used to
// support the "no extra args" overload, since the parameter pack expansion
// will do exactly what we want in all cases.
template <typename PostfixArgTuple, std::size_t... PostfixArgIndices>
static void invoke_reference(PostfixArgTuple &&postfix_tuple,
std::index_sequence<PostfixArgIndices...>,
input_type const &input,
output_type &output)
{
// Create host versions of the input/output:
thrust::host_vector<value_type> host_input(input);
thrust::host_vector<value_type> host_output(input.size());

// Run host synchronous algorithm to generate reference.
thrust::exclusive_scan(host_input.cbegin(),
host_input.cend(),
host_output.begin(),
std::get<PostfixArgIndices>(
THRUST_FWD(postfix_tuple))...);

// Copy back to device.
output = host_output;
}

// Invoke the async algorithm for a single overload as described by
// the prefix and postfix tuples. These tuple contains instances of any
// additional arguments to pass to the algorithm. The tuple/index_sequence
// pattern is used to support the "no extra args" overload, since the
// parameter pack expansion will do exactly what we want in all cases.
// Prefix args are included here (but not for invoke_reference) to allow the
// test framework to change the execution policy.
// This method must return an event or future.
template <typename PrefixArgTuple,
std::size_t... PrefixArgIndices,
typename PostfixArgTuple,
std::size_t... PostfixArgIndices>
static auto invoke_async(PrefixArgTuple &&prefix_tuple,
std::index_sequence<PrefixArgIndices...>,
PostfixArgTuple &&postfix_tuple,
std::index_sequence<PostfixArgIndices...>,
input_type const &input,
output_type &output)
{
output.resize(input.size());
auto e = thrust::async::exclusive_scan(
std::get<PrefixArgIndices>(THRUST_FWD(prefix_tuple))...,
input.cbegin(),
input.cend(),
output.begin(),
std::get<PostfixArgIndices>(THRUST_FWD(postfix_tuple))...);
return e;
}

// Wait on and validate the event/future (usually with TEST_EVENT_WAIT /
// TEST_FUTURE_VALUE_RETRIEVAL), then check that the reference output matches
// the testing output.
template <typename EventType>
static void compare_outputs(EventType &e,
output_type const &ref,
output_type const &test)
{
TEST_EVENT_WAIT(e);
ASSERT_EQUAL_QUIET(ref, test);
}
};

} // namespace

void TestPolicyOverloads()
{
// Only ints are tested here because we just want to check that the policies
// are propagated correctly, so keep codegen to a minimum.
unittest::test_async_policy_overloads<async_exclusive_scan_def<int>>::run();
}
DECLARE_UNITTEST(TestPolicyOverloads);

#endif // C++14
108 changes: 108 additions & 0 deletions testing/async_inclusive_scan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2014

#include <unittest/unittest.h>
#include <unittest/util_async.h>

#include <thrust/async/scan.h>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

// TODO Finish implementing tests. Draw from other async algos, as well as
// the older scan tests.

namespace
{

template <typename value_type>
struct async_inclusive_scan_def
{
public:
using input_type = thrust::device_vector<value_type>;
using output_type = thrust::device_vector<value_type>;

using postfix_args_type = std::tuple< // List any extra arg overloads:
std::tuple<>, // - no extra args
std::tuple<thrust::maximum<>> // - Non-default binary-op
>;

static postfix_args_type generate_postfix_args()
{
return {
{}, // - no extra args
{thrust::maximum<>{}} // - non-default binary_op
};
}

static input_type generate_input()
{
input_type input(1024);
thrust::sequence(input.begin(), input.end(), 25, 3);
return input;
}

template <typename PostfixArgTuple, std::size_t... PostfixArgIndices>
static void invoke_reference(PostfixArgTuple &&postfix_tuple,
std::index_sequence<PostfixArgIndices...>,
input_type const &input,
output_type &output)
{
// Create host versions of the input/output:
thrust::host_vector<value_type> host_input(input);
thrust::host_vector<value_type> host_output(input.size());

// Run host synchronous algorithm to generate reference.
thrust::inclusive_scan(host_input.cbegin(),
host_input.cend(),
host_output.begin(),
std::get<PostfixArgIndices>(
THRUST_FWD(postfix_tuple))...);

// Copy back to device.
output = host_output;
}

template <typename PrefixArgTuple,
std::size_t... PrefixArgIndices,
typename PostfixArgTuple,
std::size_t... PostfixArgIndices>
static auto invoke_async(PrefixArgTuple &&prefix_tuple,
std::index_sequence<PrefixArgIndices...>,
PostfixArgTuple &&postfix_tuple,
std::index_sequence<PostfixArgIndices...>,
input_type const &input,
output_type &output)
{
output.resize(input.size());
auto e = thrust::async::inclusive_scan(
std::get<PrefixArgIndices>(THRUST_FWD(prefix_tuple))...,
input.cbegin(),
input.cend(),
output.begin(),
std::get<PostfixArgIndices>(THRUST_FWD(postfix_tuple))...);
return e;
}

template <typename EventType>
static void compare_outputs(EventType &e,
output_type const &ref,
output_type const &test)
{
TEST_EVENT_WAIT(e);
ASSERT_EQUAL_QUIET(ref, test);
}
};

} // namespace

void TestPolicyOverloads()
{
// Only ints are tested here because we just want to check that the policies
// are propagated correctly, so keep codegen to a minimum.
unittest::test_async_policy_overloads<async_inclusive_scan_def<int>>::run();
}
DECLARE_UNITTEST(TestPolicyOverloads);

#endif // C++14
Loading

0 comments on commit 45c8380

Please sign in to comment.