@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
436436 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437437 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438438 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439- LLVM::SqrtOp>();
439+ LLVM::SincosOp, LLVM:: SqrtOp>();
440440
441441 // TODO: Remove once we support replacing non-root ops.
442442 target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
466466 });
467467}
468468
469+ struct SincosOpLowering : public ConvertOpToLLVMPattern <math::SincosOp> {
470+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
471+
472+ LogicalResult
473+ matchAndRewrite (math::SincosOp op, OpAdaptor adaptor,
474+ ConversionPatternRewriter &rewriter) const override {
475+ Location loc = op.getLoc ();
476+ Value input = adaptor.getOperand ();
477+ Type inputType = input.getType ();
478+ auto convertedInput = maybeExt (input, rewriter);
479+ auto computeType = convertedInput.getType ();
480+
481+ StringRef sincosFunc;
482+ if (isa<Float32Type>(computeType)) {
483+ const arith::FastMathFlags flag = op.getFastmath ();
484+ const bool useApprox =
485+ mlir::arith::bitEnumContainsAny (flag, arith::FastMathFlags::afn);
486+ sincosFunc = useApprox ? " __nv_fast_sincosf" : " __nv_sincosf" ;
487+ } else if (isa<Float64Type>(computeType)) {
488+ sincosFunc = " __nv_sincos" ;
489+ } else {
490+ return rewriter.notifyMatchFailure (op,
491+ " unsupported operand type for sincos" );
492+ }
493+
494+ auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
495+
496+ Value sinPtr, cosPtr;
497+ {
498+ OpBuilder::InsertionGuard guard (rewriter);
499+ auto *scope =
500+ op->getParentWithTrait <mlir::OpTrait::AutomaticAllocationScope>();
501+ assert (scope && " Expected op to be inside automatic allocation scope" );
502+ rewriter.setInsertionPointToStart (&scope->getRegion (0 ).front ());
503+ auto one = rewriter.create <LLVM::ConstantOp>(
504+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (1 ));
505+ sinPtr =
506+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
507+ cosPtr =
508+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
509+ }
510+
511+ createSincosCall (rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512+ op);
513+
514+ auto sinResult = rewriter.create <LLVM::LoadOp>(loc, computeType, sinPtr);
515+ auto cosResult = rewriter.create <LLVM::LoadOp>(loc, computeType, cosPtr);
516+
517+ rewriter.replaceOp (op, {maybeTrunc (sinResult, inputType, rewriter),
518+ maybeTrunc (cosResult, inputType, rewriter)});
519+ return success ();
520+ }
521+
522+ private:
523+ Value maybeExt (Value operand, PatternRewriter &rewriter) const {
524+ if (isa<Float16Type, BFloat16Type>(operand.getType ()))
525+ return rewriter.create <LLVM::FPExtOp>(
526+ operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
527+ return operand;
528+ }
529+
530+ Value maybeTrunc (Value operand, Type type, PatternRewriter &rewriter) const {
531+ if (operand.getType () != type)
532+ return rewriter.create <LLVM::FPTruncOp>(operand.getLoc (), type, operand);
533+ return operand;
534+ }
535+
536+ void createSincosCall (ConversionPatternRewriter &rewriter, Location loc,
537+ StringRef funcName, Value input, Value sinPtr,
538+ Value cosPtr, Operation *op) const {
539+ auto voidType = LLVM::LLVMVoidType::get (rewriter.getContext ());
540+ auto ptrType = sinPtr.getType ();
541+
542+ SmallVector<Type> operandTypes = {input.getType (), ptrType, ptrType};
543+ auto funcType = LLVM::LLVMFunctionType::get (voidType, operandTypes);
544+
545+ auto funcAttr = StringAttr::get (op->getContext (), funcName);
546+ auto funcOp =
547+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
548+
549+ if (!funcOp) {
550+ auto parentFunc = op->getParentOfType <FunctionOpInterface>();
551+ assert (parentFunc && " expected there to be a parent function" );
552+ OpBuilder b (parentFunc);
553+
554+ auto globalloc = loc->findInstanceOfOrUnknown <FileLineColLoc>();
555+ funcOp = LLVM::LLVMFuncOp::create (b, globalloc, funcName, funcType);
556+ }
557+
558+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
559+ rewriter.create <LLVM::CallOp>(loc, funcOp, callOperands);
560+ }
561+ };
562+
469563template <typename OpTy>
470564static void populateOpPatterns (const LLVMTypeConverter &converter,
471565 RewritePatternSet &patterns,
@@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
589683 " __nv_tan" , " __nv_fast_tanf" );
590684 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, " __nv_tanhf" ,
591685 " __nv_tanh" );
686+
687+ // Custom pattern for sincos since it returns two values
688+ patterns.add <SincosOpLowering>(converter, benefit);
592689}
593690
594691void mlir::populateGpuToNVVMConversionPatterns (
0 commit comments