Skip to content

Commit

Permalink
Merge Port rest of matrix format to dpcpp
Browse files Browse the repository at this point in the history
This pr ports the rest of matrix format to dpcpp (except for dense and ell, which are in another prs)

Related PR: #845
  • Loading branch information
yhmtsai authored Aug 6, 2021
2 parents 67296df + 994b7bb commit d9789d7
Show file tree
Hide file tree
Showing 77 changed files with 6,080 additions and 1,372 deletions.
78 changes: 78 additions & 0 deletions common/matrix/coo_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*******************************<GINKGO LICENSE>******************************
Copyright (c) 2017-2021, the Ginkgo authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

#include "core/matrix/coo_kernels.hpp"


#include <ginkgo/core/base/math.hpp>


#include "common/base/kernel_launch.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
/**
* @brief The Coo matrix format namespace.
*
* @ingroup coo
*/
namespace coo {


template <typename ValueType, typename IndexType>
void extract_diagonal(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Coo<ValueType, IndexType> *orig,
matrix::Diagonal<ValueType> *diag)
{
run_kernel(
exec,
[] GKO_KERNEL(auto tidx, auto orig_values, auto orig_row_idxs,
auto orig_col_idxs, auto diag) {
if (orig_row_idxs[tidx] == orig_col_idxs[tidx]) {
diag[orig_row_idxs[tidx]] = orig_values[tidx];
}
},
orig->get_num_stored_elements(), orig->get_const_values(),
orig->get_const_row_idxs(), orig->get_const_col_idxs(),
diag->get_values());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL);


} // namespace coo
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
16 changes: 0 additions & 16 deletions common/matrix/coo_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -272,20 +272,4 @@ __global__ __launch_bounds__(default_block_size) void fill_in_dense(
}


template <typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void extract_diagonal(
size_type nnz, const ValueType *__restrict__ orig_values,
const IndexType *__restrict__ orig_row_idxs,
const IndexType *__restrict__ orig_col_idxs, ValueType *__restrict__ diag)
{
const auto tidx = thread::get_thread_id_flat();

if (tidx < nnz) {
if (orig_row_idxs[tidx] == orig_col_idxs[tidx]) {
diag[orig_row_idxs[tidx]] = orig_values[tidx];
}
}
}


} // namespace kernel
108 changes: 108 additions & 0 deletions common/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*******************************<GINKGO LICENSE>******************************
Copyright (c) 2017-2021, the Ginkgo authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

#include "core/matrix/csr_kernels.hpp"


#include <algorithm>


#include <ginkgo/core/base/math.hpp>


#include "common/base/kernel_launch.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
/**
* @brief The Csr matrix format namespace.
*
* @ingroup csr
*/
namespace csr {


template <typename IndexType>
void invert_permutation(std::shared_ptr<const DefaultExecutor> exec,
size_type size, const IndexType *permutation_indices,
IndexType *inv_permutation)
{
run_kernel(
exec,
[] GKO_KERNEL(auto tid, auto permutation, auto inv_permutation) {
inv_permutation[permutation[tid]] = tid;
},
size, permutation_indices, inv_permutation);
}

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_INVERT_PERMUTATION_KERNEL);


template <typename ValueType, typename IndexType>
void inverse_column_permute(std::shared_ptr<const DefaultExecutor> exec,
const IndexType *perm,
const matrix::Csr<ValueType, IndexType> *orig,
matrix::Csr<ValueType, IndexType> *column_permuted)
{
auto num_rows = orig->get_size()[0];
auto nnz = orig->get_num_stored_elements();
auto size = std::max(num_rows, nnz);
run_kernel(
exec,
[] GKO_KERNEL(auto tid, auto num_rows, auto num_nonzeros,
auto permutation, auto in_row_ptrs, auto in_col_idxs,
auto in_vals, auto out_row_ptrs, auto out_col_idxs,
auto out_vals) {
if (tid < num_nonzeros) {
out_col_idxs[tid] = permutation[in_col_idxs[tid]];
out_vals[tid] = in_vals[tid];
}
if (tid <= num_rows) {
out_row_ptrs[tid] = in_row_ptrs[tid];
}
},
size, num_rows, nnz, perm, orig->get_const_row_ptrs(),
orig->get_const_col_idxs(), orig->get_const_values(),
column_permuted->get_row_ptrs(), column_permuted->get_col_idxs(),
column_permuted->get_values());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL);


} // namespace csr
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
59 changes: 8 additions & 51 deletions common/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ __device__ __forceinline__ void find_next_row(
if (ind >= *row_end) {
*row = row_predict;
*row_end = row_predict_end;
for (; ind >= *row_end; *row_end = row_ptr[++*row + 1])
;
while (ind >= *row_end) {
*row_end = row_ptr[++*row + 1];
}
}

} else {
Expand Down Expand Up @@ -140,8 +141,8 @@ template <typename IndexType>
__device__ __forceinline__ IndexType get_warp_start_idx(
const IndexType nwarps, const IndexType nnz, const IndexType warp_idx)
{
const long long cache_lines = ceildivT<IndexType>(nnz, wsize);
return (warp_idx * cache_lines / nwarps) * wsize;
const long long cache_lines = ceildivT<IndexType>(nnz, config::warp_size);
return (warp_idx * cache_lines / nwarps) * config::warp_size;
}


Expand All @@ -160,6 +161,7 @@ __device__ __forceinline__ void spmv_kernel(
}
const IndexType data_size = row_ptrs[num_rows];
const IndexType start = get_warp_start_idx(nwarps, data_size, warp_idx);
constexpr IndexType wsize = config::warp_size;
const IndexType end =
min(get_warp_start_idx(nwarps, data_size, warp_idx + 1),
ceildivT<IndexType>(data_size, wsize) * wsize);
Expand Down Expand Up @@ -218,17 +220,6 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_spmv(
}


template <typename ValueType>
__global__ __launch_bounds__(default_block_size) void set_zero(
const size_type nnz, ValueType *__restrict__ val)
{
const auto ind = thread::get_thread_id_flat();
if (ind < nnz) {
val[ind] = zero<ValueType>();
}
}


template <typename IndexType>
__forceinline__ __device__ void merge_path_search(
const IndexType diagonal, const IndexType a_len, const IndexType b_len,
Expand Down Expand Up @@ -359,8 +350,7 @@ __device__ void merge_path_spmv(
tmp_val[threadIdx.x] = value;
tmp_ind[threadIdx.x] = row_i;
group::this_thread_block().sync();
bool last = block_segment_scan_reverse(static_cast<IndexType *>(tmp_ind),
static_cast<ValueType *>(tmp_val));
bool last = block_segment_scan_reverse(tmp_ind, tmp_val);
if (threadIdx.x == spmv_block_size - 1) {
row_out[blockIdx.x] = min(end_x, num_rows - 1);
val_out[blockIdx.x] = tmp_val[threadIdx.x];
Expand Down Expand Up @@ -948,39 +938,6 @@ __global__ __launch_bounds__(default_block_size) void conjugate_kernel(
} // namespace


template <typename IndexType>
__global__ __launch_bounds__(default_block_size) void inv_permutation_kernel(
size_type size, const IndexType *__restrict__ permutation,
IndexType *__restrict__ inv_permutation)
{
auto tid = thread::get_thread_id_flat();
if (tid >= size) {
return;
}
inv_permutation[permutation[tid]] = tid;
}


template <typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void col_permute_kernel(
size_type num_rows, size_type num_nonzeros,
const IndexType *__restrict__ permutation,
const IndexType *__restrict__ in_row_ptrs,
const IndexType *__restrict__ in_cols,
const ValueType *__restrict__ in_vals, IndexType *__restrict__ out_row_ptrs,
IndexType *__restrict__ out_cols, ValueType *__restrict__ out_vals)
{
auto tid = thread::get_thread_id_flat();
if (tid < num_nonzeros) {
out_cols[tid] = permutation[in_cols[tid]];
out_vals[tid] = in_vals[tid];
}
if (tid <= num_rows) {
out_row_ptrs[tid] = in_row_ptrs[tid];
}
}


template <typename IndexType>
__global__ __launch_bounds__(default_block_size) void row_ptr_permute_kernel(
size_type num_rows, const IndexType *__restrict__ permutation,
Expand Down Expand Up @@ -1088,4 +1045,4 @@ __global__ __launch_bounds__(default_block_size) void inv_symm_permute_kernel(
out_cols[out_begin + i] = permutation[in_cols[in_begin + i]];
out_vals[out_begin + i] = in_vals[in_begin + i];
}
}
}
Loading

0 comments on commit d9789d7

Please sign in to comment.