Skip to content

Commit

Permalink
[Mosaic][NFC] Prefer mlir aliases for llvm classes/functions within m…
Browse files Browse the repository at this point in the history
…lir namespace for consistency

(also fix a missing cstdint header to fix linter error)

PiperOrigin-RevId: 609826731
  • Loading branch information
tlongeri authored and jax authors committed Feb 23, 2024
1 parent 8a43140 commit 75cdef7
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 47 deletions.
37 changes: 17 additions & 20 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ limitations under the License.

#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -79,7 +77,7 @@ FailureOr<TypedValue<VectorType>> RectangularVregBounds::getVectorMask(

DenseBoolArrayAttr RectangularVregBounds::getSublaneMask(
MLIRContext* mlir_ctx, const std::array<int64_t, 2> target_shape) const {
llvm::SmallVector<bool, 8> sublane_mask(target_shape[0], false);
SmallVector<bool, 8> sublane_mask(target_shape[0], false);
for (int64_t i = starts_[0]; i < ends_[0]; ++i) {
sublane_mask[i] = true;
}
Expand Down Expand Up @@ -178,7 +176,7 @@ class SingleRowVRegBounds : public VRegDataBounds {
const int64_t end_sublane = llvm::divideCeil(
llvm::divideCeil(stop_offset_, layout_.packing()), target_shape[1]);

llvm::SmallVector<bool> sublane_mask(target_shape[0], false);
SmallVector<bool> sublane_mask(target_shape[0], false);
for (int64_t i = start_sublane; i < end_sublane; ++i) {
sublane_mask[i] = true;
}
Expand Down Expand Up @@ -382,7 +380,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
DenseBoolArrayAttr getSublaneMask(
MLIRContext* mlir_ctx,
const std::array<int64_t, 2> target_shape) const override {
llvm::SmallVector<bool> mask(target_shape[0], false);
SmallVector<bool> mask(target_shape[0], false);
const int64_t start = start_offsets_[0] / layout_.packing();
const int64_t end = llvm::divideCeil(end_offsets_[0], layout_.packing());
const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape);
Expand All @@ -403,8 +401,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
std::array<int64_t, 2> end_offsets_;
};

mlir::ParseResult parseOffset(llvm::StringRef* data,
std::optional<int64_t>* result) {
mlir::ParseResult parseOffset(StringRef* data, std::optional<int64_t>* result) {
int64_t int_result;
if (data->consume_front("*")) {
*result = std::nullopt;
Expand Down Expand Up @@ -441,21 +438,21 @@ bool VectorLayout::hasNativeTiling(
return tiling_ == nativeTiling(bitwidth_, target_shape);
}

llvm::SmallVector<int64_t> VectorLayout::implicitShape(
SmallVector<int64_t> VectorLayout::implicitShape(
ArrayRef<int64_t> shape) const {
CHECK(!shape.empty());
switch (implicit_dim_) {
case ImplicitDim::kNone:
return llvm::SmallVector<int64_t>(shape);
return SmallVector<int64_t>(shape);
case ImplicitDim::kMinor: {
llvm::SmallVector<int64_t> implicit_shape;
SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), shape.end());
implicit_shape.push_back(1);
return implicit_shape;
}
case ImplicitDim::kSecondMinor: {
llvm::SmallVector<int64_t> implicit_shape;
SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), std::prev(shape.end()));
implicit_shape.push_back(1);
Expand All @@ -465,11 +462,11 @@ llvm::SmallVector<int64_t> VectorLayout::implicitShape(
}
}

llvm::SmallVector<int64_t> VectorLayout::tileArrayImplicitShape(
SmallVector<int64_t> VectorLayout::tileArrayImplicitShape(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
const std::array<int64_t, 2> vreg_slice = vregSlice(target_shape);
llvm::SmallVector<int64_t> tiles_shape = implicitShape(shape);
SmallVector<int64_t> tiles_shape = implicitShape(shape);
tiles_shape[tiles_shape.size() - 2] = llvm::divideCeil(
offsets_[0].value_or(0) + tiles_shape[tiles_shape.size() - 2],
vreg_slice[0]);
Expand All @@ -479,10 +476,10 @@ llvm::SmallVector<int64_t> VectorLayout::tileArrayImplicitShape(
return tiles_shape;
}

llvm::SmallVector<int64_t> VectorLayout::tileArrayShape(
SmallVector<int64_t> VectorLayout::tileArrayShape(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
llvm::SmallVector<int64_t> tiles_shape =
SmallVector<int64_t> tiles_shape =
tileArrayImplicitShape(shape, target_shape);
// Remove the implicit dimension --- it's always of size 1.
switch (implicit_dim_) {
Expand Down Expand Up @@ -521,11 +518,11 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(
break;
}

const llvm::SmallVector<int64_t> tiles_implicit_shape =
const SmallVector<int64_t> tiles_implicit_shape =
tileArrayImplicitShape(full_shape, target_shape);
const int64_t ns = tiles_implicit_shape[tiles_implicit_shape.size() - 2];
const int64_t nl = tiles_implicit_shape[tiles_implicit_shape.size() - 1];
const llvm::SmallVector<int64_t> implicit_shape = implicitShape(full_shape);
const SmallVector<int64_t> implicit_shape = implicitShape(full_shape);
const int64_t is = implicit_shape[implicit_shape.size() - 2];
const int64_t il = implicit_shape[implicit_shape.size() - 1];

Expand Down Expand Up @@ -718,8 +715,8 @@ std::optional<VectorLayout> VectorLayout::join(const VectorLayout& l,
return VectorLayout(l.bitwidth_, offsets, l.tiling_, l.implicit_dim_);
}

std::optional<VectorLayout> VectorLayout::parse(llvm::StringRef* data) {
llvm::StringRef local(*data);
std::optional<VectorLayout> VectorLayout::parse(StringRef* data) {
StringRef local(*data);
int8_t bitwidth;
LayoutOffsets offsets;
std::array<int64_t, 2> tiling;
Expand Down Expand Up @@ -797,7 +794,7 @@ std::optional<Layout> parseLayout(mlir::AsmParser& parser) {
if (layout_str == "none") {
return kNoLayout;
}
llvm::StringRef ref(layout_str);
StringRef ref(layout_str);
if (auto layout = VectorLayout::parse(&ref); ref.empty()) {
return *layout;
}
Expand Down
10 changes: 4 additions & 6 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ limitations under the License.

#include "absl/log/check.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -270,10 +268,10 @@ class VectorLayout {
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
}

llvm::SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;

private:
llvm::SmallVector<int64_t> tileArrayImplicitShape(
SmallVector<int64_t> tileArrayImplicitShape(
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;

public:
Expand All @@ -288,7 +286,7 @@ class VectorLayout {
//
// Args:
// shape: The shape of the full vector this layout applies to.
llvm::SmallVector<int64_t> tileArrayShape(
SmallVector<int64_t> tileArrayShape(
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;

// Returns the bounds of the given tile that hold useful data.
Expand Down Expand Up @@ -383,7 +381,7 @@ class VectorLayout {
const VectorLayout &r,
ArrayRef<int64_t> shape);

static std::optional<VectorLayout> parse(llvm::StringRef *data);
static std::optional<VectorLayout> parse(StringRef *data);

// Check conditions that depend on the target shape. Invariants that are
// independent of it are checked in the constructor.
Expand Down
5 changes: 3 additions & 2 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/hash/hash.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -103,7 +104,7 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess())) {
return {};
}
llvm::SmallVector<xla::Tile, 2> tiles;
SmallVector<xla::Tile, 2> tiles;
int64_t size;
while (succeeded(parser.parseOptionalLParen())) {
xla::Tile &tile = tiles.emplace_back();
Expand All @@ -121,7 +122,7 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
tile.add_dimensions(size);
}
}
llvm::SmallVector<int64_t, 2> tile_strides;
SmallVector<int64_t, 2> tile_strides;
int64_t stride;
if (failed(parser.parseComma())) {
return {};
Expand Down
9 changes: 5 additions & 4 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -32,7 +33,7 @@ namespace tpu {
LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op,
PatternRewriter &rewriter) {
RollVectorsOp roll_op =
llvm::dyn_cast_or_null<RollVectorsOp>(op.getOperand().getDefiningOp());
dyn_cast_or_null<RollVectorsOp>(op.getOperand().getDefiningOp());
if (!roll_op) {
return failure();
}
Expand Down Expand Up @@ -150,8 +151,8 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op,
int target_index = target_shape.size() - 1;
auto old_layout = dyn_cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
auto target_strides = old_layout.getTileStrides();
llvm::SmallVector<int64_t> tile_strides(target_strides.begin(),
target_strides.end());
SmallVector<int64_t> tile_strides(target_strides.begin(),
target_strides.end());
// We want to remove all strides that correspond to squeezed dimensions and
// update the corresponding output layout.
while (source_index >= 0 || target_index >= 0) {
Expand Down
4 changes: 1 addition & 3 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#include <utility>
#include <vector>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
Expand Down Expand Up @@ -3020,7 +3018,7 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
// replicated result
) {
// First, insert the new singleton lane dimension.
llvm::SmallVector<int64_t> s(src_shape);
SmallVector<int64_t> s(src_shape);
s.push_back(1);
xla::Array<Value> dst_vregs_local(
layout_out.tileArrayShape(s, ctx.target_shape));
Expand Down
9 changes: 3 additions & 6 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ limitations under the License.
#include <utility>
#include <variant>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
Expand Down Expand Up @@ -144,8 +142,7 @@ class VectorLayoutInferer {
has_vector_io |= r.getType().isa<VectorType>();
}
if (!has_vector_io && any_op.getRegions().empty()) {
llvm::SmallVector<Layout, 4> in_layout(any_op.getNumOperands(),
kNoLayout);
SmallVector<Layout, 4> in_layout(any_op.getNumOperands(), kNoLayout);
if (any_op.getNumResults() == 0) {
setInLayout(&any_op, in_layout);
} else if (any_op.getNumResults() == 1) {
Expand Down Expand Up @@ -412,7 +409,7 @@ class VectorLayoutInferer {
auto then_yield = op.thenBlock()->getTerminator();
TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(),
"scf if results and then branch yield operands do not match");
llvm::SmallVector<Layout, 4> result_layout;
SmallVector<Layout, 4> result_layout;
result_layout.reserve(then_yield->getNumOperands());
for (const auto &operand : then_yield->getOperands()) {
if (operand.getType().isSignlessIntOrIndexOrFloat()) {
Expand Down Expand Up @@ -482,7 +479,7 @@ class VectorLayoutInferer {
op->getNumOperands() == 3 + op.getNumResults(),
"expected num_operands is equal to 3 + num_results in scf.for");

llvm::SmallVector<Layout, 4> in_layouts;
SmallVector<Layout, 4> in_layouts;
in_layouts.reserve(op->getNumOperands());
in_layouts.push_back(kNoLayout); // Lower bound.
in_layouts.push_back(kNoLayout); // Upper bound.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct LinalgVectorizationPass

// We do not want to apply the vector patterns above to the ops that are
// unrelated to the original linalg op.
llvm::SmallVector<Operation *> linalgOps;
SmallVector<Operation *> linalgOps;
func.walk([&](linalg::LinalgOp op) { linalgOps.push_back(op); });
if (failed(applyOpPatternsAndFold(linalgOps, std::move(patterns)))) {
return signalPassFailure();
Expand Down
9 changes: 4 additions & 5 deletions jaxlib/mosaic/dialect/tpu/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <cstdint>
#include <type_traits>

#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -63,12 +62,12 @@ FailureOr<int8_t> getTypeBitwidth(Type ty) {
}

template <typename T>
llvm::ArrayRef<std::remove_const_t<T>> toArrayRef(absl::Span<T> span) {
return llvm::ArrayRef<std::remove_const_t<T>>(span.data(), span.size());
ArrayRef<std::remove_const_t<T>> toArrayRef(absl::Span<T> span) {
return ArrayRef<std::remove_const_t<T>>(span.data(), span.size());
}
template <typename T, std::size_t N>
llvm::ArrayRef<std::remove_const_t<T>> toArrayRef(std::array<T, N> array) {
return llvm::ArrayRef<std::remove_const_t<T>>(array.data(), array.size());
ArrayRef<std::remove_const_t<T>> toArrayRef(std::array<T, N> array) {
return ArrayRef<std::remove_const_t<T>>(array.data(), array.size());
}

inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder,
Expand Down

0 comments on commit 75cdef7

Please sign in to comment.