@@ -39,58 +39,58 @@ using namespace mlir::tosa;
3939namespace {
4040
4141static LogicalResult checkConstantOperandPad (Operation *op) {
42- if (auto pad_op = dyn_cast<tosa::PadOp>(op)) {
42+ if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
4343 DenseElementsAttr paddings;
44- if (!matchPattern (pad_op .getPadding (), m_Constant (&paddings)))
44+ if (!matchPattern (padOp .getPadding (), m_Constant (&paddings)))
4545 return op->emitOpError (" padding of pad is not constant" );
4646
47- DenseElementsAttr pad_const ;
48- // Assume this op is zero-padding if pad_const is not presented.
49- if (pad_op .getPadConst () &&
50- !matchPattern (pad_op .getPadConst (), m_Constant (&pad_const )))
47+ DenseElementsAttr padConst ;
48+ // Assume this op is zero-padding if padConst is not presented.
49+ if (padOp .getPadConst () &&
50+ !matchPattern (padOp .getPadConst (), m_Constant (&padConst )))
5151 return op->emitOpError (" pad_const of pad is not constant" );
5252 }
5353 return success ();
5454}
5555
5656static LogicalResult checkConstantOperandTranspose (Operation *op) {
57- if (auto transpose_op = dyn_cast<tosa::TransposeOp>(op)) {
57+ if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
5858 DenseElementsAttr perms;
59- if (!matchPattern (transpose_op .getPerms (), m_Constant (&perms)))
59+ if (!matchPattern (transposeOp .getPerms (), m_Constant (&perms)))
6060 return op->emitOpError (" perms of transpose is not constant" );
6161 }
6262 return success ();
6363}
6464
6565static LogicalResult checkConstantOperandFullyConnected (Operation *op) {
66- if (auto fc_op = dyn_cast<tosa::FullyConnectedOp>(op)) {
66+ if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
6767 DenseElementsAttr weight;
68- if (!matchPattern (fc_op .getWeight (), m_Constant (&weight)))
68+ if (!matchPattern (fcOp .getWeight (), m_Constant (&weight)))
6969 return op->emitOpError (" weight of fully_connected is not constant" );
7070
7171 DenseElementsAttr bias;
72- if (!matchPattern (fc_op .getBias (), m_Constant (&bias)))
72+ if (!matchPattern (fcOp .getBias (), m_Constant (&bias)))
7373 return op->emitOpError (" bias of fully_connected is not constant" );
7474 }
7575 return success ();
7676}
7777
78- struct tosa_level_t {
78+ struct TosaLevel {
7979 int32_t MAX_RANK = 0 ;
8080 int32_t MAX_KERNEL = 0 ;
8181 int32_t MAX_STRIDE = 0 ;
8282 int32_t MAX_SCALE = 0 ;
8383
8484 // @todo: MAX_LOG2_SIZE value and checks
8585
86- bool operator ==(const tosa_level_t &rhs) {
86+ bool operator ==(const TosaLevel &rhs) {
8787 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
8888 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
8989 }
9090};
9191
92- static constexpr tosa_level_t TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
93- static constexpr tosa_level_t TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
92+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
93+ static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
9494
9595// ===----------------------------------------------------------------------===//
9696// TOSA Validation Pass.
@@ -108,7 +108,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
108108 void runOnOperation () final ;
109109
110110 LogicalResult applyConstantOperandCheck (Operation *op) {
111- for (auto &checker : const_checkers ) {
111+ for (auto &checker : constCheckers ) {
112112 if (failed (checker (op)))
113113 return failure ();
114114 }
@@ -122,43 +122,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
122122
123123private:
124124 void populateConstantOperandChecks () {
125- const_checkers .emplace_back (checkConstantOperandPad);
126- const_checkers .emplace_back (checkConstantOperandTranspose);
127- const_checkers .emplace_back (checkConstantOperandFullyConnected);
125+ constCheckers .emplace_back (checkConstantOperandPad);
126+ constCheckers .emplace_back (checkConstantOperandTranspose);
127+ constCheckers .emplace_back (checkConstantOperandFullyConnected);
128128 }
129129
130130 bool levelCheckKernel (Operation *op, int32_t v,
131- const std::string &check_desc ) {
132- if (v > tosa_level .MAX_KERNEL ) {
133- op->emitOpError () << " failed level check: " << check_desc ;
131+ const std::string &checkDesc ) {
132+ if (v > tosaLevel .MAX_KERNEL ) {
133+ op->emitOpError () << " failed level check: " << checkDesc ;
134134 return false ;
135135 }
136136 return true ;
137137 }
138138
139139 bool levelCheckStride (Operation *op, int32_t v,
140- const std::string &check_desc ) {
141- if (v > tosa_level .MAX_STRIDE ) {
142- op->emitOpError () << " failed level check: " << check_desc ;
140+ const std::string &checkDesc ) {
141+ if (v > tosaLevel .MAX_STRIDE ) {
142+ op->emitOpError () << " failed level check: " << checkDesc ;
143143 return false ;
144144 }
145145 return true ;
146146 }
147147
148- bool levelCheckScale (Operation *op, int32_t v,
149- const std::string &check_desc) {
150- if (v > tosa_level.MAX_SCALE ) {
151- op->emitOpError () << " failed level check: " << check_desc;
148+ bool levelCheckScale (Operation *op, int32_t v, const std::string &checkDesc) {
149+ if (v > tosaLevel.MAX_SCALE ) {
150+ op->emitOpError () << " failed level check: " << checkDesc;
152151 return false ;
153152 }
154153 return true ;
155154 }
156155
157156 bool levelCheckRank (Operation *op, const Value &v,
158- const std::string &check_desc ) {
157+ const std::string &checkDesc ) {
159158 if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
160- if (type.getRank () > tosa_level .MAX_RANK ) {
161- op->emitOpError () << " failed level check: " << check_desc ;
159+ if (type.getRank () > tosaLevel .MAX_RANK ) {
160+ op->emitOpError () << " failed level check: " << checkDesc ;
162161 return false ;
163162 }
164163 }
@@ -182,8 +181,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
182181 }
183182
184183 bool levelCheckRanks (Operation *op) {
185- #define CHECK_RANKS_FOR (tosa_op ) \
186- if (!levelCheckRanksFor<tosa_op ##Op>(op)) \
184+ #define CHECK_RANKS_FOR (tosaOp ) \
185+ if (!levelCheckRanksFor<tosaOp ##Op>(op)) \
187186 return false ;
188187
189188 // tensor operators:
@@ -257,18 +256,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
257256 // Pool Op: level check kernel/stride/pad values
258257 template <typename T>
259258 bool levelCheckPool (Operation *op) {
260- if (auto pool_op = dyn_cast<T>(op)) {
261- for (auto k : pool_op .getKernel ()) {
259+ if (auto poolOp = dyn_cast<T>(op)) {
260+ for (auto k : poolOp .getKernel ()) {
262261 if (!levelCheckKernel (op, k, " kernel <= MAX_KERNEL" )) {
263262 return false ;
264263 }
265264 }
266- for (auto s : pool_op .getStride ()) {
265+ for (auto s : poolOp .getStride ()) {
267266 if (!levelCheckStride (op, s, " stride <= MAX_STRIDE" )) {
268267 return false ;
269268 }
270269 }
271- for (auto p : pool_op .getPad ()) {
270+ for (auto p : poolOp .getPad ()) {
272271 if (!levelCheckKernel (op, p, " pad <= MAX_KERNEL" )) {
273272 return false ;
274273 }
@@ -280,27 +279,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
280279 // Conv Op: level check dilation/stride/pad values
281280 template <typename T>
282281 bool levelCheckConv (Operation *op) {
283- if (auto conv_op = dyn_cast<T>(op)) {
282+ if (auto convOp = dyn_cast<T>(op)) {
284283
285- for (auto k : conv_op .getDilation ()) {
284+ for (auto k : convOp .getDilation ()) {
286285 if (!levelCheckKernel (op, k, " dilation <= MAX_KERNEL" )) {
287286 return false ;
288287 }
289288 }
290- for (auto p : conv_op .getPad ()) {
289+ for (auto p : convOp .getPad ()) {
291290 if (!levelCheckKernel (op, p, " pad <= MAX_KERNEL" )) {
292291 return false ;
293292 }
294293 }
295- for (auto s : conv_op .getStride ()) {
294+ for (auto s : convOp .getStride ()) {
296295 if (!levelCheckStride (op, s, " stride <= MAX_STRIDE" )) {
297296 return false ;
298297 }
299298 }
300- auto dilation = conv_op .getDilation ();
301- if (ShapedType weight_type =
299+ auto dilation = convOp .getDilation ();
300+ if (ShapedType weightType =
302301 dyn_cast<ShapedType>(op->getOperand (1 ).getType ())) {
303- auto shape = weight_type .getShape ();
302+ auto shape = weightType .getShape ();
304303 if (isa<tosa::Conv2DOp>(op)) {
305304 assert (shape.size () == 4 );
306305 assert (dilation.size () == 2 );
@@ -354,9 +353,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
354353 // TransposeConv2d op: level check kH/kW, outpad, and stride
355354 bool levelCheckTransposeConv2d (Operation *op) {
356355 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
357- if (ShapedType filter_type =
356+ if (ShapedType filterType =
358357 transpose.getFilter ().getType ().dyn_cast <ShapedType>()) {
359- auto shape = filter_type .getShape ();
358+ auto shape = filterType .getShape ();
360359 assert (shape.size () == 4 );
361360 // level check kernel sizes for kH and KW
362361 if (!levelCheckKernel (op, shape[1 ], " KH <= MAX_KERNEL" ) ||
@@ -382,13 +381,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
382381 bool levelCheckResize (Operation *op) {
383382 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
384383 auto scale = resize.getScale ();
385- int16_t scale_y_n = scale[0 ];
386- int16_t scale_y_d = scale[1 ];
387- int16_t scale_x_n = scale[2 ];
388- int16_t scale_x_d = scale[3 ];
389- if (!levelCheckScale (op, scale_y_n / scale_y_d ,
384+ int16_t scaleYN = scale[0 ];
385+ int16_t scaleYD = scale[1 ];
386+ int16_t scaleXN = scale[2 ];
387+ int16_t scaleXD = scale[3 ];
388+ if (!levelCheckScale (op, scaleYN / scaleYD ,
390389 " scale_y_n/scale_y_d <= MAX_SCALE" ) ||
391- !levelCheckScale (op, scale_x_n / scale_x_d ,
390+ !levelCheckScale (op, scaleXN / scaleXD ,
392391 " scale_x_n/scale_x_d <= MAX_SCALE" )) {
393392 return false ;
394393 }
@@ -399,22 +398,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
399398 // configure profile and level values from pass options profileName and
400399 // levelName
401400 void configLevelAndProfile () {
402- tosa_level = TOSA_LEVEL_NONE;
401+ tosaLevel = TOSA_LEVEL_NONE;
403402 if (level == TosaLevelEnum::EightK) {
404- tosa_level = TOSA_LEVEL_EIGHTK;
403+ tosaLevel = TOSA_LEVEL_EIGHTK;
405404 }
406405 }
407406
408407 bool CheckVariable (Operation *op);
409408 bool CheckVariableReadOrWrite (Operation *op);
410409
411- SmallVector<std::function<LogicalResult(Operation *)>> const_checkers ;
412- tosa_level_t tosa_level ;
413- DenseMap<StringAttr, mlir::Type> variables_map ;
410+ SmallVector<std::function<LogicalResult(Operation *)>> constCheckers ;
411+ TosaLevel tosaLevel ;
412+ DenseMap<StringAttr, mlir::Type> variablesMap ;
414413};
415414
416415LogicalResult TosaValidation::applyLevelCheck (Operation *op) {
417- if (tosa_level == TOSA_LEVEL_NONE) {
416+ if (tosaLevel == TOSA_LEVEL_NONE) {
418417 // no need to do level checks
419418 return success ();
420419 }
@@ -439,24 +438,24 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439438}
440439
441440inline bool CompatibleTypes (const mlir::Type &type,
442- const mlir::Type &declared_type ) {
441+ const mlir::Type &declaredType ) {
443442 // for now, simply use type equality comparison
444- return type == declared_type ;
443+ return type == declaredType ;
445444}
446445
447446bool TosaValidation::CheckVariable (Operation *op) {
448447 if (isa<mlir::tosa::VariableOp>(op)) {
449- auto name_attr = cast<mlir::StringAttr>(op->getAttr (" name" ));
448+ auto nameAttr = cast<mlir::StringAttr>(op->getAttr (" name" ));
450449
451- if (variables_map .count (name_attr )) {
450+ if (variablesMap .count (nameAttr )) {
452451 op->emitOpError () << " name has already been declared" ;
453452 return false ;
454453 }
455454
456- auto type_attr = cast<mlir::TypeAttr>(op->getAttr (" type" ));
457- mlir::Type type = type_attr .getValue ();
455+ auto typeAttr = cast<mlir::TypeAttr>(op->getAttr (" type" ));
456+ mlir::Type type = typeAttr .getValue ();
458457
459- variables_map[name_attr ] = type;
458+ variablesMap[nameAttr ] = type;
460459 }
461460
462461 return true ;
@@ -465,26 +464,26 @@ bool TosaValidation::CheckVariable(Operation *op) {
465464bool TosaValidation::CheckVariableReadOrWrite (Operation *op) {
466465 if (isa<mlir::tosa::VariableReadOp>(op) ||
467466 isa<mlir::tosa::VariableWriteOp>(op)) {
468- auto name_attr = cast<mlir::StringAttr>(op->getAttr (" name" ));
467+ auto nameAttr = cast<mlir::StringAttr>(op->getAttr (" name" ));
469468
470- if (!variables_map .count (name_attr )) {
469+ if (!variablesMap .count (nameAttr )) {
471470 op->emitOpError () << " name has not been declared" ;
472471 return false ;
473472 }
474473
475- auto var_type = variables_map[name_attr ];
474+ auto varType = variablesMap[nameAttr ];
476475
477476 for (auto v : op->getOperands ()) {
478477 auto type = v.getType ();
479- if (!CompatibleTypes (type, var_type )) {
478+ if (!CompatibleTypes (type, varType )) {
480479 op->emitOpError () << " operand type does not equal variable type" ;
481480 return false ;
482481 }
483482 }
484483
485484 for (auto v : op->getResults ()) {
486485 auto type = v.getType ();
487- if (!CompatibleTypes (type, var_type )) {
486+ if (!CompatibleTypes (type, varType )) {
488487 op->emitOpError () << " result type does not equal variable type" ;
489488 return false ;
490489 }
0 commit comments