@@ -135,6 +135,7 @@ static Constant *constantFoldWrRegion(Type *RetTy,
135135 const CMRegion &R, const DataLayout &DL) {
136136 Constant *OldValue = Operands[GenXIntrinsic::GenXRegion::OldValueOperandNum];
137137 Constant *NewValue = Operands[GenXIntrinsic::GenXRegion::NewValueOperandNum];
138+ Constant *Mask = Operands[GenXIntrinsic::GenXRegion::PredicateOperandNum];
138139 // The inputs can be ConstantExpr if we are being called from
139140 // CallAnalyzer.
140141 if (isa<ConstantExpr>(OldValue) || isa<ConstantExpr>(NewValue))
@@ -147,7 +148,8 @@ static Constant *constantFoldWrRegion(Type *RetTy,
147148
148149 const int RetElemSize = DL.getTypeSizeInBits (RetTy->getScalarType ()) / 8 ;
149150 unsigned Offset = OffsetC->getSExtValue () / RetElemSize;
150- if (isa<UndefValue>(OldValue) && R.isContiguous () && (Offset == 0 )) {
151+ if (isa<UndefValue>(OldValue) && R.isContiguous () && Offset == 0 &&
152+ Mask->isAllOnesValue ()) {
151153 // If old value is undef and new value is splat, and the result vector
152154 // is no bigger than 2 GRFs, then just return a splat of the right type.
153155 Constant *Splat = NewValue;
@@ -172,7 +174,7 @@ static Constant *constantFoldWrRegion(Type *RetTy,
172174 return UndefValue::get (RetTy); // out of range index
173175 if (!isa<VectorType>(NewValue->getType ()))
174176 Values[Offset] = NewValue;
175- else {
177+ else if (!Mask-> isZeroValue ()) {
176178 unsigned RowIdx = Offset;
177179 unsigned Idx = RowIdx;
178180 unsigned NextRow = R.Width ;
@@ -185,7 +187,10 @@ static Constant *constantFoldWrRegion(Type *RetTy,
185187 if (Idx >= WholeNumElements)
186188 // return collected values even if idx is out of bounds
187189 return ConstantVector::get (Values);
188- Values[Idx] = NewValue->getAggregateElement (i);
190+ if (Mask->isAllOnesValue () ||
191+ (Mask->getType ()->isVectorTy () &&
192+ !cast<ConstantVector>(Mask)->getAggregateElement (i)->isZeroValue ()))
193+ Values[Idx] = NewValue->getAggregateElement (i);
189194 Idx += R.Stride ;
190195 }
191196 }
0 commit comments