Skip to content

Commit

Permalink
Implement multi-argument @min/@max and notice bounds
Browse files Browse the repository at this point in the history
Resolves: ziglang#14039
  • Loading branch information
mlugg authored and andrewrk committed May 2, 2023
1 parent a2e2e25 commit 26e94c7
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 69 deletions.
63 changes: 44 additions & 19 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7907,6 +7907,48 @@ fn typeOf(
return rvalue(gz, ri, typeof_inst, node);
}

fn minMax(
gz: *GenZir,
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
args: []const Ast.Node.Index,
comptime op: enum { min, max },
) InnerError!Zir.Inst.Ref {
const astgen = gz.astgen;
if (args.len < 2) {
return astgen.failNode(node, "expected at least 2 arguments, found 0", .{});
}
if (args.len == 2) {
const tag: Zir.Inst.Tag = switch (op) {
.min => .min,
.max => .max,
};
const a = try expr(gz, scope, .{ .rl = .none }, args[0]);
const b = try expr(gz, scope, .{ .rl = .none }, args[1]);
const result = try gz.addPlNode(tag, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
}
const payload_index = try addExtra(astgen, Zir.Inst.NodeMultiOp{
.src_node = gz.nodeIndexToRelative(node),
});
var extra_index = try reserveExtra(gz.astgen, args.len);
for (args) |arg| {
const arg_ref = try expr(gz, scope, .{ .rl = .none }, arg);
astgen.extra.items[extra_index] = @enumToInt(arg_ref);
extra_index += 1;
}
const tag: Zir.Inst.Extended = switch (op) {
.min => .min_multi,
.max => .max_multi,
};
const result = try gz.addExtendedMultiOpPayloadIndex(tag, payload_index, args.len);
return rvalue(gz, ri, result, node);
}

fn builtinCall(
gz: *GenZir,
scope: *Scope,
Expand Down Expand Up @@ -7997,6 +8039,8 @@ fn builtinCall(
.TypeOf => return typeOf( gz, scope, ri, node, params),
.union_init => return unionInit(gz, scope, ri, node, params),
.c_import => return cImport( gz, scope, node, params[0]),
.min => return minMax( gz, scope, ri, node, params, .min),
.max => return minMax( gz, scope, ri, node, params, .max),
// zig fmt: on

.@"export" => {
Expand Down Expand Up @@ -8358,25 +8402,6 @@ fn builtinCall(
return rvalue(gz, ri, result, node);
},

.max => {
const a = try expr(gz, scope, .{ .rl = .none }, params[0]);
const b = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addPlNode(.max, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
},
.min => {
const a = try expr(gz, scope, .{ .rl = .none }, params[0]);
const b = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addPlNode(.min, node, Zir.Inst.Bin{
.lhs = a,
.rhs = b,
});
return rvalue(gz, ri, result, node);
},

.add_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .add_with_overflow),
.sub_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .sub_with_overflow),
.mul_with_overflow => return overflowArithmetic(gz, scope, ri, node, params, .mul_with_overflow),
Expand Down
4 changes: 2 additions & 2 deletions src/BuiltinFn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ pub const list = list: {
"@max",
.{
.tag = .max,
.param_count = 2,
.param_count = null,
},
},
.{
Expand All @@ -629,7 +629,7 @@ pub const list = list: {
"@min",
.{
.tag = .min,
.param_count = 2,
.param_count = null,
},
},
.{
Expand Down
241 changes: 201 additions & 40 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ fn analyzeBodyInner(
.asm_expr => try sema.zirAsm( block, extended, true),
.typeof_peer => try sema.zirTypeofPeer( block, extended),
.compile_log => try sema.zirCompileLog( extended),
.min_multi => try sema.zirMinMaxMulti( block, extended, .min),
.max_multi => try sema.zirMinMaxMulti( block, extended, .max),
.add_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
.sub_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
.mul_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode),
Expand Down Expand Up @@ -12143,7 +12145,7 @@ fn zirShl(
lhs_ty,
try lhs_ty.maxInt(sema.arena, target),
);
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, rhs, max_int, .min, rhs_src, rhs_src);
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, .min, &.{ rhs, max_int }, &.{ rhs_src, rhs_src });
break :rhs try sema.intCast(block, src, lhs_ty, rhs_src, rhs_limited, rhs_src, false);
} else {
break :rhs rhs;
Expand Down Expand Up @@ -21752,64 +21754,223 @@ fn zirMinMax(
const rhs = try sema.resolveInst(extra.rhs);
try sema.checkNumericType(block, lhs_src, sema.typeOf(lhs));
try sema.checkNumericType(block, rhs_src, sema.typeOf(rhs));
return sema.analyzeMinMax(block, src, lhs, rhs, air_tag, lhs_src, rhs_src);
return sema.analyzeMinMax(block, src, air_tag, &.{ lhs, rhs }, &.{ lhs_src, rhs_src });
}

fn zirMinMaxMulti(
sema: *Sema,
block: *Block,
extended: Zir.Inst.Extended.InstData,
comptime air_tag: Air.Inst.Tag,
) CompileError!Air.Inst.Ref {
const extra = sema.code.extraData(Zir.Inst.NodeMultiOp, extended.operand);
const src_node = extra.data.src_node;
const src = LazySrcLoc.nodeOffset(src_node);
const operands = sema.code.refSlice(extra.end, extended.small);

const air_refs = try sema.arena.alloc(Air.Inst.Ref, operands.len);
const operand_srcs = try sema.arena.alloc(LazySrcLoc, operands.len);

for (operands, air_refs, operand_srcs, 0..) |zir_ref, *air_ref, *op_src, i| {
op_src.* = switch (i) {
0 => .{ .node_offset_builtin_call_arg0 = src_node },
1 => .{ .node_offset_builtin_call_arg1 = src_node },
2 => .{ .node_offset_builtin_call_arg2 = src_node },
3 => .{ .node_offset_builtin_call_arg3 = src_node },
4 => .{ .node_offset_builtin_call_arg4 = src_node },
5 => .{ .node_offset_builtin_call_arg5 = src_node },
else => src, // TODO: better source location
};
air_ref.* = try sema.resolveInst(zir_ref);
try sema.checkNumericType(block, op_src.*, sema.typeOf(air_ref.*));
}

return sema.analyzeMinMax(block, src, air_tag, air_refs, operand_srcs);
}

fn analyzeMinMax(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
lhs: Air.Inst.Ref,
rhs: Air.Inst.Ref,
comptime air_tag: Air.Inst.Tag,
lhs_src: LazySrcLoc,
rhs_src: LazySrcLoc,
operands: []const Air.Inst.Ref,
operand_srcs: []const LazySrcLoc,
) CompileError!Air.Inst.Ref {
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
assert(operands.len == operand_srcs.len);
assert(operands.len > 0);

// TODO @max(max_int, undefined) should return max_int
if (operands.len == 1) return operands[0];

const runtime_src = if (simd_op.lhs_val) |lhs_val| rs: {
if (lhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);
const mod = sema.mod;
const target = mod.getTarget();
const opFunc = switch (air_tag) {
.min => Value.numberMin,
.max => Value.numberMax,
else => unreachable,
};

const rhs_val = simd_op.rhs_val orelse break :rs rhs_src;
// First, find all comptime-known arguments, and get their min/max
var runtime_known = try std.DynamicBitSet.initFull(sema.arena, operands.len);
var cur_minmax: ?Air.Inst.Ref = null;
var cur_minmax_src: LazySrcLoc = undefined; // defined if cur_minmax not null
for (operands, operand_srcs, 0..) |operand, operand_src, operand_idx| {
// Resolve the value now to avoid redundant calls to `checkSimdBinOp` - we'll have to call
// it in the runtime path anyway since the result type may have been refined
const uncasted_operand_val = (try sema.resolveMaybeUndefVal(operand)) orelse continue;
if (cur_minmax) |cur| {
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
const cur_val = simd_op.lhs_val.?; // cur_minmax is comptime-known
const operand_val = simd_op.rhs_val.?; // we checked the operand was resolvable above

runtime_known.unset(operand_idx);

if (cur_val.isUndef()) continue; // result is also undef
if (operand_val.isUndef()) {
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
continue;
}

if (rhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);
try sema.resolveLazyValue(cur_val);
try sema.resolveLazyValue(operand_val);

try sema.resolveLazyValue(lhs_val);
try sema.resolveLazyValue(rhs_val);
const vec_len = simd_op.len orelse {
const result_val = opFunc(cur_val, operand_val, target);
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
continue;
};
var lhs_buf: Value.ElemValueBuffer = undefined;
var rhs_buf: Value.ElemValueBuffer = undefined;
const elems = try sema.arena.alloc(Value, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = cur_val.elemValueBuffer(mod, i, &lhs_buf);
const rhs_elem_val = operand_val.elemValueBuffer(mod, i, &rhs_buf);
elem.* = opFunc(lhs_elem_val, rhs_elem_val, target);
}
cur_minmax = try sema.addConstant(
simd_op.result_ty,
try Value.Tag.aggregate.create(sema.arena, elems),
);
} else {
runtime_known.unset(operand_idx);
cur_minmax = try sema.addConstant(sema.typeOf(operand), uncasted_operand_val);
cur_minmax_src = operand_src;
}
}

const comptime_refined_ty: ?Type = if (cur_minmax) |ct_minmax_ref| refined: {
// Refine the comptime-known result type based on the operation
const val = (try sema.resolveMaybeUndefVal(ct_minmax_ref)).?;
const orig_ty = sema.typeOf(ct_minmax_ref);
const refined_ty = if (orig_ty.zigTypeTag() == .Vector) blk: {
const elem_ty = orig_ty.childType();
const len = orig_ty.vectorLen();

if (len == 0) break :blk orig_ty;
if (elem_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats

var cur_min: Value = try val.elemValue(mod, sema.arena, 0);
var cur_max: Value = cur_min;
for (1..len) |idx| {
const elem_val = try val.elemValue(mod, sema.arena, idx);
if (elem_val.isUndef()) break :blk orig_ty; // can't refine undef
if (Value.order(elem_val, cur_min, target).compare(.lt)) cur_min = elem_val;
if (Value.order(elem_val, cur_max, target).compare(.gt)) cur_max = elem_val;
}

const refined_elem_ty = try Type.intFittingRange(target, sema.arena, cur_min, cur_max);
break :blk try Type.vector(sema.arena, len, refined_elem_ty);
} else blk: {
if (orig_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
if (val.isUndef()) break :blk orig_ty; // can't refine undef
break :blk try Type.intFittingRange(target, sema.arena, val, val);
};

// Apply the refined type to the current value - this isn't strictly necessary in the
// runtime case since we'll refine again afterwards, but keeping things as small as possible
// will allow us to emit more optimal AIR (if all the runtime operands have smaller types
// than the non-refined comptime type).
if (!refined_ty.eql(orig_ty, mod)) {
if (std.debug.runtime_safety) {
assert(try sema.intFitsInType(val, refined_ty, null));
}
cur_minmax = try sema.addConstant(refined_ty, val);
}

break :refined refined_ty;
} else null;

const runtime_idx = runtime_known.findFirstSet() orelse return cur_minmax.?;
const runtime_src = operand_srcs[runtime_idx];
try sema.requireRuntimeBlock(block, src, runtime_src);

// Now, iterate over runtime operands, emitting a min/max instruction for each. We'll refine the
// type again at the end, based on the comptime-known bound.

// If the comptime-known part is undef we can avoid emitting actual instructions later
const known_undef = if (cur_minmax) |operand| blk: {
const val = (try sema.resolveMaybeUndefVal(operand)).?;
break :blk val.isUndef();
} else false;

if (cur_minmax == null) {
// No comptime operands - use the first operand as the starting value
assert(runtime_idx == 0);
cur_minmax = operands[0];
cur_minmax_src = runtime_src;
runtime_known.unset(0); // don't look at this operand in the loop below
}

var it = runtime_known.iterator(.{});
while (it.next()) |idx| {
const lhs = cur_minmax.?;
const lhs_src = cur_minmax_src;
const rhs = operands[idx];
const rhs_src = operand_srcs[idx];
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
if (known_undef) {
cur_minmax = try sema.addConstant(simd_op.result_ty, Value.undef);
} else {
cur_minmax = try block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
}
}

if (comptime_refined_ty) |comptime_ty| refine: {
// Finally, refine the type based on the comptime-known bound.
if (known_undef) break :refine; // can't refine undef
const unrefined_ty = sema.typeOf(cur_minmax.?);
const is_vector = unrefined_ty.zigTypeTag() == .Vector;
const comptime_elem_ty = if (is_vector) comptime_ty.childType() else comptime_ty;
const unrefined_elem_ty = if (is_vector) unrefined_ty.childType() else unrefined_ty;

if (unrefined_elem_ty.isAnyFloat()) break :refine; // we can't refine floats

const opFunc = switch (air_tag) {
.min => Value.numberMin,
.max => Value.numberMax,
// Compute the final bounds based on the runtime type and the comptime-known bound type
const min_val = switch (air_tag) {
.min => try unrefined_elem_ty.minInt(sema.arena, target),
.max => try comptime_elem_ty.minInt(sema.arena, target), // @max(ct, rt) >= ct
else => unreachable,
};
const target = sema.mod.getTarget();
const vec_len = simd_op.len orelse {
const result_val = opFunc(lhs_val, rhs_val, target);
return sema.addConstant(simd_op.result_ty, result_val);
const max_val = switch (air_tag) {
.min => try comptime_elem_ty.maxInt(sema.arena, target), // @min(ct, rt) <= ct
.max => try unrefined_elem_ty.maxInt(sema.arena, target),
else => unreachable,
};
var lhs_buf: Value.ElemValueBuffer = undefined;
var rhs_buf: Value.ElemValueBuffer = undefined;
const elems = try sema.arena.alloc(Value, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = lhs_val.elemValueBuffer(sema.mod, i, &lhs_buf);
const rhs_elem_val = rhs_val.elemValueBuffer(sema.mod, i, &rhs_buf);
elem.* = opFunc(lhs_elem_val, rhs_elem_val, target);
}
return sema.addConstant(
simd_op.result_ty,
try Value.Tag.aggregate.create(sema.arena, elems),
);
} else rs: {
if (simd_op.rhs_val) |rhs_val| {
if (rhs_val.isUndef()) return sema.addConstUndef(simd_op.result_ty);

// Find the smallest type which can contain these bounds
const final_elem_ty = try Type.intFittingRange(target, sema.arena, min_val, max_val);

const final_ty = if (is_vector)
try Type.vector(sema.arena, unrefined_ty.vectorLen(), final_elem_ty)
else
final_elem_ty;

if (!final_ty.eql(unrefined_ty, mod)) {
// We've reduced the type - cast the result down
return block.addTyOp(.intcast, final_ty, cur_minmax.?);
}
break :rs lhs_src;
};
}

try sema.requireRuntimeBlock(block, src, runtime_src);
return block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
return cur_minmax.?;
}

fn upgradeToArrayPtr(sema: *Sema, block: *Block, ptr: Air.Inst.Ref, len: u64) !Air.Inst.Ref {
Expand Down
Loading

0 comments on commit 26e94c7

Please sign in to comment.