Skip to content

Commit

Permalink
Fold 64-bit int operations (#5561)
Browse files Browse the repository at this point in the history
Adds folding rules that will fold basic artimetic for signed and
unsigned integers of all sizes, including 64-bit.

Also folds OpSConvert and OpUConvert.
  • Loading branch information
s-perron authored Feb 9, 2024
1 parent 80926d9 commit a8959dc
Show file tree
Hide file tree
Showing 2 changed files with 534 additions and 23 deletions.
220 changes: 219 additions & 1 deletion source/opt/const_folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,59 @@ namespace opt {
namespace {
constexpr uint32_t kExtractCompositeIdInIdx = 0;

// Returns the value obtained by extracting the |number_of_bits| least
// significant bits from |value|, and sign-extending it to 64-bits.
uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
if (number_of_bits == 64) return value;

uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
if (value & mask_for_sign_bit) {
// Set upper bits to 1
value |= ~mask_for_significant_bits;
} else {
// Clear the upper bits
value &= mask_for_significant_bits;
}
return value;
}

// Returns the value obtained by extracting the |number_of_bits| least
// significant bits from |value|, and zero-extending it to 64-bits.
uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
if (number_of_bits == 64) return value;

uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
value &= mask_for_bits_to_keep;
return value;
}

// Returns a constant whose value is `value` and type is `type`. This constant
// will be generated by `const_mgr`. The type must be a scalar integer type.
const analysis::Constant* GenerateIntegerConstant(
const analysis::Integer* integer_type, uint64_t result,
analysis::ConstantManager* const_mgr) {
assert(integer_type != nullptr);

std::vector<uint32_t> words;
if (integer_type->width() == 64) {
// In the 64-bit case, two words are needed to represent the value.
words = {static_cast<uint32_t>(result),
static_cast<uint32_t>(result >> 32)};
} else {
// In all other cases, only a single word is needed.
assert(integer_type->width() <= 32);
if (integer_type->IsSigned()) {
result = SignExtendValue(result, integer_type->width());
} else {
result = ZeroExtendValue(result, integer_type->width());
}
words = {static_cast<uint32_t>(result)};
}
return const_mgr->GetConstant(integer_type, words);
}

// Returns a constants with the value NaN of the given type. Only works for
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
const analysis::Constant* GetNan(const analysis::Type* type,
Expand Down Expand Up @@ -676,7 +729,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
Expand Down Expand Up @@ -716,6 +768,64 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
};
}

// Returns a |ConstantFoldingRule| that folds binary scalar ops
// using |scalar_rule| and binary vectors ops by applying
// |scalar_rule| to the elements of the vector. The folding rule assumes that op
// has two inputs. For regular instruction, those are in operands 0 and 1. For
// extended instruction, they are in operands 1 and 2. If an element in
// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
// or |Vector| whose element type is |Float| or |Integer|.
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(constants.size() == inst->NumInOperands());
assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();

const analysis::Constant* arg1 =
(inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
const analysis::Constant* arg2 =
(inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];

if (arg1 == nullptr || arg2 == nullptr) {
return nullptr;
}

if (vector_type == nullptr) {
return scalar_rule(result_type, arg1, arg2, const_mgr);
}

std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;

a_components = arg1->GetVectorComponents(const_mgr);
b_components = arg2->GetVectorComponents(const_mgr);
assert(a_components.size() == b_components.size());

// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i], b_components[i],
const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}

// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
};
}

// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
// using |scalar_rule| and unary float point vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
Expand Down Expand Up @@ -1587,6 +1697,72 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
return nullptr;
};
}

enum Sign { Signed, Unsigned };

// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
// The `signedness` is used to determine if the operands should be interpreted
// as signed or unsigned. If the operands are signed, the value will be sign
// extended before the value is passed to `op`. Otherwise the values will be
// zero extended.
template <Sign signedness>
BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
uint64_t)) {
return
[op](const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr && b != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
assert(integer_type != nullptr);
assert(integer_type == a->type()->AsInteger());
assert(integer_type == b->type()->AsInteger());

// In SPIR-V, all operations support unsigned types, but the way they
// are interpreted depends on the opcode. This is why we use the
// template argument to determine how to interpret the operands.
uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
: a->GetZeroExtendedValue());
uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
: b->GetZeroExtendedValue());
uint64_t result = op(ia, ib);

const analysis::Constant* result_constant =
GenerateIntegerConstant(integer_type, result, const_mgr);
return result_constant;
};
}

// A scalar folding rule that folds OpSConvert.
const analysis::Constant* FoldScalarSConvert(
const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) {
assert(result_type != nullptr);
assert(a != nullptr);
assert(const_mgr != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
assert(integer_type && "The result type of an SConvert");
int64_t value = a->GetSignExtendedValue();
return GenerateIntegerConstant(integer_type, value, const_mgr);
}

// A scalar folding rule that folds OpUConvert.
const analysis::Constant* FoldScalarUConvert(
const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) {
assert(result_type != nullptr);
assert(a != nullptr);
assert(const_mgr != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
assert(integer_type && "The result type of an UConvert");
uint64_t value = a->GetZeroExtendedValue();

// If the operand was an unsigned value with less than 32-bit, it would have
// been sign extended earlier, and we need to clear those bits.
auto* operand_type = a->type()->AsInteger();
value = ZeroExtendValue(value, operand_type->width());
return GenerateIntegerConstant(integer_type, value, const_mgr);
}
} // namespace

void ConstantFoldingRules::AddFoldingRules() {
Expand All @@ -1604,6 +1780,8 @@ void ConstantFoldingRules::AddFoldingRules() {
rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));

rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
Expand Down Expand Up @@ -1662,6 +1840,46 @@ void ConstantFoldingRules::AddFoldingRules() {
rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());

rules_[spv::Op::OpIAdd].push_back(
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
[](uint64_t a, uint64_t b) { return a + b; })));
rules_[spv::Op::OpISub].push_back(
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
[](uint64_t a, uint64_t b) { return a - b; })));
rules_[spv::Op::OpIMul].push_back(
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
[](uint64_t a, uint64_t b) { return a * b; })));
rules_[spv::Op::OpUDiv].push_back(
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
[](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
static_cast<int64_t>(b))
: 0);
})));
rules_[spv::Op::OpUMod].push_back(
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
[](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));

rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
static_cast<int64_t>(b))
: 0);
})));

rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
if (b == 0) return static_cast<uint64_t>(0ull);

int64_t signed_a = static_cast<int64_t>(a);
int64_t signed_b = static_cast<int64_t>(b);
int64_t result = signed_a % signed_b;
if ((signed_b < 0) != (result < 0)) result += signed_b;
return static_cast<uint64_t>(result);
})));

// Add rules for GLSLstd450
FeatureManager* feature_manager = context_->get_feature_mgr();
uint32_t ext_inst_glslstd450_id =
Expand Down
Loading

0 comments on commit a8959dc

Please sign in to comment.