Skip to content

Commit

Permalink
Add strongly typed index support for TensorIndex.
Browse files Browse the repository at this point in the history
  • Loading branch information
rchen20 committed Sep 27, 2024
1 parent 8709d86 commit 887d3d3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions include/RAJA/pattern/tensor/TensorIndex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace expt
{


template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, IDX INDEX_VALUE, strip_index_type_t<IDX> LENGTH_VALUE>
template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, strip_index_type_t<IDX> INDEX_VALUE, strip_index_type_t<IDX> LENGTH_VALUE>
struct StaticTensorIndexInner;

template<typename INNER_TYPE>
Expand All @@ -56,8 +56,8 @@ namespace expt
RAJA_HOST_DEVICE
static
constexpr
StaticTensorIndex<StaticTensorIndexInner<IDX,TENSOR_TYPE,DIM,index_type(-1),value_type(-1)>> static_all(){
return StaticTensorIndex<StaticTensorIndexInner<IDX,TENSOR_TYPE,DIM,index_type(-1),value_type(-1)>>();
StaticTensorIndex<StaticTensorIndexInner<IDX,TENSOR_TYPE,DIM,value_type(-1),value_type(-1)>> static_all(){
return StaticTensorIndex<StaticTensorIndexInner<IDX,TENSOR_TYPE,DIM,value_type(-1),value_type(-1)>>();
}

RAJA_INLINE
Expand Down Expand Up @@ -103,7 +103,7 @@ namespace expt
TensorIndex(TensorIndex<IDX, T, D> const &c) : m_index(*c), m_length(c.size()) {}


template<IDX IDX_VAL, strip_index_type_t<IDX> LEN_VAL>
template<strip_index_type_t<IDX> IDX_VAL, strip_index_type_t<IDX> LEN_VAL>
RAJA_INLINE
RAJA_HOST_DEVICE
constexpr
Expand Down Expand Up @@ -156,7 +156,7 @@ namespace expt
};


template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, IDX INDEX_VALUE, strip_index_type_t<IDX> LENGTH_VALUE>
template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, strip_index_type_t<IDX> INDEX_VALUE, strip_index_type_t<IDX> LENGTH_VALUE>
struct StaticTensorIndex<StaticTensorIndexInner<IDX,TENSOR_TYPE,DIM,INDEX_VALUE,LENGTH_VALUE>> {

using base_type = TensorIndex<IDX,TENSOR_TYPE,DIM>;
Expand Down

0 comments on commit 887d3d3

Please sign in to comment.