@@ -1372,54 +1372,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
1372
1372
return mlir::success ();
1373
1373
}
1374
1374
1375
- LogicalResult tosa::TransposeOp::getConstantPerms (SmallVector<int32_t > &perms) {
1376
- // Perms must be constants.
1377
- DenseIntElementsAttr permsAttr;
1378
- if (!matchPattern (getPerms (), m_Constant (&permsAttr)))
1379
- return failure ();
1380
-
1381
- perms.clear ();
1382
- for (auto v : permsAttr.getValues <APInt>())
1383
- perms.push_back (v.getSExtValue ());
1384
-
1385
- return success ();
1386
- }
1387
-
1388
1375
LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
1389
1376
MLIRContext *context, ::std::optional<Location> location,
1390
1377
TransposeOp::Adaptor adaptor,
1391
1378
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1392
1379
ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1393
- ShapeAdaptor permsShape (adaptor.getPerms ().getType ());
1394
-
1395
- // We cannot infer anything from a rank-0 "permutation" tensor.
1396
- if (permsShape.hasRank () && permsShape.getRank () == 0 )
1397
- return failure ();
1398
1380
1399
1381
// If input rank and permutation length is unknown, the output rank is
1400
1382
// unknown.
1401
- if (!inputShape.hasRank () || !permsShape.hasRank () ||
1402
- permsShape.isDynamicDim (0 )) {
1383
+ if (!inputShape.hasRank ()) {
1403
1384
inferredReturnShapes.push_back (ShapedTypeComponents ());
1404
1385
return success ();
1405
1386
}
1406
1387
1388
+ const auto inputRank = inputShape.getRank ();
1389
+
1407
1390
// This would imply the number of permutations does not match the rank of
1408
1391
// the input which is illegal.
1409
- if (permsShape. getDimSize ( 0 ) != inputShape. getRank ( )) {
1392
+ if (adaptor. getPerms (). size () != static_cast < size_t >(inputRank )) {
1410
1393
return failure ();
1411
1394
}
1412
1395
1413
1396
SmallVector<int64_t > outputShape;
1414
1397
// Rank-0 means no permutations matter.
1415
- if (inputShape. getRank () == 0 ) {
1398
+ if (inputRank == 0 ) {
1416
1399
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1417
1400
return success ();
1418
1401
}
1419
1402
1420
1403
// Check whether the input dimensions are all the same.
1421
1404
bool allTheSame = true ;
1422
- for (int i = 1 , s = inputShape. getRank () ; i < s; i++) {
1405
+ for (int i = 1 , s = inputRank ; i < s; i++) {
1423
1406
if (inputShape.getDimSize (0 ) != inputShape.getDimSize (i)) {
1424
1407
allTheSame = false ;
1425
1408
break ;
@@ -1429,34 +1412,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1429
1412
// If all of the input dimensions are the same we don't care about the
1430
1413
// permutation.
1431
1414
if (allTheSame) {
1432
- outputShape.resize (inputShape. getRank () , inputShape.getDimSize (0 ));
1415
+ outputShape.resize (inputRank , inputShape.getDimSize (0 ));
1433
1416
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1434
1417
return success ();
1435
1418
}
1436
1419
1437
- outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
1438
- // If the permuations are a constant we can directly determine the output
1439
- // shape.
1440
- DenseIntElementsAttr attr;
1441
- if (matchPattern (adaptor.getPerms (), m_Constant (&attr)) &&
1442
- attr.getType ().getRank () == 1 ) {
1443
- ShapeAdaptor permShape = attr;
1444
- // Constant permutation must be the same length as the input rank.
1445
- if (inputShape.getRank () != permShape.getRank ())
1446
- return emitOptionalError (location,
1447
- " constant permutation must be the same length"
1448
- " as the input rank" );
1449
-
1450
- // Constant permutation values must be within the input rank.
1451
- for (int i = 0 , e = inputShape.getRank (); i < e; i++) {
1452
- if (inputShape.getRank () <= permShape.getDimSize (i))
1453
- return failure ();
1454
- }
1420
+ outputShape.resize (inputRank, ShapedType::kDynamic );
1455
1421
1456
- outputShape.reserve (inputShape.getRank ());
1457
- for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1458
- outputShape[i] = inputShape.getDimSize (permShape.getDimSize (i));
1459
- }
1422
+ // Constant permutation values must be within the input rank.
1423
+ if (llvm::any_of (adaptor.getPerms (),
1424
+ [inputRank](const auto i) { return i >= inputRank; }))
1425
+ return failure ();
1426
+
1427
+ outputShape.reserve (inputRank);
1428
+ for (int i = 0 , s = inputRank; i < s; i++) {
1429
+ outputShape[i] = inputShape.getDimSize (adaptor.getPerms ()[i]);
1460
1430
}
1461
1431
1462
1432
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -1465,75 +1435,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1465
1435
1466
1436
LogicalResult tosa::TransposeOp::verify () {
1467
1437
TensorType inputType = getInput1 ().getType ();
1468
- TensorType permType = getPerms ().getType ();
1469
1438
TensorType outputType = getOutput ().getType ();
1439
+ const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
1470
1440
1471
- if (permType.hasRank () && permType.getRank () != 1 )
1472
- return emitOpError ()
1473
- << " expected permutation tensor to be rank 1 but got rank "
1474
- << permType.getRank ();
1475
- if (inputType.hasRank () && permType.hasRank ())
1476
- if (!permType.isDynamicDim (0 ) &&
1477
- permType.getDimSize (0 ) != inputType.getRank ())
1478
- return emitOpError () << " expected permutation tensor dim 0 to have size "
1479
- << inputType.getRank ()
1480
- << " (input rank) but got size "
1481
- << permType.getDimSize (0 );
1441
+ if (inputType.hasRank () &&
1442
+ constantPerms.size () != static_cast <size_t >(inputType.getRank ()))
1443
+ return emitOpError () << " expected perms attribute to have size "
1444
+ << inputType.getRank () << " (input rank) but got size "
1445
+ << constantPerms.size ();
1482
1446
if (inputType.hasRank () && outputType.hasRank () &&
1483
1447
inputType.getRank () != outputType.getRank ())
1484
1448
return emitOpError ()
1485
1449
<< " expected input tensor rank to equal result tensor rank" ;
1486
- if (outputType.hasRank () && permType.hasRank ())
1487
- if (!permType.isDynamicDim (0 ) &&
1488
- permType.getDimSize (0 ) != outputType.getRank ())
1489
- return emitOpError () << " expected permutation tensor dim 0 to have size "
1490
- << outputType.getRank ()
1491
- << " (output rank) but got size "
1492
- << permType.getDimSize (0 );
1493
-
1494
- SmallVector<int32_t > constantPerms;
1495
- if (succeeded (getConstantPerms (constantPerms))) {
1496
- // Assert that the permutation tensor has a rank, which means that the
1497
- // rank has been verified above.
1498
- assert (permType.hasRank () &&
1499
- " Unexpectedly found permutation tensor without rank" );
1500
- if (!llvm::all_of (constantPerms,
1501
- [&constantPerms](int32_t s) {
1502
- return s >= 0 &&
1503
- static_cast <size_t >(s) < constantPerms.size ();
1504
- }) ||
1505
- !isPermutationVector (llvm::to_vector (llvm::map_range (
1506
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
1507
- return emitOpError () << " expected valid permutation tensor" ;
1508
-
1509
- // Verify that the types of the input and output tensors are properly
1510
- // permuted.
1511
- if (inputType.hasRank () && outputType.hasRank ()) {
1512
- assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1513
- inputType.getRank () == outputType.getRank ());
1514
-
1515
- for (auto i = 0 ; i < outputType.getRank (); i++) {
1516
- if (inputType.isDynamicDim (constantPerms[i]) ||
1517
- outputType.isDynamicDim (i))
1518
- continue ;
1519
-
1520
- if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1521
- return emitOpError ()
1522
- << " expected output tensor dim " << i << " to match "
1523
- << " input dim " << constantPerms[i] << " with value of "
1524
- << inputType.getDimSize (constantPerms[i]);
1525
- }
1450
+ if (outputType.hasRank () &&
1451
+ constantPerms.size () != static_cast <size_t >(outputType.getRank ()))
1452
+ return emitOpError () << " expected perms attribute to have size "
1453
+ << outputType.getRank ()
1454
+ << " (output rank) but got size "
1455
+ << constantPerms.size ();
1456
+
1457
+ if (!llvm::all_of (constantPerms,
1458
+ [&constantPerms](int32_t s) {
1459
+ return s >= 0 &&
1460
+ static_cast <size_t >(s) < constantPerms.size ();
1461
+ }) ||
1462
+ !isPermutationVector (llvm::to_vector (llvm::map_range (
1463
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
1464
+ return emitOpError () << " expected valid permutation indices" ;
1465
+
1466
+ // Verify that the types of the input and output tensors are properly
1467
+ // permuted.
1468
+ if (inputType.hasRank () && outputType.hasRank ()) {
1469
+ assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1470
+ inputType.getRank () == outputType.getRank ());
1471
+
1472
+ for (auto i = 0 ; i < outputType.getRank (); i++) {
1473
+ if (inputType.isDynamicDim (constantPerms[i]) ||
1474
+ outputType.isDynamicDim (i))
1475
+ continue ;
1476
+
1477
+ if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1478
+ return emitOpError ()
1479
+ << " expected output tensor dim " << i << " to match "
1480
+ << " input dim " << constantPerms[i] << " with value of "
1481
+ << inputType.getDimSize (constantPerms[i]);
1526
1482
}
1527
1483
}
1484
+
1528
1485
return success ();
1529
1486
}
1530
1487
1531
1488
LogicalResult TransposeOp::reifyResultShapes (
1532
1489
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1533
1490
1534
- SmallVector<int32_t > transposePerms;
1535
- if (getConstantPerms (transposePerms).failed ())
1536
- return failure ();
1491
+ const llvm::ArrayRef<int32_t > transposePerms = getPerms ();
1537
1492
1538
1493
Value input = getInput1 ();
1539
1494
auto inputType = cast<TensorType>(input.getType ());
0 commit comments