@@ -794,7 +794,7 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
794
794
// R - number of rows
795
795
// C - number of columns
796
796
// VF - VNNI Factor
797
- #define DEFINE_GET_COORD_ROWPACKED (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
797
+ #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
798
798
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
799
799
int sg_size = get_sub_group_size(); \
800
800
int wi_id = get_sub_group_local_id(); \
@@ -807,67 +807,32 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
807
807
return result; \
808
808
}
809
809
810
- #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , R , C , slices ) \
811
- INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
812
- int sg_size = get_sub_group_size(); \
813
- int wi_id = get_sub_group_local_id(); \
814
- int elems_per_slice = (R * C / sg_size) / slices; \
815
- int slice_cols = (C / slices); \
816
- int sg_cols_per_wi = slice_cols / sg_size; \
817
- int row = (index % elems_per_slice) / sg_cols_per_wi; \
818
- int col = wi_id + ((index % elems_per_slice) % sg_cols_per_wi) * sg_size + (index / elems_per_slice * slice_cols); \
819
- int2 result = (int2)(row, col); \
820
- return result; \
821
- }
822
-
823
810
// ------ PVC -------
824
- // DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF)
825
- // DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices)
826
- // int8
827
- DEFINE_GET_COORD_ROWPACKED (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
828
- DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 16 , 1 )
829
-
830
- // 16bit A
831
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 16 , 1 )
832
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 8 , 16 , 1 )
833
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 1 )
834
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 32 , 1 )
835
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 16 , 1 )
836
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 32 , 2 )
837
-
838
- // 16bit PackedB
839
- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 16 , 1 )
840
- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 64 , 4 )
841
- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 64 , 4 )
842
-
843
- // 16bit Row_major B
844
- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 16 , 1 )
845
- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 64 , 4 )
846
- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 32 , 64 , 4 )
811
+ // layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
812
+ //int8
813
+ DEFINE_GET_COORD (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
814
+ DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 32 , 16 , 4 )
847
815
848
- // Accumulator
849
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 8 , 16 , 1 )
850
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 16 , 16 , 1 )
851
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 64 , 4 )
852
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 1 , 64 , 4 )
816
+ //bfloat16
817
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 8 , 16 , 1 )
818
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 16 , 1 )
819
+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 16 , 16 , 2 )
853
820
854
- // Accumulator 16bit
855
- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 8 , 16 , 1 )
856
- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 16 , 16 , 1 )
857
- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 32 , 64 , 2 )
858
- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 1 , 64 , 4 )
821
+ // Accumulator
822
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 8 , 16 , 1 )
823
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 16 , 16 , 1 )
859
824
860
825
// --------- XMX8 ------------
861
826
//int8
862
- DEFINE_GET_COORD_ROWPACKED (PackedA , , 8 , 32 , 8 , 32 , 1 )
863
- DEFINE_GET_COORD_ROWPACKED (PackedB , , 8 , 32 , 32 , 8 , 4 )
827
+ DEFINE_GET_COORD (PackedA , , 8 , 32 , 8 , 32 , 1 )
828
+ DEFINE_GET_COORD (PackedB , , 8 , 32 , 32 , 8 , 4 )
864
829
865
830
//bfloat16
866
- DEFINE_GET_COORD_ROWPACKED (PackedA , , 16 , 32 , 8 , 16 , 1 )
867
- DEFINE_GET_COORD_ROWPACKED (PackedB , , 16 , 32 , 16 , 8 , 2 )
831
+ DEFINE_GET_COORD (PackedA , , 16 , 32 , 8 , 16 , 1 )
832
+ DEFINE_GET_COORD (PackedB , , 16 , 32 , 16 , 8 , 2 )
868
833
869
834
// Accumulator
870
- DEFINE_GET_COORD_ROWPACKED (Accumulator , , 32 , 32 , 8 , 8 , 1 )
835
+ DEFINE_GET_COORD (Accumulator , , 32 , 32 , 8 , 8 , 1 )
871
836
872
837
/* experimental large slice support: */
873
838
0 commit comments