diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 4890da7e28783..86834ab1240c1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3047,13 +3047,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, // instruction cost. return 0; case Instruction::Call: { - if (!isSingleScalar()) { - // TODO: Handle remaining call costs here as well. - if (VF.isScalable()) - return InstructionCost::getInvalid(); - break; - } - auto *CalledFn = cast(getOperand(getNumOperands() - 1)->getLiveInIRValue()); if (CalledFn->isIntrinsic()) @@ -3063,7 +3056,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, for (VPValue *ArgOp : drop_end(operands())) Tys.push_back(Ctx.Types.inferScalarType(ArgOp)); Type *ResultTy = Ctx.Types.inferScalarType(this); - return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind); + InstructionCost ScalarCallCost = + Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind); + if (isSingleScalar()) + return ScalarCallCost; + + if (VF.isScalable()) + return InstructionCost::getInvalid(); + + // Compute the cost of scalarizing the result and operands if needed. + InstructionCost ScalarizationCost = 0; + if (VF.isVector()) { + if (!ResultTy->isVoidTy()) { + for (Type *VectorTy : + to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) { + ScalarizationCost += Ctx.TTI.getScalarizationOverhead( + cast(VectorTy), APInt::getAllOnes(VF.getFixedValue()), + /*Insert=*/true, + /*Extract=*/false, Ctx.CostKind); + } + } + // Skip operands that do not require extraction/scalarization and do not + // incur any overhead. + SmallPtrSet UniqueOperands; + Tys.clear(); + for (auto *Op : drop_end(operands())) { + if (Op->isLiveIn() || isa(Op) || + !UniqueOperands.insert(Op).second) + continue; + Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF)); + } + ScalarizationCost += + Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind); + } + + return ScalarCallCost * (isSingleScalar() ? 1 : VF.getFixedValue()) + + ScalarizationCost; } case Instruction::Add: case Instruction::Sub: