@@ -64,6 +64,14 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
6464 patterns.getContext (), " __ocml_cabs_f32" );
6565 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::AbsOp, Float64Type>>(
6666 patterns.getContext (), " __ocml_cabs_f64" );
67+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::AngleOp, Float32Type>>(
68+ patterns.getContext (), " __ocml_carg_f32" );
69+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::AngleOp, Float64Type>>(
70+ patterns.getContext (), " __ocml_carg_f64" );
71+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::ConjOp, Float32Type>>(
72+ patterns.getContext (), " __ocml_conj_f32" );
73+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::ConjOp, Float64Type>>(
74+ patterns.getContext (), " __ocml_conj_f64" );
6775 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::CosOp, Float32Type>>(
6876 patterns.getContext (), " __ocml_ccos_f32" );
6977 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::CosOp, Float64Type>>(
@@ -76,6 +84,10 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
7684 patterns.getContext (), " __ocml_clog_f32" );
7785 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::LogOp, Float64Type>>(
7886 patterns.getContext (), " __ocml_clog_f64" );
87+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::PowOp, Float32Type>>(
88+ patterns.getContext (), " __ocml_cpow_f32" );
89+ patterns.add <ComplexOpToROCDLLibraryCalls<complex ::PowOp, Float64Type>>(
90+ patterns.getContext (), " __ocml_cpow_f64" );
7991 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::SinOp, Float32Type>>(
8092 patterns.getContext (), " __ocml_csin_f32" );
8193 patterns.add <ComplexOpToROCDLLibraryCalls<complex ::SinOp, Float64Type>>(
@@ -110,8 +122,9 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
110122
111123 ConversionTarget target (getContext ());
112124 target.addLegalDialect <func::FuncDialect>();
113- target.addIllegalOp <complex ::AbsOp, complex ::CosOp, complex ::ExpOp,
114- complex ::LogOp, complex ::SinOp, complex ::SqrtOp,
125+ target.addIllegalOp <complex ::AbsOp, complex ::AngleOp, complex ::ConjOp,
126+ complex ::CosOp, complex ::ExpOp, complex ::LogOp,
127+ complex ::PowOp, complex ::SinOp, complex ::SqrtOp,
115128 complex ::TanOp, complex ::TanhOp>();
116129 if (failed (applyPartialConversion (op, target, std::move (patterns))))
117130 signalPassFailure ();
0 commit comments