Skip to content

Commit

Permalink
Fix stft not found (#9924)
Browse files Browse the repository at this point in the history
Fix #9922
  • Loading branch information
mosout authored Mar 2, 2023
1 parent 8ba8eb6 commit a843da1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_
#define ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_

#include <cufft.h>
#include <cufftXt.h>
#include "oneflow/core/framework/framework.h"
Expand All @@ -22,10 +25,13 @@ limitations under the License.
#include "oneflow/core/kernel/kernel.h"

namespace oneflow {

namespace {

constexpr int max_rank = 3;

}

struct CuFFtParams {
int32_t ndim;
int32_t output_shape[max_rank + 1];
Expand All @@ -34,8 +40,8 @@ struct CuFFtParams {
int32_t output_strides[max_rank + 1];
int32_t* rank;
int32_t batch;
CuFFtParams(int32_t dims, int32_t* r, const Stride& in_strides, const Stride& out_strides,
const Shape& in_shape, const Shape& out_shape, int32_t b)
CuFFtParams(int32_t dims, int32_t* r, const Stride& in_strides, // NOLINT
const Stride& out_strides, const Shape& in_shape, const Shape& out_shape, int32_t b)
: ndim(dims), rank(r), batch(b) {
std::copy(in_strides.begin(), in_strides.end(), input_strides);
std::copy(out_strides.begin(), out_strides.end(), output_strides);
Expand All @@ -49,8 +55,9 @@ class CuFFtConfig {
public:
CuFFtConfig(const CuFFtConfig&) = delete;
CuFFtConfig& operator=(CuFFtConfig const&) = delete;
~CuFFtConfig() = default;

explicit CuFFtConfig(CuFFtParams& params) {
explicit CuFFtConfig(CuFFtParams& params) { // NOLINT
infer_cufft_type_();
cufftPlanMany(&plan_handle_, params.ndim, params.rank, params.input_shape,
params.input_strides[0], params.input_strides[1], params.output_shape,
Expand Down Expand Up @@ -83,6 +90,6 @@ class CuFFtConfig {
cufftType exectype_;
};

} // namespace
} // namespace oneflow

} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_
20 changes: 9 additions & 11 deletions oneflow/user/kernels/stft_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,15 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include <cuda.h>

#if CUDA_VERSION >= 11000
#define CUDA_SUPPORT_CUFFT
#endif

#ifdef CUDA_SUPPORT_CUFFT
#include "cufft_plan_cache.h"

#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include <cufft.h>
#include <cufftXt.h>
#include "oneflow/core/kernel/kernel.h"
#include "cufftplancache.h"
namespace oneflow {
namespace {} // namespace

namespace {

template<typename IN, typename OUT>
__global__ void convert_complex_to_real(IN* dst, const OUT* src, size_t n) {
Expand Down Expand Up @@ -73,6 +67,8 @@ __global__ void convert_doublesided(const FFTTYPE* src, FFTTYPE* dst, size_t len
}
}

} // namespace

template<typename IN, typename OUT>
class StftGpuKernel final : public user_op::OpKernel {
public:
Expand Down Expand Up @@ -158,8 +154,10 @@ class StftGpuKernel final : public user_op::OpKernel {
const int64_t output_bytes = GetCudaAlignedSize(output_elem_cnt * sizeof(outtype)); \
return onesided ? output_bytes : 2 * output_bytes; \
});
REGISTER_STFT_GPU_KERNEL(float, cufftComplex)
REGISTER_STFT_GPU_KERNEL(double, cufftDoubleComplex)
} // namespace oneflow
#endif

0 comments on commit a843da1

Please sign in to comment.