Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Update DeviceScan to pass Thrust's scan tests.
Browse files Browse the repository at this point in the history
Thrust will be switching to `cub::DeviceScan` to replace its custom scan
implementation. This patch addresses some issues found by the Thrust
tests:

- Initialize unused `BlockLoad` items with values known to be in the input
  set. This fixes the `TestInclusiveScanWithIndirection` Thrust test by
  keeping the `plus_mod3` functor indices valid.
- Use `OffsetT` instead of `int` to hold indicies in `AgentScan`. This
  fixes the `Test*ScanWithBigIndexes` Thrust tests by not truncating
  the input problem size.
- Use `BLOCK_[STORE|LOAD]_WARP_TRANSPOSED_TIMESLICED` instead of
  `BLOCK_[STORE|LOAD]_WARP_TRANSPOSED` when the intermediate type is
  larger than 128 bytes. This keeps shared memory buffers from growing
  too large in the `TestScanWithLargeTypes` Thrust test.
  • Loading branch information
alliepiper committed Nov 2, 2020
1 parent 200cf19 commit 7e6f33b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
26 changes: 23 additions & 3 deletions cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,19 @@ struct AgentScan
OutputT items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining);
{
// Fill last element with the first element because collectives are
// not suffix guarded.
BlockLoadT(temp_storage.load)
.Load(d_in + tile_offset,
items,
num_remaining,
*(d_in + tile_offset));
}
else
{
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
}

CTA_SYNC();

Expand Down Expand Up @@ -330,7 +340,7 @@ struct AgentScan
* Scan tiles of items as part of a dynamic chained scan
*/
__device__ __forceinline__ void ConsumeRange(
int num_items, ///< Total number of input items
OffsetT num_items, ///< Total number of input items
ScanTileStateT& tile_state, ///< Global tile state descriptor
int start_tile) ///< The starting tile for the current grid
{
Expand Down Expand Up @@ -371,9 +381,19 @@ struct AgentScan
OutputT items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, valid_items);
{
// Fill last element with the first element because collectives are
// not suffix guarded.
BlockLoadT(temp_storage.load)
.Load(d_in + tile_offset,
items,
valid_items,
*(d_in + tile_offset));
}
else
{
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
}

CTA_SYNC();

Expand Down
14 changes: 11 additions & 3 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ template <
typename OutputT> ///< Data type
struct DeviceScanPolicy
{
// For large values, use timesliced loads/stores to fit shared memory.
static constexpr bool LargeValues = sizeof(OutputT) > 128;
static constexpr BlockLoadAlgorithm ScanTransposedLoad =
LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED
: BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm ScanTransposedStore =
LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED
: BLOCK_STORE_WARP_TRANSPOSE;

/// SM35
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
Expand All @@ -163,7 +171,7 @@ struct DeviceScanPolicy
OutputT,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
BLOCK_STORE_WARP_TRANSPOSE,
ScanTransposedStore,
BLOCK_SCAN_WARP_SCANS>
ScanPolicyT;
};
Expand All @@ -174,9 +182,9 @@ struct DeviceScanPolicy
typedef AgentScanPolicy<
128, 15, ///< Threads per block, items per thread
OutputT,
BLOCK_LOAD_TRANSPOSE,
ScanTransposedLoad,
LOAD_DEFAULT,
BLOCK_STORE_TRANSPOSE,
ScanTransposedStore,
BLOCK_SCAN_WARP_SCANS>
ScanPolicyT;
};
Expand Down

0 comments on commit 7e6f33b

Please sign in to comment.