diff --git a/lib/Dialect/Tpu/Interfaces/Common/Softmax.cpp b/lib/Dialect/Tpu/Interfaces/Common/Softmax.cpp index cd74e255e..8c7a50413 100644 --- a/lib/Dialect/Tpu/Interfaces/Common/Softmax.cpp +++ b/lib/Dialect/Tpu/Interfaces/Common/Softmax.cpp @@ -212,17 +212,16 @@ LogicalResult tpu::SoftmaxOp::AllowDataSplit(int64_t axis, ArrayAttr tpu::SoftmaxOp::getIndexingMaps() { MLIRContext *ctx = getContext(); - auto out_shape = module::getShape(getOutput()); - auto num_dims = out_shape.size(); int axis = getAxis(); - // shape < 4 not support - if ( num_dims < 3 and axis != num_dims -1) { - return Builder(ctx).getAffineMapArrayAttr({}); + + auto inputMap = AffineMap::getMultiDimIdentityMap(axis, ctx); + auto empty = AffineMap::get(axis, 0, ctx); + SmallVector indexingMaps{inputMap}; + for (int i = 1, n = getNumOperands(); i < n; ++i) { + indexingMaps.push_back(empty); } - AffineMap outMap = AffineMap::getMultiDimIdentityMap(num_dims, ctx); - auto empty_map = AffineMap::get(num_dims, 0, ctx); - SmallVector indexingMaps{outMap, empty_map, empty_map, empty_map, empty_map, empty_map, outMap}; - return Builder(ctx).getAffineMapArrayAttr(indexingMaps); + indexingMaps.push_back(inputMap); + return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); }; bool tpu::SoftmaxOp::support_multi_core() { return module::isSG2380(); }