Skip to content

Commit

Permalink
WIP factorization fix half but not complex half yet
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 4, 2024
1 parent b96aa71 commit 88967e6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 9 deletions.
13 changes: 13 additions & 0 deletions common/cuda_hip/base/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z)


namespace gko {
using thrust::sqrt;
// It is required by NVHPC 23.3, isnan is undefined when NVHPC are only as host
// compiler.
#if defined(__CUDACC__) || defined(GKO_COMPILING_HIP)
Expand Down Expand Up @@ -156,6 +157,18 @@ __device__ __forceinline__ __half sqrt(const __half& val)
}


// using overload here. Otherwise, compiler still think the is_finite
// specialization is still __host__ __device__ function.
__device__ __forceinline__ bool is_finite(const __half& value)
{
return abs(value) < device_numeric_limits<__half>::inf();
}

__device__ __forceinline__ bool is_finite(const thrust::complex<__half>& value)
{
return is_finite(value.real()) && is_finite(value.imag());
}

#endif


Expand Down
6 changes: 5 additions & 1 deletion common/cuda_hip/factorization/par_ic_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ __global__ __launch_bounds__(default_block_size) void ic_sweep(
}
auto to_write =
row == col
? sqrt(a_val - sum)
? gko::sqrt(a_val - sum)
: (a_val - sum) / load_relaxed(l_vals + (l_row_ptrs[col + 1] - 1));
// if (row == col && row < 30) {
// printf("%d: %lf\n", row, static_cast<double>(real(to_write)));
// }
if (is_finite(to_write)) {
printf("write?!!\n");
store_relaxed(l_vals + l_nz, to_write);
}
}
Expand Down
2 changes: 2 additions & 0 deletions common/cuda_hip/factorization/par_ilu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ __global__ __launch_bounds__(default_block_size) void compute_l_u_factors(
auto to_write =
sum / load_relaxed(u_values + (u_row_ptrs[col + 1] - 1));
if (is_finite(to_write)) {
printf("write!\n");
store_relaxed(l_values + (l_idx - 1), to_write);
}
} else {
auto to_write = sum;
if (is_finite(to_write)) {
printf("write!\n");
store_relaxed(u_values + (u_idx - 1), to_write);
}
}
Expand Down
14 changes: 10 additions & 4 deletions test/factorization/par_ic_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ParIc : public CommonTestFixture {
mtx_l = gko::test::generate_random_lower_triangular_matrix<Csr>(
mtx_size[0], false,
std::uniform_int_distribution<index_type>(10, mtx_size[0]),
std::normal_distribution<>(0, 10.0), rand_engine, ref);
std::normal_distribution<>(0, 5.0), rand_engine, ref);
dmtx_ani = Csr::create(exec);
dmtx_l_ani = Csr::create(exec);
dmtx_l_ani_init = Csr::create(exec);
Expand Down Expand Up @@ -108,16 +108,22 @@ TYPED_TEST(ParIc, KernelComputeFactorIsEquivalentToRef)
using Csr = typename TestFixture::Csr;
using Coo = typename TestFixture::Coo;
using value_type = typename TestFixture::value_type;
SKIP_IF_HALF(value_type);
auto square_size = this->mtx_ani->get_size();
auto mtx_l_coo = Coo::create(this->ref, square_size);
this->mtx_l_ani->convert_to(mtx_l_coo);
auto dmtx_l_coo = gko::clone(this->exec, mtx_l_coo);

auto mtx_init = this->dmtx_l_ani_init->clone(this->ref);
gko::kernels::reference::par_ic_factorization::compute_factor(
this->ref, 1, mtx_l_coo.get(), this->mtx_l_ani_init.get());
gko::kernels::GKO_DEVICE_NAMESPACE::par_ic_factorization::compute_factor(
this->exec, 100, dmtx_l_coo.get(), this->dmtx_l_ani_init.get());

GKO_ASSERT_MTX_NEAR(this->mtx_l_ani_init, this->dmtx_l_ani_init, 1e-4);
GKO_EXPECT_MTX_NEAR(this->mtx_l_ani_init, this->dmtx_l_ani_init, 1e-4);

// gko::kernels::reference::par_ic_factorization::compute_factor(
// this->ref, 1, mtx_l_coo.get(), mtx_init.get());
// gko::kernels::GKO_DEVICE_NAMESPACE::par_ic_factorization::compute_factor(
// this->exec, 2000, dmtx_l_coo.get(), this->dmtx_l_ani_init.get());
// GKO_EXPECT_MTX_NEAR(this->mtx_l_ani_init, mtx_init, 1e-4);
GKO_EXPECT_MTX_NEAR(this->dmtx_l_ani_init, mtx_init, 0);
}
2 changes: 1 addition & 1 deletion test/factorization/par_ict_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TYPED_TEST(ParIct, KernelComputeFactorIsEquivalentToRef)
using Csr = typename TestFixture::Csr;
using Coo = typename TestFixture::Coo;
using value_type = typename TestFixture::value_type;
SKIP_IF_HALF(value_type);
// SKIP_IF_HALF(value_type);
auto square_size = this->mtx_ani->get_size();
auto mtx_l_coo = Coo::create(this->ref, square_size);
this->mtx_l_ani->convert_to(mtx_l_coo);
Expand Down
6 changes: 3 additions & 3 deletions test/factorization/par_ilu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ TYPED_TEST(ParIlu, KernelComputeParILUIsEquivalentToRef)
{
using Csr = typename TestFixture::Csr;
using value_type = typename TestFixture::value_type;
SKIP_IF_HALF(value_type);
// SKIP_IF_HALF(value_type);
std::unique_ptr<Csr> l_mtx{};
std::unique_ptr<Csr> u_mtx{};
std::unique_ptr<Csr> dl_mtx{};
Expand All @@ -257,12 +257,12 @@ TYPED_TEST(ParIlu, KernelComputeParILUWithMoreIterationsIsEquivalentToRef)
{
using Csr = typename TestFixture::Csr;
using value_type = typename TestFixture::value_type;
SKIP_IF_HALF(value_type);
// SKIP_IF_HALF(value_type);
std::unique_ptr<Csr> l_mtx{};
std::unique_ptr<Csr> u_mtx{};
std::unique_ptr<Csr> dl_mtx{};
std::unique_ptr<Csr> du_mtx{};
gko::size_type iterations{200};
gko::size_type iterations{500};

this->compute_lu(l_mtx, u_mtx, dl_mtx, du_mtx, iterations);

Expand Down

0 comments on commit 88967e6

Please sign in to comment.