Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@058222b23166
Browse files Browse the repository at this point in the history
Updates LLVM usage to match
[058222b23166](llvm/llvm-project@058222b23166)

PiperOrigin-RevId: 567628854
  • Loading branch information
tensorflower-gardener authored and TensorFlow MLIR Team committed Sep 22, 2023
1 parent d5a17b8 commit 4abaffe
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 56 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ http_archive(
],
)

LLVM_COMMIT = "3b0f812b9af459c4f857e4a7ffffa01f7a21446e"
LLVM_COMMIT = "058222b2316615194c089f2bc68d11341f39d26e"

LLVM_SHA256 = "47f8c81275437fb4ebcf0286125d683fb946cbddf0b725dc61d786141b32bc08"
LLVM_SHA256 = "99d3c38eb11dee8f00bd74b69152d961ab73cf4488842f6120e81342eeb94a3b"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
3b0f812b9af459c4f857e4a7ffffa01f7a21446e
058222b2316615194c089f2bc68d11341f39d26e

6 changes: 2 additions & 4 deletions gml_st/IR/gml_st_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ def GMLST_FusionOp : GMLST_Op<"fusion", [
YieldOp getTerminator();

// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t numOperands = this->getNumOperands();
int64_t numInits = this->getInits().size();
return {numOperands - numInits, numOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitsMutable();
}
}];
}
Expand Down
4 changes: 2 additions & 2 deletions gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ bool allowedToFuse(Operation* consumerOp, Operation* producerOp) {
auto dstStyleOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstStyleOp) return false;

if (llvm::any_of(dstStyleOp.getDpsInitOperands(), [&](OpOperand* operand) {
return operand->get().getDefiningOp() == producerOp;
if (llvm::any_of(dstStyleOp.getDpsInits(), [&](Value operand) {
return operand.getDefiningOp() == producerOp;
}))
return true;
}
Expand Down
10 changes: 6 additions & 4 deletions gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ Operation *findRootElementwiseOp(Operation *op, FusionFilterFn fusionFilterFn) {
if (hasLabel(owner, kTransformedLabel)) continue;
if (hasLabel(owner, kFusionPlanningLabel)) continue;
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(owner)) {
if (llvm::is_contained(dpsOp.getDpsInitOperands(), &use)) continue;
SmallVector<OpOperand *> opOperands = llvm::to_vector(llvm::map_range(
dpsOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
if (llvm::is_contained(opOperands, &use)) continue;
}
curOp = owner;
rootOp = curOp;
Expand Down Expand Up @@ -205,9 +207,9 @@ FusionCluster findElementwiseCluster(Operation *rootOp,
// Add tensor.empty ops to the cluster.
for (auto *op : resultOps) {
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op)) {
for (auto &operand : dpsOp.getDpsInitOperands()) {
if (auto emptyOp = dyn_cast_or_null<tensor::EmptyOp>(
operand->get().getDefiningOp()))
for (auto operand : dpsOp.getDpsInits()) {
if (auto emptyOp =
dyn_cast_or_null<tensor::EmptyOp>(operand.getDefiningOp()))
fusionCluster.operations.insert(emptyOp);
}
}
Expand Down
2 changes: 1 addition & 1 deletion gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ LogicalResult validateOp(linalg::ReduceOp reduceOp, PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
reduceOp, "expects 1 reduction dimension element. 0 or > 1 received.");
}
OpOperandVector operands = reduceOp.getDpsInputOperands();
SmallVector<OpOperand *> operands = reduceOp.getDpsInputOperands();
if (operands.size() != 1) {
return rewriter.notifyMatchFailure(reduceOp,
"expects 1 operand. 0 or > 1 received.");
Expand Down
4 changes: 2 additions & 2 deletions gml_st/transforms/fusion/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ SmallVector<Value> getRootOpInitOperands(PatternRewriter& rewriter,

SmallVector<Value> initOperands;

for (auto* operand : dstStyleOp.getDpsInitOperands()) {
initOperands.push_back(getTiedSourceOp(rewriter, operand, fusionCluster));
for (OpOperand& operand : dstStyleOp.getDpsInitsMutable()) {
initOperands.push_back(getTiedSourceOp(rewriter, &operand, fusionCluster));
}

return initOperands;
Expand Down
11 changes: 7 additions & 4 deletions gml_st/transforms/scalarization/scalarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,13 @@ LogicalResult scalarizeLinalgOp(LinalgOp linalgOp, PatternRewriter &rewriter) {
if (isa<linalg::FillOp>(linalgOp)) {
if (llvm::all_of(linalgOp->getUses(), [&](OpOperand &use) {
Operation *user = use.getOwner();
return isa<DestinationStyleOpInterface>(user) &&
llvm::is_contained(cast<DestinationStyleOpInterface>(user)
.getDpsInitOperands(),
&use);
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
SmallVector<OpOperand *> opOperands = llvm::to_vector(
llvm::map_range(dpsOp.getDpsInitsMutable(),
[](OpOperand &o) { return &o; }));
return llvm::is_contained(opOperands, &use);
}
return false;
}))
return failure();
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3b0f812b9af459c4f857e4a7ffffa01f7a21446e
058222b2316615194c089f2bc68d11341f39d26e
29 changes: 14 additions & 15 deletions thlo/IR/thlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,17 @@ SmallVector<Range> getIterationDomainForTensor(OpBuilder &b, Location loc,
static void getDstStyleOpEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange results, const OpOperandVector &inputOperands,
const OpOperandVector &outputOperands) {
for (auto *operand : inputOperands) {
if (!operand->get().getType().isa<MemRefType>()) continue;
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
ValueRange results, ValueRange inputOperands, ValueRange outputOperands) {
for (auto operand : inputOperands) {
if (!operand.getType().isa<MemRefType>()) continue;
effects.emplace_back(MemoryEffects::Read::get(), operand,
SideEffects::DefaultResource::get());
}
for (auto *operand : outputOperands) {
if (!operand->get().getType().isa<MemRefType>()) continue;
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
for (auto operand : outputOperands) {
if (!operand.getType().isa<MemRefType>()) continue;
effects.emplace_back(MemoryEffects::Read::get(), operand,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
effects.emplace_back(MemoryEffects::Write::get(), operand,
SideEffects::DefaultResource::get());
}
}
Expand Down Expand Up @@ -557,7 +556,7 @@ void ConcatenateOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -713,7 +712,7 @@ void DynamicBroadcastInDimOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -879,7 +878,7 @@ void ScatterOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -981,7 +980,7 @@ void GatherOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1231,7 +1230,7 @@ void SortOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1354,7 +1353,7 @@ void ReverseOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDstStyleOpEffectsImpl(effects, getOperation()->getResults(),
getDpsInputOperands(), getDpsInitOperands());
getDpsInputs(), getDpsInits());
}

} // namespace thlo
Expand Down
30 changes: 12 additions & 18 deletions thlo/IR/thlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def THLO_ConcatenateOp : THLO_DstStyleOp<"concatenate", [

let extraClassDeclaration = [{
// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitMutable();
}
}];
}
Expand Down Expand Up @@ -130,9 +129,8 @@ def THLO_DynamicBroadcastInDimOp : THLO_DstStyleOp<"dynamic_broadcast_in_dim", [

let extraClassDeclaration = [{
// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitMutable();
}
}];
}
Expand Down Expand Up @@ -176,9 +174,8 @@ def THLO_GatherOp : THLO_DstStyleOp<"gather", [

let extraClassDeclaration = [{
// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitMutable();
}
}];
}
Expand Down Expand Up @@ -240,9 +237,8 @@ def THLO_ScatterOp : THLO_DstStyleOp<"scatter", [
int64_t getIndicesCount() { return getIndices().getType().getDimSize(0); }

// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitMutable();
}
}];
}
Expand Down Expand Up @@ -298,9 +294,8 @@ def THLO_SortOp : THLO_DstStyleOp<"sort", [

let extraClassDeclaration = [{
// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - getInits().size(), getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitsMutable();
}
}];
}
Expand Down Expand Up @@ -334,9 +329,8 @@ def THLO_ReverseOp : THLO_DstStyleOp<"reverse", [

let extraClassDeclaration = [{
// Implement method necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
mlir::MutableOperandRange getDpsInitsMutable() {
return getInitMutable();
}
}];
}
Expand Down
4 changes: 2 additions & 2 deletions tools/mlir_interpreter/dialects/vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ InterpreterValue extract(InterpreterState& state, vector::ExtractOp extract,
const InterpreterValue& vector) {
auto result = vector;
auto& resultView = result.view();
for (int64_t offset : extract.getPosition()) {
for (int64_t offset : extract.getStaticPosition()) {
state.checkSuccess(resultView.slice(0, offset), "index out of bounds");
}
return resultView.rank() == 0 ? result.extractElement({}) : result;
Expand Down Expand Up @@ -374,7 +374,7 @@ InterpreterValue insert(InterpreterState& state, vector::InsertOp insert,
auto result = dst.clone();
auto resultSlice = result;
auto& resultSliceView = resultSlice.view();
for (int64_t offset : insert.getPosition()) {
for (int64_t offset : insert.getStaticPosition()) {
state.checkSuccess(resultSliceView.slice(0, offset), "index out of bounds");
}
resultSlice.fill([&](auto indices) { return src.extractElement(indices); });
Expand Down

0 comments on commit 4abaffe

Please sign in to comment.