Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-argument @min/@max and notice bounds #15522

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7907,6 +7907,48 @@ fn typeOf(
return rvalue(gz, ri, typeof_inst, node);
}

fn minMax(
gz: *GenZir,
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
args: []const Ast.Node.Index,
comptime op: enum { min, max },
) InnerError!Zir.Inst.Ref {
const astgen = gz.astgen;
if (args.len < 2) {
return astgen.failNode(node, "expected at least 2 arguments, found 0", .{});
}
if (args.len == 2) {
const tag: Zir.Inst.Tag = switch (op) {
.min => .min,
.max => .max,
};
const a = try expr(gz, scope, .{ .rl = .none }, args[0]);
const b = try expr(gz, scope, .{ .rl = .none }, args[1]);
const result = try gz.addPlNode(tag, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
}
const payload_index = try addExtra(astgen, Zir.Inst.NodeMultiOp{
.src_node = gz.nodeIndexToRelative(node),
});
var extra_index = try reserveExtra(gz.astgen, args.len);
for (args) |arg| {
const arg_ref = try expr(gz, scope, .{ .rl = .none }, arg);
astgen.extra.items[extra_index] = @enumToInt(arg_ref);
extra_index += 1;
}
const tag: Zir.Inst.Extended = switch (op) {
.min => .min_multi,
.max => .max_multi,
};
const result = try gz.addExtendedMultiOpPayloadIndex(tag, payload_index, args.len);
return rvalue(gz, ri, result, node);
}

fn builtinCall(
gz: *GenZir,
scope: *Scope,
Expand Down Expand Up @@ -7997,6 +8039,8 @@ fn builtinCall(
.TypeOf => return typeOf( gz, scope, ri, node, params),
.union_init => return unionInit(gz, scope, ri, node, params),
.c_import => return cImport( gz, scope, node, params[0]),
.min => return minMax( gz, scope, ri, node, params, .min),
.max => return minMax( gz, scope, ri, node, params, .max),
// zig fmt: on

.@"export" => {
Expand Down Expand Up @@ -8358,25 +8402,6 @@ fn builtinCall(
return rvalue(gz, ri, result, node);
},

.max => {
const a = try expr(gz, scope, .{ .rl = .none }, params[0]);
const b = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addPlNode(.max, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
},
.min => {
const a = try expr(gz, scope, .{ .rl = .none }, params[0]);
const b = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addPlNode(.min, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
},

.add_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .add_with_overflow),
.sub_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .sub_with_overflow),
.mul_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .mul_with_overflow),
Expand Down
4 changes: 2 additions & 2 deletions src/BuiltinFn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ pub const list = list: {
"@max",
.{
.tag = .max,
.param_count = 2,
.param_count = null,
},
},
.{
Expand All @@ -629,7 +629,7 @@ pub const list = list: {
"@min",
.{
.tag = .min,
.param_count = 2,
.param_count = null,
},
},
.{
Expand Down
241 changes: 201 additions & 40 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ fn analyzeBodyInner(
.asm_expr => try sema.zirAsm( block, extended, true),
.typeof_peer => try sema.zirTypeofPeer( block, extended),
.compile_log => try sema.zirCompileLog( extended),
.min_multi => try sema.zirMinMaxMulti( block, extended, .min),
.max_multi => try sema.zirMinMaxMulti( block, extended, .max),
.add_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
.sub_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
.mul_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
Expand Down Expand Up @@ -12143,7 +12145,7 @@ fn zirShl(
lhs_ty,
try lhs_ty.maxInt(sema.arena, target),
);
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, rhs, max_int, .min, rhs_src, rhs_src);
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, .min, &.{ rhs, max_int }, &.{ rhs_src, rhs_src });
break :rhs try sema.intCast(block, src, lhs_ty, rhs_src, rhs_limited, rhs_src, false);
} else {
break :rhs rhs;
Expand Down Expand Up @@ -21752,64 +21754,223 @@ fn zirMinMax(
const rhs = try sema.resolveInst(extra.rhs);
try sema.checkNumericType(block, lhs_src, sema.typeOf(lhs));
try sema.checkNumericType(block, rhs_src, sema.typeOf(rhs));
return sema.analyzeMinMax(block, src, lhs, rhs, air_tag, lhs_src, rhs_src);
return sema.analyzeMinMax(block, src, air_tag, &.{ lhs, rhs }, &.{ lhs_src, rhs_src });
}

fn zirMinMaxMulti(
sema: *Sema,
block: *Block,
extended: Zir.Inst.Extended.InstData,
comptime air_tag: Air.Inst.Tag,
) CompileError!Air.Inst.Ref {
const extra = sema.code.extraData(Zir.Inst.NodeMultiOp, extended.operand);
const src_node = extra.data.src_node;
const src = LazySrcLoc.nodeOffset(src_node);
const operands = sema.code.refSlice(extra.end, extended.small);

const air_refs = try sema.arena.alloc(Air.Inst.Ref, operands.len);
const operand_srcs = try sema.arena.alloc(LazySrcLoc, operands.len);

for (operands, air_refs, operand_srcs, 0..) |zir_ref, *air_ref, *op_src, i| {
op_src.* = switch (i) {
0 => .{ .node_offset_builtin_call_arg0 = src_node },
1 => .{ .node_offset_builtin_call_arg1 = src_node },
2 => .{ .node_offset_builtin_call_arg2 = src_node },
3 => .{ .node_offset_builtin_call_arg3 = src_node },
4 => .{ .node_offset_builtin_call_arg4 = src_node },
5 => .{ .node_offset_builtin_call_arg5 = src_node },
else => src, // TODO: better source location
};
air_ref.* = try sema.resolveInst(zir_ref);
try sema.checkNumericType(block, op_src.*, sema.typeOf(air_ref.*));
}

return sema.analyzeMinMax(block, src, air_tag, air_refs, operand_srcs);
}

fn analyzeMinMax(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
lhs: Air.Inst.Ref,
rhs: Air.Inst.Ref,
comptime air_tag: Air.Inst.Tag,
lhs_src: LazySrcLoc,
rhs_src: LazySrcLoc,
operands: []const Air.Inst.Ref,
operand_srcs: []const LazySrcLoc,
) CompileError!Air.Inst.Ref {
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
assert(operands.len == operand_srcs.len);
assert(operands.len > 0);

// TODO @max(max_int, undefined) should return max_int
if (operands.len == 1) return operands[0];

const runtime_src = if (simd_op.lhs_val) |lhs_val| rs: {
if (lhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);
const mod = sema.mod;
const target = mod.getTarget();
const opFunc = switch (air_tag) {
.min => Value.numberMin,
.max => Value.numberMax,
else => unreachable,
};

const rhs_val = simd_op.rhs_val orelse break :rs rhs_src;
// First, find all comptime-known arguments, and get their min/max
var runtime_known = try std.DynamicBitSet.initFull(sema.arena, operands.len);
var cur_minmax: ?Air.Inst.Ref = null;
var cur_minmax_src: LazySrcLoc = undefined; // defined if cur_minmax not null
for (operands, operand_srcs, 0..) |operand, operand_src, operand_idx| {
// Resolve the value now to avoid redundant calls to `checkSimdBinOp` - we'll have to call
// it in the runtime path anyway since the result type may have been refined
const uncasted_operand_val = (try sema.resolveMaybeUndefVal(operand)) orelse continue;
if (cur_minmax) |cur| {
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
const cur_val = simd_op.lhs_val.?; // cur_minmax is comptime-known
const operand_val = simd_op.rhs_val.?; // we checked the operand was resolvable above

runtime_known.unset(operand_idx);

if (cur_val.isUndef()) continue; // result is also undef
if (operand_val.isUndef()) {
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
continue;
}

if (rhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);
try sema.resolveLazyValue(cur_val);
try sema.resolveLazyValue(operand_val);

try sema.resolveLazyValue(lhs_val);
try sema.resolveLazyValue(rhs_val);
const vec_len = simd_op.len orelse {
const result_val = opFunc(cur_val, operand_val, target);
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
continue;
};
var lhs_buf: Value.ElemValueBuffer = undefined;
var rhs_buf: Value.ElemValueBuffer = undefined;
const elems = try sema.arena.alloc(Value, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = cur_val.elemValueBuffer(mod, i, &lhs_buf);
const rhs_elem_val = operand_val.elemValueBuffer(mod, i, &rhs_buf);
elem.* = opFunc(lhs_elem_val, rhs_elem_val, target);
}
cur_minmax = try sema.addConstant(
simd_op.result_ty,
try Value.Tag.aggregate.create(sema.arena, elems),
);
} else {
runtime_known.unset(operand_idx);
cur_minmax = try sema.addConstant(sema.typeOf(operand), uncasted_operand_val);
cur_minmax_src = operand_src;
}
}

const comptime_refined_ty: ?Type = if (cur_minmax) |ct_minmax_ref| refined: {
// Refine the comptime-known result type based on the operation
const val = (try sema.resolveMaybeUndefVal(ct_minmax_ref)).?;
const orig_ty = sema.typeOf(ct_minmax_ref);
const refined_ty = if (orig_ty.zigTypeTag() == .Vector) blk: {
const elem_ty = orig_ty.childType();
const len = orig_ty.vectorLen();

if (len == 0) break :blk orig_ty;
if (elem_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats

var cur_min: Value = try val.elemValue(mod, sema.arena, 0);
var cur_max: Value = cur_min;
for (1..len) |idx| {
const elem_val = try val.elemValue(mod, sema.arena, idx);
if (elem_val.isUndef()) break :blk orig_ty; // can't refine undef
if (Value.order(elem_val, cur_min, target).compare(.lt)) cur_min = elem_val;
if (Value.order(elem_val, cur_max, target).compare(.gt)) cur_max = elem_val;
}

const refined_elem_ty = try Type.intFittingRange(target, sema.arena, cur_min, cur_max);
break :blk try Type.vector(sema.arena, len, refined_elem_ty);
} else blk: {
if (orig_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
if (val.isUndef()) break :blk orig_ty; // can't refine undef
break :blk try Type.intFittingRange(target, sema.arena, val, val);
};

// Apply the refined type to the current value - this isn't strictly necessary in the
// runtime case since we'll refine again afterwards, but keeping things as small as possible
// will allow us to emit more optimal AIR (if all the runtime operands have smaller types
// than the non-refined comptime type).
if (!refined_ty.eql(orig_ty, mod)) {
if (std.debug.runtime_safety) {
assert(try sema.intFitsInType(val, refined_ty, null));
}
cur_minmax = try sema.addConstant(refined_ty, val);
}

break :refined refined_ty;
} else null;

const runtime_idx = runtime_known.findFirstSet() orelse return cur_minmax.?;
const runtime_src = operand_srcs[runtime_idx];
try sema.requireRuntimeBlock(block, src, runtime_src);

// Now, iterate over runtime operands, emitting a min/max instruction for each. We'll refine the
// type again at the end, based on the comptime-known bound.

// If the comptime-known part is undef we can avoid emitting actual instructions later
const known_undef = if (cur_minmax) |operand| blk: {
const val = (try sema.resolveMaybeUndefVal(operand)).?;
break :blk val.isUndef();
} else false;

if (cur_minmax == null) {
// No comptime operands - use the first operand as the starting value
assert(runtime_idx == 0);
cur_minmax = operands[0];
cur_minmax_src = runtime_src;
runtime_known.unset(0); // don't look at this operand in the loop below
}

var it = runtime_known.iterator(.{});
while (it.next()) |idx| {
const lhs = cur_minmax.?;
const lhs_src = cur_minmax_src;
const rhs = operands[idx];
const rhs_src = operand_srcs[idx];
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
if (known_undef) {
cur_minmax = try sema.addConstant(simd_op.result_ty, Value.undef);
} else {
cur_minmax = try block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
}
}

if (comptime_refined_ty) |comptime_ty| refine: {
// Finally, refine the type based on the comptime-known bound.
if (known_undef) break :refine; // can't refine undef
const unrefined_ty = sema.typeOf(cur_minmax.?);
const is_vector = unrefined_ty.zigTypeTag() == .Vector;
const comptime_elem_ty = if (is_vector) comptime_ty.childType() else comptime_ty;
const unrefined_elem_ty = if (is_vector) unrefined_ty.childType() else unrefined_ty;

if (unrefined_elem_ty.isAnyFloat()) break :refine; // we can't refine floats

const opFunc = switch (air_tag) {
.min => Value.numberMin,
.max => Value.numberMax,
// Compute the final bounds based on the runtime type and the comptime-known bound type
const min_val = switch (air_tag) {
.min => try unrefined_elem_ty.minInt(sema.arena, target),
.max => try comptime_elem_ty.minInt(sema.arena, target), // @max(ct, rt) >= ct
else => unreachable,
};
const target = sema.mod.getTarget();
const vec_len = simd_op.len orelse {
const result_val = opFunc(lhs_val, rhs_val, target);
return sema.addConstant(simd_op.result_ty, result_val);
const max_val = switch (air_tag) {
.min => try comptime_elem_ty.maxInt(sema.arena, target), // @min(ct, rt) <= ct
.max => try unrefined_elem_ty.maxInt(sema.arena, target),
else => unreachable,
};
var lhs_buf: Value.ElemValueBuffer = undefined;
var rhs_buf: Value.ElemValueBuffer = undefined;
const elems = try sema.arena.alloc(Value, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = lhs_val.elemValueBuffer(sema.mod, i, &lhs_buf);
const rhs_elem_val = rhs_val.elemValueBuffer(sema.mod, i, &rhs_buf);
elem.* = opFunc(lhs_elem_val, rhs_elem_val, target);
}
return sema.addConstant(
simd_op.result_ty,
try Value.Tag.aggregate.create(sema.arena, elems),
);
} else rs: {
if (simd_op.rhs_val) |rhs_val| {
if (rhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);

// Find the smallest type which can contain these bounds
const final_elem_ty = try Type.intFittingRange(target, sema.arena, min_val, max_val);

const final_ty = if (is_vector)
try Type.vector(sema.arena, unrefined_ty.vectorLen(), final_elem_ty)
else
final_elem_ty;

if (!final_ty.eql(unrefined_ty, mod)) {
// We've reduced the type - cast the result down
return block.addTyOp(.intcast, final_ty, cur_minmax.?);
}
break :rs lhs_src;
};
}

try sema.requireRuntimeBlock(block, src, runtime_src);
return block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
return cur_minmax.?;
}

fn upgradeToArrayPtr(sema: *Sema, block: *Block, ptr: Air.Inst.Ref, len: u64) !Air.Inst.Ref {
Expand Down
Loading