@@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase {
380
380
// /
381
381
// / Tests with both undef and dense input.
382
382
// /
383
- class MergerTest3T1LU : public MergerTestBase {
383
+
384
+ class MergerTest4T1LU : public MergerTestBase {
384
385
protected:
385
386
// Our three tensors (two inputs, one output).
386
- const unsigned t0 = 0 , t1 = 1 , t2 = 2 ;
387
+ const unsigned t0 = 0 , t1 = 1 , t2 = 2 , t3 = 3 ;
387
388
388
389
// Our single loop.
389
390
const unsigned l0 = 0 ;
390
391
391
- MergerTest3T1LU () : MergerTestBase(3 , 1 ) {
392
+ MergerTest4T1LU () : MergerTestBase(4 , 1 ) {
392
393
// Tensor 0: undef input vector.
393
394
merger.addExp (Kind::kTensor , t0, -1u );
394
395
merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
@@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase {
397
398
merger.addExp (Kind::kTensor , t1, -1u );
398
399
merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kDense ));
399
400
400
- // Tensor 2: dense output vector.
401
+ // Tensor 2: undef input vector.
401
402
merger.addExp (Kind::kTensor , t2, -1u );
402
- merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kDense ));
403
+ merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kUndef ));
404
+
405
+ // Tensor 3: dense output vector.
406
+ merger.addExp (Kind::kTensor , t3, -1u );
407
+ merger.setDimLevelFormat (t3, l0, DimLevelFormat (DimLvlType::kDense ));
408
+ }
409
+ };
410
+
411
+ // /
412
+ // / Tests with operation on sparse output.
413
+ // /
414
+
415
+ class MergerTest3T1L_SO : public MergerTestBase {
416
+ protected:
417
+ // Our three tensors (two inputs, one output, one synthetic).
418
+ const unsigned t0 = 0 , t1 = 1 , t2 = 2 , t3 = 3 ;
419
+
420
+ // Our single loop.
421
+ const unsigned l0 = 0 ;
422
+
423
+ MergerTest3T1L_SO () : MergerTestBase(3 , 1 ) {
424
+ merger.setHasSparseOut (true );
425
+
426
+ // Tensor 0: undef input vector.
427
+ merger.addExp (Kind::kTensor , t0, -1u );
428
+ merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
429
+
430
+ // Tensor 1: undef input vector.
431
+ merger.addExp (Kind::kTensor , t1, -1u );
432
+ merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kUndef ));
433
+
434
+ // Tensor 2: sparse output vector.
435
+ merger.addExp (Kind::kTensor , t2, -1u );
436
+ merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kCompressed ));
403
437
}
404
438
};
439
+
405
440
} // namespace
406
441
407
- // / Vector multiplication (conjunction) of 2 vectors, i.e.;
408
- // / a(i) = b(i) * c(i)
442
+ // / Vector multiplication (conjunction) of 3 vectors, i.e.;
443
+ // / a(i) = b(i) * c(i) * d(i)
409
444
// / which should form the single lattice point
410
445
// / {
411
- // / lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
446
+ // / lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2 ) )
412
447
// / }
413
448
// / after optimization, the dense dimesion should be kept, despite it appears
414
- // / after the undef dimension
449
+ // / in the middle
415
450
// / {
416
- // / lat( i_01_D / (tensor_0 * tensor_1) )
451
+ // / lat( i_01_D / (tensor_0 * tensor_1 * tensor2 ) )
417
452
// / }
418
- #define IMPL_MERGER_TEST_CONJ (OP ) \
419
- TEST_F (MergerTest3T1LU, vector_##OP) { \
420
- auto e = OP##Expr (t0, t1); \
453
+ #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF (CONJ1, CONJ2 ) \
454
+ TEST_F (MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
455
+ auto em = CONJ1##Expr (t0, t1); \
456
+ auto e = CONJ2##Expr (em, t2); \
421
457
auto p0 = tensorPattern (t0); \
422
458
auto p1 = tensorPattern (t1); \
459
+ auto p2 = tensorPattern (t2); \
423
460
auto s = merger.buildLattices (e, l0); \
424
- \
425
461
expectNumLatPoints (s, 1 ); \
426
- expectLatPoint (s, lat (0 ), OP##Pattern (p0, p1), \
427
- loopsToBits ({{l0, t0}, {l0, t1}})); \
428
- \
462
+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
463
+ loopsToBits ({{l0, t0}, {l0, t1}, {l0, t2}})); \
429
464
s = merger.optimizeSet (s); \
430
465
expectNumLatPoints (s, 1 ); \
431
- expectLatPoint (s, lat (0 ), OP ##Pattern (p0, p1), loopsToBits ({{l0, t1}}), \
432
- true ); \
466
+ expectLatPoint (s, lat (0 ), CONJ2 ##Pattern (CONJ1## Pattern ( p0, p1), p2), \
467
+ loopsToBits ({{l0, t1}}), true ); \
433
468
}
434
- FOREVERY_COMMON_CONJ_BINOP (IMPL_MERGER_TEST_CONJ)
435
469
436
- #undef IMPL_MERGER_TEST_CONJ
470
+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
471
+
472
+ #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
473
+
474
+ // / Vector multiplication (conjunction) of 2 vectors, i.e.;
475
+ // / o(i) = b(i) * c(i) * o(i)
476
+ // / which should form the single lattice point (note how a synthetic tensor
477
+ // / i_03_U is created for the sparse output)
478
+ // / {
479
+ // / lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
480
+ // / }
481
+ // / after optimization, the synthetic tensor should be preserved.
482
+ // / {
483
+ // / lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
484
+ // / }
485
+ #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT (CONJ1, CONJ2 ) \
486
+ TEST_F (MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) { \
487
+ auto em = CONJ1##Expr (t0, t1); \
488
+ auto e = CONJ2##Expr (em, t2); \
489
+ auto p0 = tensorPattern (t0); \
490
+ auto p1 = tensorPattern (t1); \
491
+ auto p2 = tensorPattern (t2); \
492
+ auto s = merger.buildLattices (e, l0); \
493
+ expectNumLatPoints (s, 1 ); \
494
+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
495
+ loopsToBits ({{l0, t0}, {l0, t1}, {l0, t3}})); \
496
+ s = merger.optimizeSet (s); \
497
+ expectNumLatPoints (s, 1 ); \
498
+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
499
+ loopsToBits ({{l0, t3}}), true ); \
500
+ }
501
+
502
+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
503
+
504
+ #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
437
505
438
506
// / Vector addition (disjunction) of 2 vectors. i.e.;
439
507
// / a(i) = b(i) + c(i)
0 commit comments