Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebuild the type of constants during evaluation. #4138

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 94 additions & 60 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ struct EvalContext {
return GetInContext(context.types().GetConstantId(type_id));
}

// Gets the constant value of the specified type in this context.
auto GetConstantValueAsType(SemIR::TypeId id) -> SemIR::TypeId {
return context.GetTypeIdForTypeConstant(GetConstantValue(id));
}

// Gets the instruction describing the constant value of the specified type in
// this context.
auto GetConstantValueAsInst(SemIR::TypeId id) -> SemIR::Inst {
Expand Down Expand Up @@ -150,16 +155,6 @@ static auto GetPhase(SemIR::ConstantId constant_id) -> Phase {
}
}

// Gets the earliest possible phase for a constant whose type is `type_id`. The
// type of a constant is effectively treated as an operand of that constant when
// determining its phase. For example, an empty struct with a symbolic type is a
// symbolic constant, not a template constant.
static auto GetTypePhase(EvalContext& eval_context, SemIR::TypeId type_id)
-> Phase {
CARBON_CHECK(type_id.is_valid());
return GetPhase(eval_context.GetConstantValue(type_id));
}

// Returns the later of two phases.
static auto LatestPhase(Phase a, Phase b) -> Phase {
return static_cast<Phase>(
Expand Down Expand Up @@ -227,12 +222,12 @@ static auto GetConstantValue(EvalContext& eval_context, SemIR::InstId inst_id,
return eval_context.context.constant_values().GetInstId(const_id);
}

// A type is always constant, but we still need to extract its phase.
// Gets the type corresponding to the specified type in this evaluation context.
zygoloid marked this conversation as resolved.
Show resolved Hide resolved
static auto GetConstantValue(EvalContext& eval_context, SemIR::TypeId type_id,
Phase* phase) -> SemIR::TypeId {
auto const_id = eval_context.GetConstantValue(type_id);
*phase = LatestPhase(*phase, GetPhase(const_id));
return type_id;
return eval_context.context.GetTypeIdForTypeConstant(const_id);
}

// If the given instruction block contains only constants, returns a
Expand Down Expand Up @@ -324,57 +319,77 @@ static auto ReplaceFieldWithConstantValue(EvalContext& eval_context,
//
// The constant value is then checked by calling `validate_fn(typed_inst)`,
// which should return a `bool` indicating whether the new constant is valid. If
// validation passes, a corresponding ConstantId for the new constant is
// validation passes, `transform_fn(typed_inst)` is called to produce the final
// constant instruction, and a corresponding ConstantId for the new constant is
// returned. If validation fails, it should produce a suitable error message.
// `ConstantId::Error` is returned.
template <typename InstT, typename ValidateFn, typename... EachFieldIdT>
static auto RebuildAndValidateIfFieldsAreConstant(
template <typename InstT, typename ValidateFn, typename TransformFn,
typename... EachFieldIdT>
static auto RebuildIfFieldsAreConstantImpl(
EvalContext& eval_context, SemIR::Inst inst, ValidateFn validate_fn,
EachFieldIdT InstT::*... each_field_id) -> SemIR::ConstantId {
TransformFn transform_fn, EachFieldIdT InstT::*... each_field_id)
-> SemIR::ConstantId {
// Build a constant instruction by replacing each non-constant operand with
// its constant value.
auto typed_inst = inst.As<InstT>();
// Some instruction kinds don't have a `type_id` field. For those that do, the
// type contributes to the phase.
Phase phase = inst.type_id().is_valid()
? GetTypePhase(eval_context, inst.type_id())
: Phase::Template;
Phase phase = Phase::Template;
if ((ReplaceFieldWithConstantValue(eval_context, &typed_inst, each_field_id,
&phase) &&
...)) {
if (phase == Phase::UnknownDueToError || !validate_fn(typed_inst)) {
return SemIR::ConstantId::Error;
}
return MakeConstantResult(eval_context.context, typed_inst, phase);
return MakeConstantResult(eval_context.context, transform_fn(typed_inst),
phase);
}
return MakeNonConstantResult(phase);
}

// Same as above but with an identity transform function.
template <typename InstT, typename ValidateFn, typename... EachFieldIdT>
static auto RebuildAndValidateIfFieldsAreConstant(
EvalContext& eval_context, SemIR::Inst inst, ValidateFn validate_fn,
EachFieldIdT InstT::*... each_field_id) -> SemIR::ConstantId {
return RebuildIfFieldsAreConstantImpl(eval_context, inst, validate_fn,
std::identity{}, each_field_id...);
}

// Same as above but with no validation step.
template <typename InstT, typename TransformFn, typename... EachFieldIdT>
static auto TransformIfFieldsAreConstant(EvalContext& eval_context,
SemIR::Inst inst,
TransformFn transform_fn,
EachFieldIdT InstT::*... each_field_id)
-> SemIR::ConstantId {
return RebuildIfFieldsAreConstantImpl(
eval_context, inst, [](...) { return true; }, transform_fn,
each_field_id...);
}

// Same as above but with no validation or transform step.
template <typename InstT, typename... EachFieldIdT>
static auto RebuildIfFieldsAreConstant(EvalContext& eval_context,
SemIR::Inst inst,
EachFieldIdT InstT::*... each_field_id)
-> SemIR::ConstantId {
return RebuildAndValidateIfFieldsAreConstant(
eval_context, inst, [](...) { return true; }, each_field_id...);
return RebuildIfFieldsAreConstantImpl(
eval_context, inst, [](...) { return true; }, std::identity{},
each_field_id...);
}

// Rebuilds the given aggregate initialization instruction as a corresponding
// constant aggregate value, if its elements are all constants.
static auto RebuildInitAsValue(EvalContext& eval_context, SemIR::Inst inst,
SemIR::InstKind value_kind)
-> SemIR::ConstantId {
auto init_inst = inst.As<SemIR::AnyAggregateInit>();
Phase phase = GetTypePhase(eval_context, init_inst.type_id);
auto elements_id =
GetConstantValue(eval_context, init_inst.elements_id, &phase);
return MakeConstantResult(
eval_context.context,
SemIR::AnyAggregateValue{.kind = value_kind,
.type_id = init_inst.type_id,
.elements_id = elements_id},
phase);
return TransformIfFieldsAreConstant(
eval_context, inst,
[&](SemIR::AnyAggregateInit result) {
return SemIR::AnyAggregateValue{.kind = value_kind,
.type_id = result.type_id,
.elements_id = result.elements_id};
},
&SemIR::AnyAggregateInit::type_id, &SemIR::AnyAggregateInit::elements_id);
}

// Performs an access into an aggregate, retrieving the specified element.
Expand Down Expand Up @@ -409,8 +424,6 @@ static auto PerformAggregateIndex(EvalContext& eval_context, SemIR::Inst inst)
-> SemIR::ConstantId {
auto index_inst = inst.As<SemIR::AnyAggregateIndex>();
Phase phase = Phase::Template;
auto aggregate_id =
GetConstantValue(eval_context, index_inst.aggregate_id, &phase);
auto index_id = GetConstantValue(eval_context, index_inst.index_id, &phase);

if (!index_id.is_valid()) {
Expand All @@ -423,10 +436,11 @@ static auto PerformAggregateIndex(EvalContext& eval_context, SemIR::Inst inst)
return MakeNonConstantResult(phase);
}

// Array indexing is invalid if the index is constant and out of range.
auto aggregate_type_id =
eval_context.insts().Get(index_inst.aggregate_id).type_id();
// Array indexing is invalid if the index is constant and out of range,
// regardless of whether the array itself is constant.
const auto& index_val = eval_context.ints().Get(index->int_id);
auto aggregate_type_id = eval_context.GetConstantValueAsType(
eval_context.insts().Get(index_inst.aggregate_id).type_id());
if (auto array_type =
eval_context.types().TryGetAs<SemIR::ArrayType>(aggregate_type_id)) {
if (auto bound = eval_context.insts().TryGetAs<SemIR::IntLiteral>(
Expand All @@ -448,6 +462,8 @@ static auto PerformAggregateIndex(EvalContext& eval_context, SemIR::Inst inst)
}
}

auto aggregate_id =
GetConstantValue(eval_context, index_inst.aggregate_id, &phase);
if (!aggregate_id.is_valid()) {
return MakeNonConstantResult(phase);
}
Expand Down Expand Up @@ -986,9 +1002,11 @@ static auto MakeConstantForCall(EvalContext& eval_context, SemIRLoc loc,
return SemIR::ConstantId::Error;
}

// If the callee isn't constant, this is not a constant call.
// If the callee or return type isn't constant, this is not a constant call.
if (!ReplaceFieldWithConstantValue(eval_context, &call,
&SemIR::Call::callee_id, &phase)) {
&SemIR::Call::callee_id, &phase) ||
!ReplaceFieldWithConstantValue(eval_context, &call, &SemIR::Call::type_id,
&phase)) {
return SemIR::ConstantId::NotConstant;
}

Expand Down Expand Up @@ -1071,6 +1089,7 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
// These cases are constants if their operands are.
case SemIR::AddrOf::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::AddrOf::type_id,
&SemIR::AddrOf::lvalue_id);
case CARBON_KIND(SemIR::ArrayType array_type): {
return RebuildAndValidateIfFieldsAreConstant(
Expand Down Expand Up @@ -1110,13 +1129,17 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
},
&SemIR::ArrayType::bound_id, &SemIR::ArrayType::element_type_id);
}
case SemIR::AssociatedEntity::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::AssociatedEntity::type_id);

case SemIR::AssociatedEntityType::Kind:
return RebuildIfFieldsAreConstant(
eval_context, inst, &SemIR::AssociatedEntityType::entity_type_id);
case SemIR::BoundMethod::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::BoundMethod::object_id,
&SemIR::BoundMethod::function_id);
return RebuildIfFieldsAreConstant(
eval_context, inst, &SemIR::BoundMethod::type_id,
&SemIR::BoundMethod::object_id, &SemIR::BoundMethod::function_id);
case SemIR::ClassType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::ClassType::instance_id);
Expand Down Expand Up @@ -1155,12 +1178,14 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
&SemIR::StructTypeField::field_type_id);
case SemIR::StructValue::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::StructValue::type_id,
&SemIR::StructValue::elements_id);
case SemIR::TupleType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::TupleType::elements_id);
case SemIR::TupleValue::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::TupleValue::type_id,
&SemIR::TupleValue::elements_id);
case SemIR::UnboundElementType::Kind:
return RebuildIfFieldsAreConstant(
Expand All @@ -1181,7 +1206,6 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
case SemIR::TupleInit::Kind:
return RebuildInitAsValue(eval_context, inst, SemIR::TupleValue::Kind);

case SemIR::AssociatedEntity::Kind:
case SemIR::BuiltinInst::Kind:
case SemIR::FunctionType::Kind:
case SemIR::GenericClassType::Kind:
Expand All @@ -1190,22 +1214,27 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
return MakeConstantResult(eval_context.context, inst, Phase::Template);

case CARBON_KIND(SemIR::FunctionDecl fn_decl): {
return MakeConstantResult(
eval_context.context,
SemIR::StructValue{.type_id = fn_decl.type_id,
.elements_id = SemIR::InstBlockId::Empty},
GetTypePhase(eval_context, fn_decl.type_id));
return TransformIfFieldsAreConstant(
eval_context, fn_decl,
[&](SemIR::FunctionDecl result) {
return SemIR::StructValue{.type_id = result.type_id,
.elements_id = SemIR::InstBlockId::Empty};
},
&SemIR::FunctionDecl::type_id);
}

case CARBON_KIND(SemIR::ClassDecl class_decl): {
// If the class has generic parameters, we don't produce a class type, but
// a callable whose return value is a class type.
if (eval_context.classes().Get(class_decl.class_id).is_generic()) {
return MakeConstantResult(
eval_context.context,
SemIR::StructValue{.type_id = class_decl.type_id,
.elements_id = SemIR::InstBlockId::Empty},
GetTypePhase(eval_context, class_decl.type_id));
return TransformIfFieldsAreConstant(
eval_context, class_decl,
[&](SemIR::ClassDecl result) {
return SemIR::StructValue{
.type_id = result.type_id,
.elements_id = SemIR::InstBlockId::Empty};
},
&SemIR::ClassDecl::type_id);
}
// A non-generic class declaration evaluates to the class type.
return MakeConstantResult(
Expand All @@ -1221,11 +1250,14 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
if (eval_context.interfaces()
.Get(interface_decl.interface_id)
.is_generic()) {
return MakeConstantResult(
eval_context.context,
SemIR::StructValue{.type_id = interface_decl.type_id,
.elements_id = SemIR::InstBlockId::Empty},
GetTypePhase(eval_context, interface_decl.type_id));
return TransformIfFieldsAreConstant(
eval_context, interface_decl,
[&](SemIR::InterfaceDecl result) {
return SemIR::StructValue{
.type_id = result.type_id,
.elements_id = SemIR::InstBlockId::Empty};
},
&SemIR::InterfaceDecl::type_id);
}
// A non-generic interface declaration evaluates to the interface type.
return MakeConstantResult(
Expand Down Expand Up @@ -1261,6 +1293,8 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
// TODO: Convert literals into a canonical form. Currently we can form two
// different `i32` constants with the same value if they are represented
// by `APInt`s with different bit widths.
// TODO: Can the type of an IntLiteral or FloatLiteral be symbolic? If so,
// we may need to rebuild.
return MakeConstantResult(eval_context.context, inst, Phase::Template);

// The elements of a constant aggregate can be accessed.
Expand Down
1 change: 0 additions & 1 deletion toolchain/check/testdata/array/generic_empty.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ fn G(T:! type) {
// CHECK:STDOUT: %.3: type = array_type %.2, %T [symbolic]
// CHECK:STDOUT: %.4: type = ptr_type %.3 [symbolic]
// CHECK:STDOUT: %array: %.3 = tuple_value () [symbolic]
// CHECK:STDOUT: %.5: type = ptr_type @G.%.loc13_17 (%.3) [symbolic]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand Down
1 change: 0 additions & 1 deletion toolchain/check/testdata/eval/symbolic.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ fn F(T:! type) {
// CHECK:STDOUT: %.9: i32 = int_literal 5 [template]
// CHECK:STDOUT: %.10: type = array_type %.9, %T [symbolic]
// CHECK:STDOUT: %.11: type = ptr_type %.10 [symbolic]
// CHECK:STDOUT: %.12: type = ptr_type @F.%.loc15_15 (%.10) [symbolic]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ var b: i32 = a[2.6];
// CHECK:STDOUT: assign file.%a.var, %.loc11_24
// CHECK:STDOUT: %a.ref: ref %.3 = name_ref a, file.%a
// CHECK:STDOUT: %.loc15_16: f64 = float_literal 2.6000000000000001 [template = constants.%.8]
// CHECK:STDOUT: %.loc15_19.1: ref i32 = array_index %a.ref, <error>
// CHECK:STDOUT: %.loc15_19.1: ref i32 = array_index %a.ref, <error> [template = <error>]
// CHECK:STDOUT: %.loc15_19.2: i32 = bind_value %.loc15_19.1
// CHECK:STDOUT: assign file.%b.var, %.loc15_19.2
// CHECK:STDOUT: return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn Run() {
// CHECK:STDOUT: %.loc17_7: i32 = int_literal 0 [template = constants.%.2]
// CHECK:STDOUT: %.loc17_4.1: ref %.1 = temporary_storage
// CHECK:STDOUT: %.loc17_4.2: ref %.1 = temporary %.loc17_4.1, %F.call
// CHECK:STDOUT: %.loc17_8: ref <error> = tuple_index %.loc17_4.2, <error>
// CHECK:STDOUT: %.loc17_8: ref <error> = tuple_index %.loc17_4.2, <error> [template = <error>]
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ var b: i32 = a[-10];
// CHECK:STDOUT: assign file.%a.var, %.loc11_28
// CHECK:STDOUT: %a.ref: ref %.3 = name_ref a, file.%a
// CHECK:STDOUT: %.loc15_17: i32 = int_literal 10 [template = constants.%.7]
// CHECK:STDOUT: %.loc15_19: ref <error> = tuple_index %a.ref, <error>
// CHECK:STDOUT: %.loc15_19: ref <error> = tuple_index %a.ref, <error> [template = <error>]
// CHECK:STDOUT: assign file.%b.var, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ var c: i32 = a[b];
// CHECK:STDOUT: %a.ref: ref %.3 = name_ref a, file.%a
// CHECK:STDOUT: %b.ref: ref i32 = name_ref b, file.%b
// CHECK:STDOUT: %.loc16_16: i32 = bind_value %b.ref
// CHECK:STDOUT: %.loc16_17: ref <error> = tuple_index %a.ref, <error>
// CHECK:STDOUT: %.loc16_17: ref <error> = tuple_index %a.ref, <error> [template = <error>]
// CHECK:STDOUT: assign file.%c.var, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ var b: i32 = a[{.index = 2}.index];
// CHECK:STDOUT: %struct: %.8 = struct_value (%.loc15_26) [template = constants.%struct]
// CHECK:STDOUT: %.loc15_27.2: %.8 = converted %.loc15_27.1, %struct [template = constants.%struct]
// CHECK:STDOUT: %.loc15_28: i32 = struct_access %.loc15_27.2, element0 [template = constants.%.7]
// CHECK:STDOUT: %.loc15_34: ref <error> = tuple_index %a.ref, <error>
// CHECK:STDOUT: %.loc15_34: ref <error> = tuple_index %a.ref, <error> [template = <error>]
// CHECK:STDOUT: assign file.%b.var, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ var b: i32 = a[oops];
// CHECK:STDOUT: assign file.%a.var, %.loc11_28
// CHECK:STDOUT: %a.ref: ref %.3 = name_ref a, file.%a
// CHECK:STDOUT: %oops.ref: <error> = name_ref oops, <error> [template = <error>]
// CHECK:STDOUT: %.loc15: ref <error> = tuple_index %a.ref, <error>
// CHECK:STDOUT: %.loc15: ref <error> = tuple_index %a.ref, <error> [template = <error>]
// CHECK:STDOUT: assign file.%b.var, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
Expand Down
4 changes: 2 additions & 2 deletions toolchain/check/testdata/index/fail_tuple_large_index.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ var d: i32 = b[0x7FFF_FFFF];
// CHECK:STDOUT: assign file.%b.var, %.loc12_18
// CHECK:STDOUT: %b.ref.loc17: ref %.3 = name_ref b, file.%b
// CHECK:STDOUT: %.loc17_16: i32 = int_literal 1 [template = constants.%.5]
// CHECK:STDOUT: %.loc17_17: ref <error> = tuple_index %b.ref.loc17, <error>
// CHECK:STDOUT: %.loc17_17: ref <error> = tuple_index %b.ref.loc17, <error> [template = <error>]
// CHECK:STDOUT: assign file.%c.var, <error>
// CHECK:STDOUT: %b.ref.loc21: ref %.3 = name_ref b, file.%b
// CHECK:STDOUT: %.loc21_16: i32 = int_literal 2147483647 [template = constants.%.6]
// CHECK:STDOUT: %.loc21_27: ref <error> = tuple_index %b.ref.loc21, <error>
// CHECK:STDOUT: %.loc21_27: ref <error> = tuple_index %b.ref.loc21, <error> [template = <error>]
// CHECK:STDOUT: assign file.%d.var, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
Expand Down
Loading
Loading