Skip to content

Commit 20e5838

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] apply_vector_layout C++: Be consistent about using "Not implemented" as a prefix for error messages
I want to rely on this in the Python bindings to raise `NotImplementedError` exceptions. PiperOrigin-RevId: 575897758
1 parent 9bc0439 commit 20e5838

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@
6464
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
6565
// consistent about this for layout null checks in rules.
6666

67-
#define NYI(msg) \
68-
op->emitOpError("not implemented: " msg); \
69-
return failure();
70-
7167
namespace mlir::tpu {
7268
// TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of
7369
// using XLA's? There's too much glue for going from/to ArrayRef.
@@ -354,7 +350,8 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
354350
Block &entry_block = ctx.func.getBody().front();
355351
auto value_ty = cast<VectorType>(value.getType());
356352
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
357-
return ctx.func.emitOpError("Only 32-bit constants supported");
353+
return ctx.func.emitOpError(
354+
"Not implemented: Only 32-bit constants supported");
358355
}
359356
if (ctx.func->getAttr("scratch_operands")) {
360357
return ctx.func.emitOpError(
@@ -514,7 +511,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
514511
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
515512
return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape);
516513
})) {
517-
return op.emitOpError("Incompatible layouts in elementwise operation");
514+
return op.emitOpError(
515+
"Not implemented: Incompatible layouts in elementwise operation");
518516
}
519517
const unsigned num_operands = op.getNumOperands();
520518
SmallVector<xla::Array<Value>> in_vreg_arrays;
@@ -585,7 +583,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
585583
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
586584
auto result_ty = cast<VectorType>(op.getResult().getType());
587585
if (layout_out.bitwidth() != 32) {
588-
return op.emitOpError("Only extensions to 32-bit supported");
586+
return op.emitOpError(
587+
"Not implemented: Only extensions to 32-bit supported");
589588
}
590589
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
591590
disassemble(ctx, builder, layout_in, op.getIn()));
@@ -600,7 +599,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
600599
switch (layout_in.implicit_dim()) {
601600
case VectorLayout::ImplicitDim::kNone: {
602601
if (layout_in.tiling() != layout_out.tiling()) {
603-
return op.emitOpError("Changing tiling during extension");
602+
return op.emitOpError(
603+
"Not implemented: Changing tiling during the cast");
604604
}
605605
auto tiling = layout_in.tiling();
606606
if (ctx.target_shape[0] % tiling[0] != 0 ||
@@ -648,7 +648,8 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
648648
auto extf_op = cast<arith::ExtFOp>(op);
649649
if (layouts_in.front()->bitwidth() != 16 ||
650650
layouts_out.front()->bitwidth() != 32) {
651-
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
651+
return op.emitOpError(
652+
"Not implemented: Only 16-bit to 32-bit conversion supported");
652653
}
653654
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
654655
*layouts_out.front());
@@ -677,7 +678,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
677678
xla::Array<Value> output_vregs(
678679
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
679680
if (layout_in.bitwidth() != 32) {
680-
return op.emitOpError("Only 32-bit truncation supported");
681+
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
681682
}
682683
FAILUREOR_ASSIGN_OR_RETURN(
683684
VectorType res_vreg_ty,
@@ -1257,8 +1258,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
12571258
if (!layout.hasNaturalTopology(ctx.target_shape) ||
12581259
layout.offsets() != LayoutOffsets{0, 0}) {
12591260
return op.emitOpError(
1260-
"Only native tiling with offset (0, 0) is supported when "
1261-
"concatenation along tiling dims.");
1261+
"Not implemented: Only native tiling with offset (0, 0) is supported "
1262+
"when concatenation along tiling dims.");
12621263
}
12631264
// Check if shapes of src and res are aligned to native tiling.
12641265
auto check_aligned = [&](const VectorType &vty) {
@@ -1274,8 +1275,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
12741275
}
12751276
if (!is_aligned) {
12761277
return op.emitOpError(
1277-
"Only aligned shapes are supported when concatenation along tiling "
1278-
"dims");
1278+
"Not implemented: Only aligned shapes are supported when "
1279+
"concatenation along tiling dims");
12791280
}
12801281
}
12811282

@@ -1582,12 +1583,14 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
15821583
AffineMap load_map;
15831584
arith::ConstantOp padding;
15841585
if (offsets[1] == std::nullopt) {
1585-
return op.emitOpError("Load replicated along lanes is unsupported");
1586+
return op.emitOpError(
1587+
"Not implemented: Load replicated along lanes is unsupported");
15861588
}
15871589
if (offsets[0] == std::nullopt) {
15881590
if (ss != 1) {
15891591
return op.emitOpError(
1590-
"Sublane-replicated load with size > 1 is unsupported");
1592+
"Not implemented: Sublane-replicated load with size > 1 is "
1593+
"unsupported");
15911594
}
15921595
if (!layout_out.hasNativeTiling(ctx.target_shape)) {
15931596
return op.emitOpError("Not implemented");
@@ -1692,7 +1695,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
16921695
getNativeVregType(vty.getElementType(), ctx.target_shape));
16931696
if (value.isSplat()) {
16941697
if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) {
1695-
return op.emitOpError("Non-replicated splat constants");
1698+
return op.emitOpError(
1699+
"Not implemented: Non-replicated splat constants");
16961700
}
16971701
auto new_value =
16981702
DenseElementsAttr::get(target_vty, value.getSplatValue<Attribute>());
@@ -1708,7 +1712,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
17081712
}
17091713
// !value.isSplat()
17101714
if (getTypeBitwidth<true>(vty.getElementType()) != 32) {
1711-
return op.emitOpError("Only 32-bit non-splat constants are supported");
1715+
return op.emitOpError(
1716+
"Not implemented: Only 32-bit non-splat constants are supported");
17121717
}
17131718
FAILUREOR_ASSIGN_OR_RETURN(const BlockArgument ref,
17141719
appendConstant(ctx, value));
@@ -1722,7 +1727,7 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
17221727
{VectorLayout(/*bitwidth=*/32, /*offsets=*/{0, 0},
17231728
/*tiling=*/ctx.target_shape)});
17241729
}
1725-
return op.emitOpError("Unsupported arith.const type: ")
1730+
return op.emitOpError("Not implemented: Unsupported arith.const type: ")
17261731
<< op.getResult(0).getType();
17271732
}
17281733

@@ -1959,7 +1964,7 @@ LogicalResult vector_contract_rule(RewriteContext &ctx, Operation &op,
19591964
if (indexing_maps != matmul_indexing_maps &&
19601965
indexing_maps != matmul_indexing_maps_transposed) {
19611966
return vector_contract_op->emitOpError(
1962-
"Non-matmul or unsupported indexing_maps");
1967+
"Not implemented: Non-matmul or unsupported indexing_maps");
19631968
}
19641969
const bool transpose_rhs = indexing_maps == matmul_indexing_maps_transposed;
19651970
const ArrayAttr matmul_iterator_types =
@@ -2125,15 +2130,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
21252130
"Not implemented: unsupported kind");
21262131
}
21272132
if (val != neutral.getValueAsDouble()) {
2128-
return multi_reduction_op.emitOpError("Only neutral accumulator supported");
2133+
return multi_reduction_op.emitOpError(
2134+
"Not implemented: Only neutral accumulator supported");
21292135
}
21302136

21312137
if (src_layout.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
21322138
src_layout.hasNaturalTopology(ctx.target_shape)) {
21332139
auto [sublane_offset, lane_offset] = src_layout.offsets();
21342140
if (dim < 0) {
21352141
return multi_reduction_op.emitOpError(
2136-
"Negative reduction dimension unsupported");
2142+
"Not implemented: Negative reduction dimension unsupported");
21372143
}
21382144
int64_t vdim;
21392145
Direction reduce_over;
@@ -2252,7 +2258,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
22522258
multi_reduction_op->erase();
22532259
return success();
22542260
}
2255-
return multi_reduction_op->emitOpError("Unsupported layout: ") << src_layout;
2261+
return multi_reduction_op->emitOpError(
2262+
"Not implemented: Unsupported layout: ")
2263+
<< src_layout;
22562264
}
22572265

22582266
LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
@@ -2826,7 +2834,7 @@ FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
28262834
val.getLoc(), SmallVector<Type>(num_vectors, vreg_ty), val);
28272835
return XlaArrayFromShapeAndValues<Value>(layout_shape, u->getResults());
28282836
}
2829-
return op->emitOpError("unimplemented: ") << val;
2837+
return op->emitOpError("Not implemented: ") << val;
28302838
}
28312839

28322840
// Assembles a destination tile using partial data from rotated vregs using a
@@ -3337,7 +3345,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
33373345
const auto &tiling = src.tiling();
33383346
// TODO(apaszke): Changing an offset might add or remove one vreg.
33393347
if (dst_tiles_shape != src_tiles.dimensions()) {
3340-
return emitError(v.getLoc(), "Offsets changing the vreg array shape");
3348+
return emitError(
3349+
v.getLoc(), "Not implemented: Offsets changing the vreg array shape");
33413350
}
33423351
xla::Array<Value> dst_tiles = src_tiles;
33433352

@@ -3346,7 +3355,7 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
33463355
if (!src.offsets()[0].has_value()) {
33473356
row_diff = 0;
33483357
} else if (!dst.offsets()[0].has_value()) {
3349-
return emitError(v.getLoc(), "Sublane broadcast not implemented");
3358+
return emitError(v.getLoc(), "Not implemented: Sublane broadcast");
33503359
} else {
33513360
row_diff = *dst.offsets()[0] - *src.offsets()[0];
33523361
}
@@ -3356,7 +3365,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
33563365
const SmallVector<int64_t> implicit_shape =
33573366
src.implicitShape(vty.getShape());
33583367
if (implicit_shape[implicit_shape.size() - 2] != 1) {
3359-
return emitError(v.getLoc(), "Row shifts for multi-row values");
3368+
return emitError(v.getLoc(),
3369+
"Not implemented: Row shifts for multi-row values");
33603370
}
33613371
const int64_t src_sublane = *src.offsets()[0] / packing;
33623372
const int64_t dst_sublane = *dst.offsets()[0] / packing;
@@ -3454,7 +3464,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
34543464
return assemble(ctx, builder, vty, dst, std::move(dst_tiles)).getResult();
34553465
}
34563466
// TODO(apaszke): Implement general relayout
3457-
return emitError(v.getLoc(), "unsupported layout change for ")
3467+
return emitError(v.getLoc(),
3468+
"Not implemented: Unsupported layout change for ")
34583469
<< vty << ": " << src << " -> " << dst;
34593470
}
34603471

@@ -3497,7 +3508,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
34973508
auto vty = dyn_cast<VectorType>(operand.getType());
34983509
if ((vty == nullptr) == li.has_value()) {
34993510
return op.emitError(
3500-
"layout should be none iff operand is not a vector");
3511+
"Layout should be none iff operand is not a vector");
35013512
}
35023513
if (vty == nullptr) {
35033514
continue;
@@ -3508,7 +3519,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
35083519
// arguments.
35093520
auto op_result = dyn_cast<OpResult>(operand);
35103521
if (op_result == nullptr) {
3511-
return op.emitError("expected operand to be an operation result");
3522+
return op.emitError("Expected operand to be an operation result");
35123523
}
35133524
Operation *const def_op = op_result.getOwner();
35143525
CHECK(def_op);
@@ -3517,7 +3528,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
35173528
getOutLayout(*def_op));
35183529
const Layout lo = def_layouts[res_idx];
35193530
if (!lo.has_value()) {
3520-
return op.emitError() << "vector result should have a defined layout";
3531+
return op.emitError() << "Vector result should have a defined layout";
35213532
}
35223533
if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) {
35233534
continue;
@@ -3553,7 +3564,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
35533564
if (OpTrait::hasElementwiseMappableTraits(&op)) {
35543565
return elementwise_op_rule(ctx, op, layout_in, layout_out);
35553566
}
3556-
return op.emitError("Unsupported operation: ") << op.getName();
3567+
return op.emitError("Not implemented: Unsupported operation: ")
3568+
<< op.getName();
35573569
}
35583570

35593571
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block) {

0 commit comments

Comments
 (0)