Skip to content

Commit

Permalink
Shuffled checked properties in bli_l3_check.c. (#676)
Browse files Browse the repository at this point in the history
Details:
- Added certain checks for matrix structure to the level-3 operations'
  _check() functions, and slightly reorganized existing checks.
  • Loading branch information
fgvanzee authored Oct 18, 2022
1 parent 9453e0f commit 23f5b8d
Showing 1 changed file with 122 additions and 57 deletions.
179 changes: 122 additions & 57 deletions frame/3/bli_l3_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,22 @@ void bli_gemm_check
const cntx_t* cntx
)
{
//err_t e_val;
err_t e_val;

// Check basic properties of the operation.

bli_gemm_basic_check( alpha, a, b, beta, c, cntx );

// Check object structure.

// NOTE: Can't perform these checks as long as bli_gemm_check() is called
// from bli_l3_int(), which is in the execution path for structured
// level-3 operations such as hemm.
e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

//e_val = bli_check_general_object( a );
//bli_check_error_code( e_val );
e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

//e_val = bli_check_general_object( b );
//bli_check_error_code( e_val );
e_val = bli_check_general_object( c );
bli_check_error_code( e_val );
}

void bli_gemmt_check
Expand All @@ -83,6 +82,14 @@ void bli_gemmt_check

e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );
}

void bli_hemm_check
Expand All @@ -102,10 +109,21 @@ void bli_hemm_check

bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_hermitian_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

e_val = bli_check_general_object( c );
bli_check_error_code( e_val );
}

void bli_herk_check
Expand All @@ -127,18 +145,26 @@ void bli_herk_check

bli_herk_basic_check( alpha, a, &ah, beta, c, cntx );

// Check for real-valued alpha and beta.

e_val = bli_check_real_valued_object( alpha );
bli_check_error_code( e_val );
// Check matrix squareness.

e_val = bli_check_real_valued_object( beta );
e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_hermitian_object( c );
bli_check_error_code( e_val );

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

// Check for real-valued alpha and beta.

e_val = bli_check_real_valued_object( alpha );
bli_check_error_code( e_val );

e_val = bli_check_real_valued_object( beta );
bli_check_error_code( e_val );
}

void bli_her2k_check
Expand All @@ -162,15 +188,26 @@ void bli_her2k_check

bli_her2k_basic_check( alpha, a, &bh, b, &ah, beta, c, cntx );

// Check for real-valued beta.
// Check matrix squareness.

e_val = bli_check_real_valued_object( beta );
e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_hermitian_object( c );
bli_check_error_code( e_val );

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

// Check for real-valued beta.

e_val = bli_check_real_valued_object( beta );
bli_check_error_code( e_val );
}

void bli_symm_check
Expand All @@ -190,10 +227,21 @@ void bli_symm_check

bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_symmetric_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

e_val = bli_check_general_object( c );
bli_check_error_code( e_val );
}

void bli_syrk_check
Expand All @@ -215,10 +263,18 @@ void bli_syrk_check

bli_herk_basic_check( alpha, a, &at, beta, c, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_symmetric_object( c );
bli_check_error_code( e_val );

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );
}

void bli_syr2k_check
Expand All @@ -242,10 +298,21 @@ void bli_syr2k_check

bli_her2k_basic_check( alpha, a, &bt, b, &at, beta, c, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_symmetric_object( c );
bli_check_error_code( e_val );

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );
}

void bli_trmm3_check
Expand All @@ -261,14 +328,25 @@ void bli_trmm3_check
{
err_t e_val;

// Perform checks common to hemm/symm/trmm/trsm.
// Check basic properties of the operation.

bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_triangular_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

e_val = bli_check_general_object( c );
bli_check_error_code( e_val );
}

void bli_trmm_check
Expand All @@ -282,14 +360,22 @@ void bli_trmm_check
{
err_t e_val;

// Perform checks common to hemm/symm/trmm/trsm.
// Check basic properties of the operation.

bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_triangular_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );
}

void bli_trsm_check
Expand All @@ -307,10 +393,18 @@ void bli_trsm_check

bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx );

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check object structure.

e_val = bli_check_triangular_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );
}

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -385,6 +479,14 @@ void bli_gemmt_basic_check

e_val = bli_check_level3_dims( a, b, c );
bli_check_error_code( e_val );

// Check for consistent datatypes.

e_val = bli_check_consistent_object_datatypes( c, a );
bli_check_error_code( e_val );

e_val = bli_check_consistent_object_datatypes( c, b );
bli_check_error_code( e_val );
}

void bli_hemm_basic_check
Expand Down Expand Up @@ -417,11 +519,6 @@ void bli_hemm_basic_check
bli_check_error_code( e_val );
}

// Check matrix squareness.

e_val = bli_check_square_object( a );
bli_check_error_code( e_val );

// Check for consistent datatypes.

e_val = bli_check_consistent_object_datatypes( c, a );
Expand Down Expand Up @@ -452,19 +549,6 @@ void bli_herk_basic_check
e_val = bli_check_level3_dims( a, ah, c );
bli_check_error_code( e_val );

// Check matrix squareness.

e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( ah );
bli_check_error_code( e_val );

// Check for consistent datatypes.

e_val = bli_check_consistent_object_datatypes( c, a );
Expand Down Expand Up @@ -501,25 +585,6 @@ void bli_her2k_basic_check
e_val = bli_check_level3_dims( b, ah, c );
bli_check_error_code( e_val );

// Check matrix squareness.

e_val = bli_check_square_object( c );
bli_check_error_code( e_val );

// Check matrix structure.

e_val = bli_check_general_object( a );
bli_check_error_code( e_val );

e_val = bli_check_general_object( bh );
bli_check_error_code( e_val );

e_val = bli_check_general_object( b );
bli_check_error_code( e_val );

e_val = bli_check_general_object( ah );
bli_check_error_code( e_val );

// Check for consistent datatypes.

e_val = bli_check_consistent_object_datatypes( c, a );
Expand Down Expand Up @@ -586,13 +651,13 @@ void bli_l3_basic_check
e_val = bli_check_object_buffer( alpha );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( a );
e_val = bli_check_object_buffer( beta );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( b );
e_val = bli_check_object_buffer( a );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( beta );
e_val = bli_check_object_buffer( b );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( c );
Expand Down

0 comments on commit 23f5b8d

Please sign in to comment.