@@ -53,6 +53,8 @@ struct quant_entry_t : public c_compatible {
5353 }
5454 status_t set (int mask, data_type_t data_type, int group_ndims,
5555 const dims_t group_dims) {
56+ type_ = type_ | DNNL;
57+ is_set_ = true ;
5658 mask_ = mask;
5759 data_type_ = data_type;
5860 group_ndims_ = group_ndims;
@@ -61,25 +63,98 @@ struct quant_entry_t : public c_compatible {
6163 }
6264 return status::success;
6365 }
66+ status_t set_scales (const dims_t dims, int ndims, data_type_t data_type = data_type::f32 , int mask = 1 ) {
67+ type_ = type_ | OV_SCALES;
68+ is_set_scale = true ;
69+ ndims_scale = ndims;
70+ mask_scale = mask;
71+ data_type_scale = data_type;
72+ if (ndims_scale > 0 ) {
73+ utils::array_copy (dims_scale, dims, ndims_scale);
74+ }
75+ return status::success;
76+ }
77+ status_t set_zero_points (const dims_t dims, int ndims, data_type_t data_type) {
78+ type_ = type_ | OV_ZERO_POINTS;
79+ is_set_wei = true ;
80+ ndims_wei = ndims;
81+ mask_wei = 1 ;
82+ if (ndims_wei > 0 ) {
83+ utils::array_copy (dims_wei, dims, ndims_wei);
84+ }
85+ data_type_wei = data_type;
86+ return status::success;
87+ }
88+ status_t set_zero_points (const dims_t dims, int ndims, data_type_t data_type, int mask) {
89+ type_ = type_ | DNNL;
90+ is_set_wei = true ;
91+ ndims_wei = ndims;
92+ mask_wei = mask;
93+ if (ndims_wei > 0 ) {
94+ utils::array_copy (dims_wei, dims, ndims_wei);
95+ group_ndims_ = ndims;
96+ utils::array_copy (group_dims_, dims, group_ndims_);
97+ }
98+ data_type_wei = data_type;
99+ return status::success;
100+ }
64101 status_t set (const quant_entry_t &other) {
65- return set (other.mask_ , other.data_type_ , other.group_ndims_ ,
66- other.group_dims_ );
102+ type_ = other.type_ ;
103+ is_set_ = other.is_set_ ;
104+ mask_ = other.mask_ ;
105+ data_type_ = other.data_type_ ;
106+ group_ndims_ = other.group_ndims_ ;
107+ if (group_ndims_ > 0 )
108+ utils::array_copy (group_dims_, other.group_dims_ , group_ndims_);
109+ is_set_scale = other.is_set_scale ;
110+ mask_scale = other.mask_scale ;
111+ data_type_scale = other.data_type_scale ;
112+ ndims_scale = other.ndims_scale ;
113+ if (ndims_scale > 0 )
114+ utils::array_cmp (dims_scale, other.dims_scale , ndims_scale);
115+ is_set_wei = other.is_set_wei ;
116+ mask_wei = other.mask_wei ;
117+ data_type_wei = other.data_type_wei ;
118+ ndims_wei = other.ndims_wei ;
119+ if (ndims_wei > 0 )
120+ utils::array_cmp (dims_wei, other.dims_wei , ndims_wei);
121+ return status::success;
67122 }
68-
69123 quant_entry_t &operator =(const quant_entry_t &rhs) {
70124 auto st = this ->set (rhs);
71125 assert (st == status::success);
72126 UNUSED (st);
73127 return *this ;
74128 }
75-
76129 bool has_default_values () const { return *this == default_quant_entry (); }
77130 bool has_default_groups () const {
78131 return this ->group_ndims_ == default_quant_entry ().group_ndims_ ;
79132 }
80-
81- int get_mask () const { return mask_; }
82- data_type_t get_data_type () const { return data_type_; }
133+ int get_mask () const {
134+ if (is_set_wei) return mask_wei;
135+ if (is_set_) return mask_;
136+ if (is_set_scale) return mask_scale;
137+ return INT_MIN;
138+ }
139+ data_type_t get_data_type () const {
140+ if (is_set_wei) return data_type_wei;
141+ if (is_set_) return data_type_;
142+ if (is_set_scale) return data_type_scale;
143+ return data_type::undef;
144+ }
145+ const dims_t & get_dims () const {
146+ if (is_set_wei) return dims_wei;
147+ if (is_set_) return group_dims_;
148+ if (is_set_scale) return dims_scale;
149+ static const dims_t result = {};
150+ return result;
151+ }
152+ int get_ndims () const {
153+ if (is_set_wei) return ndims_wei;
154+ if (is_set_) return group_ndims_;
155+ if (is_set_scale) return ndims_scale;
156+ return 0 ;
157+ }
83158 dim_t get_group (int d) const {
84159 // If groups were not requested, return `1` for convenience.
85160 if (group_ndims_ == default_quant_entry ().group_ndims_ ) return 1 ;
@@ -93,13 +168,33 @@ struct quant_entry_t : public c_compatible {
93168 // `gtests/internals/test_comparison_operators` linking requirements which
94169 // mandates bodies to be in the header file.
95170 bool operator ==(const quant_entry_t &rhs) const {
96- return mask_ == rhs.mask_ && data_type_ == rhs.data_type_
171+ bool result = (type_ == rhs.type_ && is_set_ == rhs.is_set_
172+ && mask_ == rhs.mask_
173+ && data_type_ == rhs.data_type_
97174 && group_ndims_ == rhs.group_ndims_
98175 && IMPLICATION (group_ndims_ > 0 ,
99- utils::array_cmp (
100- group_dims_, rhs.group_dims_ , group_ndims_));
176+ utils::array_cmp (
177+ group_dims_, rhs.group_dims_ , group_ndims_)));
178+
179+ if (!result) return false ;
180+ result = (is_set_scale == rhs.is_set_scale
181+ && mask_scale == rhs.mask_scale
182+ && data_type_scale == rhs.data_type_scale
183+ && ndims_scale == rhs.ndims_scale
184+ && IMPLICATION (ndims_scale > 0 ,
185+ utils::array_cmp (
186+ dims_scale, rhs.dims_scale , ndims_scale)));
187+
188+ if (!result) return false ;
189+ result = (is_set_wei == rhs.is_set_wei
190+ && mask_wei == rhs.mask_wei
191+ && data_type_wei == rhs.data_type_wei
192+ && ndims_wei == rhs.ndims_wei
193+ && IMPLICATION (ndims_wei > 0 ,
194+ utils::array_cmp (
195+ dims_wei, rhs.dims_wei , ndims_wei)));
196+ return result;
101197 }
102-
103198 size_t get_hash () const ;
104199
105200 void serialize (serialization_stream_t &sstream) const ;
@@ -109,23 +204,32 @@ struct quant_entry_t : public c_compatible {
109204 std::string get_verbose () const ;
110205
111206private:
207+ data_type_t data_type_ = data_type::undef;
112208 int group_ndims_ = 0 ;
113209 dims_t group_dims_ {};
114- public:
115210 // Note: INT_MIN is used on purpose to avoid potential issues when
116211 // `(mask & bit)` expression will return `true`. `INT_MIN` is represented
117212 // as `10...0` in bits and will avoid such situations.
118213 int mask_ = INT_MIN;
119- data_type_t data_type_ = data_type::undef ;
214+ bool is_set_ = false ;
120215 // openvino extension
216+ enum entry_type {
217+ NONE = 0 ,
218+ DNNL = 1 ,
219+ OV_SCALES = 2 ,
220+ OV_ZERO_POINTS = 4
221+ };
222+ int type_ = NONE;
121223 // scale
122- bool is_set_ = false ;
123- int ndims_ = 0 ;
124- dims_t dims_ {};
224+ bool is_set_scale = false ;
225+ int ndims_scale = 0 ;
226+ int mask_scale = INT_MIN;
227+ dims_t dims_scale {};
228+ data_type_t data_type_scale = data_type::undef;
125229 // zero_point
126230 bool is_set_wei = false ;
127231 int ndims_wei = 0 ;
128- int mask_wei = 0 ;
232+ int mask_wei = INT_MIN ;
129233 dims_t dims_wei {};
130234 data_type_t data_type_wei = data_type::s32;
131235};
@@ -144,57 +248,35 @@ struct quant_entries_t : public c_compatible {
144248
145249 // See `set(...)` comment for `quant_entry_t` for a design choice
146250 // explanation.
147- status_t set (int arg, int mask) {
251+ virtual status_t set (int arg, int mask) {
148252 return set (arg, mask, default_data_type_, 0 , {});
149253 }
150- status_t set_scales (int arg, const dims_t dims, int ndims, data_type_t data_type = data_type::f32 ) {
151- if (!check_arg (arg)) return status::invalid_arguments;
152- entries_[arg].is_set_ = true ;
153- entries_[arg].ndims_ = ndims;
154- entries_[arg].mask_ = 1 ;
155- entries_[arg].data_type_ = data_type;
156- utils::array_copy (entries_[arg].dims_ , dims, entries_[arg].ndims_ );
157- return status::success;
158- }
159- status_t set_zero_points (int arg, const dims_t dims, int ndims, data_type_t data_type) {
160- const bool supported_arg = utils::one_of (arg, DNNL_ARG_WEIGHTS);
161- if (!supported_arg) return status::unimplemented;
162-
163- switch (arg) {
164- case DNNL_ARG_WEIGHTS:
165- entries_[arg].is_set_wei = true ;
166- entries_[arg].ndims_wei = ndims;
167- entries_[arg].mask_wei = 1 ;
168- utils::array_copy (entries_[arg].dims_wei , dims, ndims);
169- entries_[arg].data_type_wei = data_type;
170- break ;
171- }
172- return status::success;
173- }
174254 const dims_t & get_dims (int arg) const {
175- return get (arg).dims_wei ;
255+ return get (arg).get_dims () ;
176256 }
177257 int get_ndims (int arg) const {
178- const bool supported_arg = utils::one_of (arg, DNNL_ARG_WEIGHTS);
179- if (!supported_arg) return status::unimplemented;
180- switch (arg) {
181- case DNNL_ARG_WEIGHTS: return get (arg).ndims_wei ; break ;
182- default : return 0 ;
183- }
258+ return get (arg).get_ndims ();
184259 }
185- status_t set (int arg, int mask, data_type_t data_type, int group_ndims,
260+ virtual status_t set (int arg, int mask, data_type_t data_type, int group_ndims,
186261 const dims_t group_dims) {
187262 if (!check_arg (arg)) return status::invalid_arguments;
188263 CHECK (entries_[arg].set (mask, data_type, group_ndims, group_dims));
189- if (arg == DNNL_ARG_WEIGHTS) {
190- utils::array_copy (entries_[arg].dims_wei , group_dims, group_ndims);
191- entries_[arg].ndims_wei = group_ndims;
192- }
193264 return status::success;
194265 }
266+ status_t set_scales (int arg, const dims_t dims, int ndims, data_type_t data_type = data_type::f32 ) {
267+ if (!check_arg (arg)) return status::invalid_arguments;
268+ CHECK (entries_[arg].set_scales (dims, ndims, data_type));
269+ return status::success;
270+ }
271+ status_t set_zero_points (int arg, const dims_t dims, int ndims, data_type_t data_type) {
272+ if (arg != DNNL_ARG_WEIGHTS) return status::unimplemented;
273+ CHECK (entries_[arg].set_zero_points (dims, ndims, data_type));
274+ return status::success;
275+ }
276+
195277 // Use this interface with `default_quant_entry` when need to remove a
196278 // specific entry.
197- status_t set (int arg, const quant_entry_t &other) {
279+ virtual status_t set (int arg, const quant_entry_t &other) {
198280 return entries_[arg].set (other);
199281 }
200282
@@ -356,7 +438,23 @@ struct zero_points_t : public quant_entries_t {
356438 }
357439
358440 static zero_points_t deserialize (deserializer_t &d);
441+ status_t set (int arg, int mask) override {
442+ return quant_entries_t::set (arg, mask, default_data_type_, 0 , {});
443+ }
444+ status_t set (int arg, int mask, data_type_t data_type, int group_ndims,
445+ const dims_t group_dims) override {
446+ if (!check_arg (arg)) return status::invalid_arguments;
447+ if (arg == DNNL_ARG_WEIGHTS) {
448+ CHECK (entries_[arg].set_zero_points (group_dims, group_ndims, data_type, mask));
449+ } else {
450+ CHECK (entries_[arg].set (mask, data_type, group_ndims, group_dims));
451+ }
452+ return status::success;
453+ }
359454
455+ status_t set (int arg, const quant_entry_t &other) override {
456+ return quant_entries_t::set (arg, other);
457+ }
360458private:
361459 static constexpr data_type_t default_data_type_ = data_type::s32;
362460
@@ -388,7 +486,7 @@ struct src_dyn_quant_params_t : public c_compatible {
388486 return status::success;
389487 }
390488
391- uint64_t get () {
489+ uint64_t get () const {
392490 return group_size_;
393491 }
394492
@@ -397,6 +495,7 @@ struct src_dyn_quant_params_t : public c_compatible {
397495 return group_size_ == rhs.group_size_ ;
398496 }
399497
498+ private:
400499 uint64_t group_size_;
401500};
402501
0 commit comments