Skip to content

Commit 0083ddb

Browse files
tiger100256-huazhai219
authored andcommitted
[CPU][fix] fix matmul decompress test case for migration v3.8 (#1)
* fix matmul decompress test case Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * save tmp Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * [FORK][FIX] IP weights compression: scalar scale [FORK][FEATURE] InnerProduct primitive: squashed weight decompression Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * [FORK][FIX] IP weights compression: max bcast blocking computation [FORK][FEATURE] InnerProduct primitive: squashed weight decompression * fix compile issue Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * fix crash issue Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * try to fix compare issue Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * contiue fix some accrucy issue Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * fix f4_e2m1 Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * continue to fix f4e2m1 Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * fix confict on smoke_FC_(2|3)D_I8_sparse Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * clean debug and unused code Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> * revert this change, should affect test case Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> --------- Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> Co-authored-by: dmitrygo <dmitry.gorokhov@intel.com> [FORK][FIX] add missing override Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> fix compilation error move quant functions to hpp file [FORK][Fix] Fix condition compilation Signed-off-by: HU Yuan2 <yuan2.hu@intel.com>
1 parent e46a773 commit 0083ddb

26 files changed

+365
-228
lines changed

src/common/memory_desc_wrapper.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,7 @@ struct memory_desc_wrapper : public c_compatible {
393393
if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
394394
return false;
395395
if (has_runtime_dims_or_strides() || has_broadcast()) return false;
396-
return nelems(with_padding) * data_type_size()
397-
/ sub_byte_data_type_multiplier()
396+
return utils::div_up(nelems(with_padding)* data_type_size(), sub_byte_data_type_multiplier())
398397
== size(0, /* include_additional_size = */ false);
399398
}
400399

@@ -714,7 +713,7 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
714713
rhs.padded_dims() + ds, ndims() - ds)
715714
&& custom_cpm(padded_offsets() + ds,
716715
rhs.padded_offsets() + ds, ndims() - ds))
717-
&& IMPLICATION(check_off0, (offset0() == DNNL_RUNTIME_DIM_VAL || rhs.offset0() ==DNNL_RUNTIME_DIM_VAL || offset0() == rhs.offset0()));
716+
&& IMPLICATION(check_off0, (offset0() == DNNL_RUNTIME_DIM_VAL || rhs.offset0() ==DNNL_RUNTIME_DIM_VAL || offset0() == rhs.offset0()));
718717
}
719718

720719
inline bool memory_desc_wrapper::consistent_with(

src/common/primitive_attr_quant.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,22 @@ quant_entry_t quant_entry_t::deserialize(deserializer_t &d) {
5555

5656
std::string quant_entry_t::get_verbose() const {
5757
std::string s;
58-
s.append(std::to_string(mask_));
59-
s.append(":").append(dnnl_dt2str(data_type_));
58+
s.append(std::to_string(get_mask()));
59+
s.append(":").append(dnnl_dt2str(get_data_type()));
60+
s.append(":").append(std::to_string(type_));
61+
s.append(":");
6062
if (group_ndims_ > 0) {
61-
s.append(":")
62-
.append(std::to_string(group_dims_[0]))
63+
s.append(std::to_string(group_dims_[0]))
6364
.append("x")
6465
.append(std::to_string(group_dims_[1]));
6566
}
67+
s.append(":");
68+
if (get_ndims() > 0) {
69+
s.append(std::to_string(get_dims()[0]))
70+
.append("x")
71+
.append(std::to_string(get_dims()[1]));
72+
}
73+
6674
return s;
6775
}
6876

src/common/primitive_attr_quant.hpp

Lines changed: 155 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ struct quant_entry_t : public c_compatible {
5353
}
5454
status_t set(int mask, data_type_t data_type, int group_ndims,
5555
const dims_t group_dims) {
56+
type_ = type_ | DNNL;
57+
is_set_ = true;
5658
mask_ = mask;
5759
data_type_ = data_type;
5860
group_ndims_ = group_ndims;
@@ -61,25 +63,98 @@ struct quant_entry_t : public c_compatible {
6163
}
6264
return status::success;
6365
}
66+
status_t set_scales(const dims_t dims, int ndims, data_type_t data_type = data_type::f32, int mask = 1) {
67+
type_ = type_ | OV_SCALES;
68+
is_set_scale = true;
69+
ndims_scale = ndims;
70+
mask_scale = mask;
71+
data_type_scale = data_type;
72+
if (ndims_scale > 0) {
73+
utils::array_copy(dims_scale, dims, ndims_scale);
74+
}
75+
return status::success;
76+
}
77+
status_t set_zero_points(const dims_t dims, int ndims, data_type_t data_type) {
78+
type_ = type_ | OV_ZERO_POINTS;
79+
is_set_wei = true;
80+
ndims_wei = ndims;
81+
mask_wei = 1;
82+
if (ndims_wei > 0) {
83+
utils::array_copy(dims_wei, dims, ndims_wei);
84+
}
85+
data_type_wei = data_type;
86+
return status::success;
87+
}
88+
status_t set_zero_points(const dims_t dims, int ndims, data_type_t data_type, int mask) {
89+
type_ = type_ | DNNL;
90+
is_set_wei = true;
91+
ndims_wei = ndims;
92+
mask_wei = mask;
93+
if (ndims_wei > 0) {
94+
utils::array_copy(dims_wei, dims, ndims_wei);
95+
group_ndims_ = ndims;
96+
utils::array_copy(group_dims_, dims, group_ndims_);
97+
}
98+
data_type_wei = data_type;
99+
return status::success;
100+
}
64101
status_t set(const quant_entry_t &other) {
65-
return set(other.mask_, other.data_type_, other.group_ndims_,
66-
other.group_dims_);
102+
type_ = other.type_;
103+
is_set_ = other.is_set_;
104+
mask_ = other.mask_;
105+
data_type_ = other.data_type_;
106+
group_ndims_ = other.group_ndims_;
107+
if(group_ndims_ > 0)
108+
utils::array_copy(group_dims_, other.group_dims_, group_ndims_);
109+
is_set_scale = other.is_set_scale;
110+
mask_scale = other.mask_scale;
111+
data_type_scale = other.data_type_scale;
112+
ndims_scale = other.ndims_scale;
113+
if (ndims_scale > 0)
114+
utils::array_cmp(dims_scale, other.dims_scale, ndims_scale);
115+
is_set_wei = other.is_set_wei;
116+
mask_wei = other.mask_wei;
117+
data_type_wei = other.data_type_wei;
118+
ndims_wei = other.ndims_wei;
119+
if(ndims_wei > 0)
120+
utils::array_cmp(dims_wei, other.dims_wei, ndims_wei);
121+
return status::success;
67122
}
68-
69123
quant_entry_t &operator=(const quant_entry_t &rhs) {
70124
auto st = this->set(rhs);
71125
assert(st == status::success);
72126
UNUSED(st);
73127
return *this;
74128
}
75-
76129
bool has_default_values() const { return *this == default_quant_entry(); }
77130
bool has_default_groups() const {
78131
return this->group_ndims_ == default_quant_entry().group_ndims_;
79132
}
80-
81-
int get_mask() const { return mask_; }
82-
data_type_t get_data_type() const { return data_type_; }
133+
int get_mask() const {
134+
if (is_set_wei) return mask_wei;
135+
if (is_set_) return mask_;
136+
if (is_set_scale) return mask_scale;
137+
return INT_MIN;
138+
}
139+
data_type_t get_data_type() const {
140+
if (is_set_wei) return data_type_wei;
141+
if (is_set_) return data_type_;
142+
if (is_set_scale) return data_type_scale;
143+
return data_type::undef;
144+
}
145+
const dims_t& get_dims() const {
146+
if (is_set_wei) return dims_wei;
147+
if (is_set_) return group_dims_;
148+
if (is_set_scale) return dims_scale;
149+
static const dims_t result = {};
150+
return result;
151+
}
152+
int get_ndims() const {
153+
if (is_set_wei) return ndims_wei;
154+
if (is_set_) return group_ndims_;
155+
if (is_set_scale) return ndims_scale;
156+
return 0;
157+
}
83158
dim_t get_group(int d) const {
84159
// If groups were not requested, return `1` for convenience.
85160
if (group_ndims_ == default_quant_entry().group_ndims_) return 1;
@@ -93,13 +168,33 @@ struct quant_entry_t : public c_compatible {
93168
// `gtests/internals/test_comparison_operators` linking requirements which
94169
// mandates bodies to be in the header file.
95170
bool operator==(const quant_entry_t &rhs) const {
96-
return mask_ == rhs.mask_ && data_type_ == rhs.data_type_
171+
bool result = (type_ == rhs.type_ && is_set_ == rhs.is_set_
172+
&& mask_ == rhs.mask_
173+
&& data_type_ == rhs.data_type_
97174
&& group_ndims_ == rhs.group_ndims_
98175
&& IMPLICATION(group_ndims_ > 0,
99-
utils::array_cmp(
100-
group_dims_, rhs.group_dims_, group_ndims_));
176+
utils::array_cmp(
177+
group_dims_, rhs.group_dims_, group_ndims_)));
178+
179+
if (!result) return false;
180+
result = (is_set_scale == rhs.is_set_scale
181+
&& mask_scale == rhs.mask_scale
182+
&& data_type_scale == rhs.data_type_scale
183+
&& ndims_scale == rhs.ndims_scale
184+
&& IMPLICATION(ndims_scale > 0,
185+
utils::array_cmp(
186+
dims_scale, rhs.dims_scale, ndims_scale)));
187+
188+
if (!result) return false;
189+
result = (is_set_wei == rhs.is_set_wei
190+
&& mask_wei == rhs.mask_wei
191+
&& data_type_wei == rhs.data_type_wei
192+
&& ndims_wei == rhs.ndims_wei
193+
&& IMPLICATION(ndims_wei > 0,
194+
utils::array_cmp(
195+
dims_wei, rhs.dims_wei, ndims_wei)));
196+
return result;
101197
}
102-
103198
size_t get_hash() const;
104199

105200
void serialize(serialization_stream_t &sstream) const;
@@ -109,23 +204,32 @@ struct quant_entry_t : public c_compatible {
109204
std::string get_verbose() const;
110205

111206
private:
207+
data_type_t data_type_ = data_type::undef;
112208
int group_ndims_ = 0;
113209
dims_t group_dims_ {};
114-
public:
115210
// Note: INT_MIN is used on purpose to avoid potential issues when
116211
// `(mask & bit)` expression will return `true`. `INT_MIN` is represented
117212
// as `10...0` in bits and will avoid such situations.
118213
int mask_ = INT_MIN;
119-
data_type_t data_type_ = data_type::undef;
214+
bool is_set_ = false;
120215
// openvino extension
216+
enum entry_type {
217+
NONE = 0,
218+
DNNL = 1,
219+
OV_SCALES = 2,
220+
OV_ZERO_POINTS = 4
221+
};
222+
int type_ = NONE;
121223
// scale
122-
bool is_set_ = false;
123-
int ndims_ = 0;
124-
dims_t dims_ {};
224+
bool is_set_scale = false;
225+
int ndims_scale = 0;
226+
int mask_scale = INT_MIN;
227+
dims_t dims_scale {};
228+
data_type_t data_type_scale = data_type::undef;
125229
// zero_point
126230
bool is_set_wei = false;
127231
int ndims_wei = 0;
128-
int mask_wei = 0;
232+
int mask_wei = INT_MIN;
129233
dims_t dims_wei {};
130234
data_type_t data_type_wei = data_type::s32;
131235
};
@@ -144,57 +248,35 @@ struct quant_entries_t : public c_compatible {
144248

145249
// See `set(...)` comment for `quant_entry_t` for a design choice
146250
// explanation.
147-
status_t set(int arg, int mask) {
251+
virtual status_t set(int arg, int mask) {
148252
return set(arg, mask, default_data_type_, 0, {});
149253
}
150-
status_t set_scales(int arg, const dims_t dims, int ndims, data_type_t data_type = data_type::f32) {
151-
if (!check_arg(arg)) return status::invalid_arguments;
152-
entries_[arg].is_set_ = true;
153-
entries_[arg].ndims_ = ndims;
154-
entries_[arg].mask_ = 1;
155-
entries_[arg].data_type_ = data_type;
156-
utils::array_copy(entries_[arg].dims_, dims, entries_[arg].ndims_);
157-
return status::success;
158-
}
159-
status_t set_zero_points(int arg, const dims_t dims, int ndims, data_type_t data_type) {
160-
const bool supported_arg = utils::one_of(arg, DNNL_ARG_WEIGHTS);
161-
if (!supported_arg) return status::unimplemented;
162-
163-
switch (arg) {
164-
case DNNL_ARG_WEIGHTS:
165-
entries_[arg].is_set_wei = true;
166-
entries_[arg].ndims_wei = ndims;
167-
entries_[arg].mask_wei = 1;
168-
utils::array_copy(entries_[arg].dims_wei, dims, ndims);
169-
entries_[arg].data_type_wei = data_type;
170-
break;
171-
}
172-
return status::success;
173-
}
174254
const dims_t & get_dims(int arg) const {
175-
return get(arg).dims_wei;
255+
return get(arg).get_dims();
176256
}
177257
int get_ndims(int arg) const {
178-
const bool supported_arg = utils::one_of(arg, DNNL_ARG_WEIGHTS);
179-
if (!supported_arg) return status::unimplemented;
180-
switch (arg) {
181-
case DNNL_ARG_WEIGHTS: return get(arg).ndims_wei; break;
182-
default: return 0;
183-
}
258+
return get(arg).get_ndims();
184259
}
185-
status_t set(int arg, int mask, data_type_t data_type, int group_ndims,
260+
virtual status_t set(int arg, int mask, data_type_t data_type, int group_ndims,
186261
const dims_t group_dims) {
187262
if (!check_arg(arg)) return status::invalid_arguments;
188263
CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims));
189-
if (arg == DNNL_ARG_WEIGHTS) {
190-
utils::array_copy(entries_[arg].dims_wei, group_dims, group_ndims);
191-
entries_[arg].ndims_wei = group_ndims;
192-
}
193264
return status::success;
194265
}
266+
status_t set_scales(int arg, const dims_t dims, int ndims, data_type_t data_type = data_type::f32) {
267+
if (!check_arg(arg)) return status::invalid_arguments;
268+
CHECK(entries_[arg].set_scales(dims, ndims, data_type));
269+
return status::success;
270+
}
271+
status_t set_zero_points(int arg, const dims_t dims, int ndims, data_type_t data_type) {
272+
if (arg != DNNL_ARG_WEIGHTS) return status::unimplemented;
273+
CHECK(entries_[arg].set_zero_points(dims, ndims, data_type));
274+
return status::success;
275+
}
276+
195277
// Use this interface with `default_quant_entry` when need to remove a
196278
// specific entry.
197-
status_t set(int arg, const quant_entry_t &other) {
279+
virtual status_t set(int arg, const quant_entry_t &other) {
198280
return entries_[arg].set(other);
199281
}
200282

@@ -356,7 +438,23 @@ struct zero_points_t : public quant_entries_t {
356438
}
357439

358440
static zero_points_t deserialize(deserializer_t &d);
441+
status_t set(int arg, int mask) override {
442+
return quant_entries_t::set(arg, mask, default_data_type_, 0, {});
443+
}
444+
status_t set(int arg, int mask, data_type_t data_type, int group_ndims,
445+
const dims_t group_dims) override {
446+
if (!check_arg(arg)) return status::invalid_arguments;
447+
if (arg == DNNL_ARG_WEIGHTS) {
448+
CHECK(entries_[arg].set_zero_points(group_dims, group_ndims, data_type, mask));
449+
} else {
450+
CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims));
451+
}
452+
return status::success;
453+
}
359454

455+
status_t set(int arg, const quant_entry_t &other) override {
456+
return quant_entries_t::set(arg, other);
457+
}
360458
private:
361459
static constexpr data_type_t default_data_type_ = data_type::s32;
362460

@@ -388,7 +486,7 @@ struct src_dyn_quant_params_t : public c_compatible {
388486
return status::success;
389487
}
390488

391-
uint64_t get() {
489+
uint64_t get() const {
392490
return group_size_;
393491
}
394492

@@ -397,6 +495,7 @@ struct src_dyn_quant_params_t : public c_compatible {
397495
return group_size_ == rhs.group_size_;
398496
}
399497

498+
private:
400499
uint64_t group_size_;
401500
};
402501

src/common/primitive_hashing_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ size_t get_attr_hash(const primitive_attr_t &attr) {
192192
seed = hash_combine(
193193
seed, get_md_hash(attr.dropout_.user_dropout_desc_));
194194
}
195-
seed = hash_combine(seed, attr.src_dyn_quant_params_.group_size_);
195+
seed = hash_combine(seed, attr.src_dyn_quant_params_.get());
196196
// Combined hash for attributes
197197
return seed;
198198
}

src/common/primitive_serialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ void serialize(serialization_stream_t &sstream, const primitive_attr_t &attr) {
232232
int zero = 0;
233233
sstream.append(zero);
234234
}
235-
sstream.append(attr.src_dyn_quant_params_.group_size_);
235+
sstream.append(attr.src_dyn_quant_params_.get());
236236
}
237237

238238
void serialize(serialization_stream_t &sstream, const concat_desc_t &desc) {

src/common/verbose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) {
833833

834834
const src_dyn_quant_params_t &dyn_qp = attr->src_dyn_quant_params_;
835835
if (!dyn_qp.has_default_values()) {
836-
ss << "src_dyn_quant_group_size:" << dyn_qp.group_size_ << ";";
836+
ss << "src_dyn_quant_group_size:" << dyn_qp.get() << ";";
837837
}
838838

839839
if (!attr->dropout_.has_default_values()) {

0 commit comments

Comments
 (0)