Skip to content

Commit

Permalink
generic: sycl: add missing type checks on scales
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 authored and dzarukin committed Nov 1, 2024
1 parent 7486ed8 commit 7d85c75
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 48 deletions.
15 changes: 11 additions & 4 deletions src/gpu/generic/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
&& check_formats(src0_d, src1_d, dst_d)
&& attr()->has_default_values(
sm::scales_runtime | sm::post_ops)
&& IMPLICATION(!attr()->scales_.has_default_values(),
check_scales_mask())
&& IMPLICATION(
!attr()->scales_.has_default_values(), scales_ok())
&& sycl_post_ops_t::post_ops_ok(attr())
&& md_dims_in_range(src_md(0))
&& md_dims_in_range(src_md(1))
Expand All @@ -70,10 +70,17 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
private:
status_t init_conf();

bool check_scales_mask() const {
bool scales_ok() const {
const std::vector<int> supported_args
= {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1};
return attr_scales_ok(supported_args);

const auto &scales = attr()->scales_;
bool dt_ok = true;
for (auto arg : supported_args) {
auto &s = scales.get(arg);
dt_ok = dt_ok && is_supported_type(s.data_type_);
}
return dt_ok && attr_scales_ok(supported_args);
}

static bool check_data_types(const memory_desc_wrapper &src0,
Expand Down
37 changes: 23 additions & 14 deletions src/gpu/generic/sycl/ref_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,16 @@ namespace gpu {
namespace generic {
namespace sycl {

static bool check_convolution_data_types(const memory_desc_wrapper &src0,
inline bool check_convolution_data_types(const memory_desc_wrapper &src0,
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
using namespace data_type;

const auto src0_dt = src0.data_type();
const auto src1_dt = src1.data_type();
const auto dst_dt = dst.data_type();

for (auto t : {src0_dt, src1_dt, dst_dt}) {
if (!utils::one_of(t, f32, bf16, f16, s32, s8, u8)) return false;
for (const auto &mdw : {src0, src1, dst}) {
if (!is_supported_type(mdw.data_type())) return false;
}

return true;
}

static bool check_convolution_formats(const memory_desc_wrapper &src0,
inline bool check_convolution_formats(const memory_desc_wrapper &src0,
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
using namespace format_tag;

Expand All @@ -57,7 +51,7 @@ static bool check_convolution_formats(const memory_desc_wrapper &src0,
return true;
}

static bool check_convolution_work_amount(
inline bool check_convolution_work_amount(
const memory_desc_wrapper &weights, dim_t OC) {
auto elems = weights.nelems();
auto work_per_output = elems / OC;
Expand All @@ -66,6 +60,18 @@ static bool check_convolution_work_amount(
return work_per_output < 200000;
}

inline bool check_convolution_scales_types(const primitive_attr_t *attr) {
const std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};

const auto &scales = attr->scales_;
for (auto arg : supported_args) {
auto dt = scales.get(arg).data_type_;
if (!is_supported_type(dt)) { return false; }
}
return true;
}

struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
using gpu::generic::sycl::primitive_t::primitive_t;

Expand All @@ -92,7 +98,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
| sm::zero_points_runtime | sm::post_ops
| sm::sum_dt)
&& IMPLICATION(!attr()->scales_.has_default_values(),
attr_scales_ok())
attr_scales_ok()
&& check_convolution_scales_types(attr()))
&& sycl_post_ops_t::post_ops_ok(attr(), false)
&& set_default_alg_kind(alg_kind::convolution_direct);
if (!ok) return status::unimplemented;
Expand Down Expand Up @@ -149,7 +156,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
&& attr()->has_default_values(sm::scales_runtime
| sm::zero_points_runtime | sm::sum_dt)
&& IMPLICATION(!attr()->scales_.has_default_values(),
attr_scales_ok())
attr_scales_ok()
&& check_convolution_scales_types(attr()))
&& set_default_alg_kind(alg_kind::convolution_direct);
if (!ok) return status::unimplemented;

Expand Down Expand Up @@ -205,7 +213,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
&& attr()->has_default_values(sm::scales_runtime
| sm::zero_points_runtime | sm::sum_dt)
&& IMPLICATION(!attr()->scales_.has_default_values(),
attr_scales_ok())
attr_scales_ok()
&& check_convolution_scales_types(attr()))
&& set_default_alg_kind(alg_kind::convolution_direct);
if (!ok) return status::unimplemented;

Expand Down
30 changes: 21 additions & 9 deletions src/gpu/generic/sycl/ref_layer_normalizations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,31 @@ struct ref_layer_normalization_fwd_t : public gpu::generic::sycl::primitive_t {

const bool ok = is_fwd()
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
&& utils::one_of(
src_md(0)->data_type, f32, bf16, f16, s8, u8)
&& utils::one_of(
dst_md(0)->data_type, f32, bf16, f16, s8, u8)
&& stat_md()->data_type == f32
&& is_supported_type(src_md(0)->data_type)
&& is_supported_type(dst_md(0)->data_type)
&& is_supported_type(stat_md()->data_type)
&& check_scale_shift_data_type({f32, bf16, f16})
&& attr()->has_default_values(sm::scales_runtime)
&& IMPLICATION(
!attr()->scales_.has_default_values(), scales_ok())
&& attr_scales_ok() && set_default_formats_common()
&& md_dims_in_range(src_md());
if (!ok) return status::unimplemented;
return init_conf();
}

bool scales_ok() const {
const std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST};

const auto &scales = attr()->scales_;
for (auto arg : supported_args) {
auto dt = scales.get(arg).data_type_;
if (!is_supported_type(dt)) { return false; }
}
return true;
}

status_t init_conf();
sycl_layer_normalization_conf_t conf_;
};
Expand Down Expand Up @@ -105,10 +117,10 @@ struct ref_layer_normalization_bwd_t : public gpu::generic::sycl::primitive_t {
const bool ok = !is_fwd()
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
&& (diff_dst_md(0)->format_desc.blocking.inner_nblks == 0)
&& utils::one_of(src_md(0)->data_type, f32, bf16)
&& utils::one_of(diff_dst_md(0)->data_type, f32, bf16)
&& utils::one_of(diff_src_md(0)->data_type, f32, bf16)
&& stat_md()->data_type == f32
&& is_supported_type(src_md(0)->data_type)
&& is_supported_type(diff_dst_md(0)->data_type)
&& is_supported_type(diff_src_md(0)->data_type)
&& is_supported_type(stat_md()->data_type)
&& check_scale_shift_data_type({f32, bf16, f16})
&& attr()->has_default_values()
&& set_default_formats_common()
Expand Down
4 changes: 1 addition & 3 deletions src/gpu/generic/sycl/ref_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,14 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
}

bool scales_ok() const {
using namespace data_type;
const std::vector<int> supported_args
= {DNNL_ARG_SRC_0, DNNL_ARG_WEIGHTS_0, DNNL_ARG_DST};

const auto &scales = attr()->scales_;
bool dt_ok = true;
for (auto arg : supported_args) {
auto &s = scales.get(arg);
dt_ok = dt_ok
&& utils::one_of(s.data_type_, s8, s32, f32, f16, bf16);
dt_ok = dt_ok && is_supported_type(s.data_type_);
}
return dt_ok && attr_scales_ok(supported_args);
}
Expand Down
21 changes: 21 additions & 0 deletions src/gpu/generic/sycl/ref_prelu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct ref_prelu_fwd_t : public gpu::generic::sycl::primitive_t {
const bool ok = is_fwd() && set_default_formats()
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
&& (weights_md(0)->format_desc.blocking.inner_nblks == 0)
&& check_data_types(data_d, weights_d, dst_d)
&& md_dims_in_range(src_md())
&& md_dims_in_range(weights_md());

Expand All @@ -63,6 +64,15 @@ struct ref_prelu_fwd_t : public gpu::generic::sycl::primitive_t {

status_t init_conf();
sycl_prelu_conf_t conf_;

static bool check_data_types(const memory_desc_wrapper &src,
const memory_desc_wrapper &wei,
const memory_desc_wrapper &dst) {
for (const auto &mdw : {src, wei, dst}) {
if (!is_supported_type(mdw.data_type())) return false;
}
return true;
}
};

status_t init(impl::engine_t *engine) override;
Expand Down Expand Up @@ -97,6 +107,7 @@ struct ref_prelu_bwd_t : public gpu::generic::sycl::primitive_t {
&& (weights_md(0)->format_desc.blocking.inner_nblks == 0)
&& diff_src_md(0)->data_type == src_md(0)->data_type
&& diff_weights_md(0)->data_type == weights_md(0)->data_type
&& check_data_types(data_d, weights_d, diff_dst_d)
&& md_dims_in_range(diff_src_md())
&& md_dims_in_range(weights_md());

Expand All @@ -113,6 +124,16 @@ struct ref_prelu_bwd_t : public gpu::generic::sycl::primitive_t {
status_t init_reduction(impl::engine_t *engine);
void init_scratchpad();

static bool check_data_types(const memory_desc_wrapper &src,
const memory_desc_wrapper &wei,
const memory_desc_wrapper &dst) {
for (const auto &mdw : {src, wei, dst}) {
if (!is_supported_type(mdw.data_type())) return false;
}

return true;
}

sycl_prelu_conf_t conf_;
bool reduce_diff_weights_ = false;
memory_desc_t scratch_md_;
Expand Down
23 changes: 16 additions & 7 deletions src/gpu/generic/sycl/ref_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {
&& check_formats(src_d, dst_d)
&& attr()->has_default_values(
sm::scales_runtime | sm::post_ops)
&& IMPLICATION(
!attr()->scales_.has_default_values(), scales_ok())
&& sycl_post_ops_t::post_ops_ok(attr())
&& md_dims_in_range(dst_md());
if (!ok) return status::unimplemented;
Expand All @@ -70,13 +72,8 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {

static bool check_data_types(const memory_desc_wrapper &src,
const memory_desc_wrapper &dst) {
using namespace data_type;

const auto src_dt = src.data_type();
const auto dst_dt = dst.data_type();

for (auto t : {src_dt, dst_dt}) {
if (!utils::one_of(t, f32, bf16, f16, s8, u8)) return false;
for (const auto &mdw : {src, dst}) {
if (!is_supported_type(mdw.data_type())) return false;
}

return true;
Expand All @@ -91,6 +88,18 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {
}
return true;
}

bool scales_ok() const {
const std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST};

const auto &scales = attr()->scales_;
for (auto arg : supported_args) {
auto dt = scales.get(arg).data_type_;
if (!is_supported_type(dt)) { return false; }
}
return true;
}
};

status_t init(impl::engine_t *engine) override;
Expand Down
12 changes: 5 additions & 7 deletions src/gpu/generic/sycl/ref_resampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,14 @@ struct ref_resampling_fwd_t : public gpu::generic::sycl::primitive_t {
DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_resampling_fwd_t);

status_t init(impl::engine_t *engine) {
using namespace data_type;
using namespace prop_kind;
using namespace alg_kind;
using sm = primitive_attr_t::skip_mask_t;
const memory_desc_wrapper src_d(src_md(0));
const memory_desc_wrapper dst_d(dst_md(0));

const bool ok = is_fwd()
&& utils::one_of(
src_md(0)->data_type, f32, bf16, f16, s32, s8, u8)
&& utils::one_of(
dst_md(0)->data_type, f32, bf16, f16, s32, s8, u8)
const bool ok = is_fwd() && is_supported_type(src_md(0)->data_type)
&& is_supported_type(dst_md(0)->data_type)
&& attr()->has_default_values(sm::post_ops)
&& set_default_params() == status::success
&& attr_.set_default_formats(dst_md(0)) == status::success
Expand Down Expand Up @@ -92,7 +88,9 @@ struct ref_resampling_bwd_t : public gpu::generic::sycl::primitive_t {
const memory_desc_wrapper diff_dst_d(diff_dst_md(0));
const memory_desc_wrapper diff_src_d(diff_src_md(0));

bool ok = !is_fwd() && set_default_params() == status::success
bool ok = !is_fwd() && is_supported_type(src_md(0)->data_type)
&& is_supported_type(dst_md(0)->data_type)
&& set_default_params() == status::success
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
&& (diff_dst_md(0)->format_desc.blocking.inner_nblks == 0)
&& attr()->has_default_values()
Expand Down
15 changes: 11 additions & 4 deletions src/gpu/generic/sycl/ref_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,34 @@ struct ref_sum_t : public gpu::generic::sycl::primitive_t {
DECLARE_SUM_PD_T("dpcpp:ref:any", ref_sum_t);

status_t init(impl::engine_t *engine) {
using namespace data_type;
using namespace format_tag;

const memory_desc_wrapper dst_d(dst_md());
if (!utils::one_of(dst_d.data_type(), f32, bf16, f16, s8, u8))
if (!is_supported_type(dst_d.data_type()))
return status::unimplemented;
// Block formats are not yet supported
// Dimensions can not be > 6
if (!dst_d.is_plain() || dst_d.ndims() > xpu::sycl::md_t::max_dims)
return status::unimplemented;

const int n = n_inputs();
const auto &scales = attr()->scales_;
for (auto i = 0; i < n; ++i) {
const memory_desc_wrapper src_d(src_md(i));
if (!utils::one_of(src_d.data_type(), f32, bf16, f16, s8, u8))
if (!is_supported_type(src_d.data_type())) {
return status::unimplemented;
}
// Block formats are not yet supported
// Dimensions can not be > 6
if (!src_d.is_plain()
|| src_d.ndims() > xpu::sycl::md_t::max_dims)
|| src_d.ndims() > xpu::sycl::md_t::max_dims) {
return status::unimplemented;
}
if (!attr()->scales_.has_default_values()
&& !is_supported_type(
scales.get(DNNL_ARG_SRC + i).data_type_)) {
return status::unimplemented;
}
}

const bool ok = set_default_params() == status::success
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/generic/sycl/sycl_io_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ namespace gpu {
namespace generic {
namespace sycl {

inline bool is_supported_type(data_type_t dt) {
using namespace data_type;
return utils::one_of(dt, f32, f16, bf16, s32, s8, u8);
}

inline int load_int_value(data_type_t dt, const void *ptr, dim_t idx) {
#define CASE(dt) \
case dt: \
Expand Down

0 comments on commit 7d85c75

Please sign in to comment.