@@ -3172,3 +3172,219 @@ def test_cli_file_input(self):
31723172 self .assertEqual (res .err , b"" )
31733173 self .assertEqual (expected .splitlines (), res .out .decode ("utf8" ).splitlines ())
31743174 self .assertEqual (res .rc , 0 )
3175+
3176+ def compare (left , right ):
3177+ return ast .dump (left ) == ast .dump (right )
3178+
3179+ class ASTOptimiziationTests (unittest .TestCase ):
3180+ binop = {
3181+ "+" : ast .Add (),
3182+ "-" : ast .Sub (),
3183+ "*" : ast .Mult (),
3184+ "/" : ast .Div (),
3185+ "%" : ast .Mod (),
3186+ "<<" : ast .LShift (),
3187+ ">>" : ast .RShift (),
3188+ "|" : ast .BitOr (),
3189+ "^" : ast .BitXor (),
3190+ "&" : ast .BitAnd (),
3191+ "//" : ast .FloorDiv (),
3192+ "**" : ast .Pow (),
3193+ }
3194+
3195+ unaryop = {
3196+ "~" : ast .Invert (),
3197+ "+" : ast .UAdd (),
3198+ "-" : ast .USub (),
3199+ }
3200+
3201+ def wrap_expr (self , expr ):
3202+ return ast .Module (body = [ast .Expr (value = expr )])
3203+
3204+ def wrap_for (self , for_statement ):
3205+ return ast .Module (body = [for_statement ])
3206+
3207+ def assert_ast (self , code , non_optimized_target , optimized_target ):
3208+
3209+ non_optimized_tree = ast .parse (code , optimize = - 1 )
3210+ optimized_tree = ast .parse (code , optimize = 1 )
3211+
3212+ # Is a non-optimized tree equal to a non-optimized target?
3213+ self .assertTrue (
3214+ compare (non_optimized_tree , non_optimized_target ),
3215+ f"{ ast .dump (non_optimized_target )} must equal "
3216+ f"{ ast .dump (non_optimized_tree )} " ,
3217+ )
3218+
3219+ # Is a optimized tree equal to a non-optimized target?
3220+ self .assertFalse (
3221+ compare (optimized_tree , non_optimized_target ),
3222+ f"{ ast .dump (non_optimized_target )} must not equal "
3223+ f"{ ast .dump (non_optimized_tree )} "
3224+ )
3225+
3226+ # Is a optimized tree is equal to an optimized target?
3227+ self .assertTrue (
3228+ compare (optimized_tree , optimized_target ),
3229+ f"{ ast .dump (optimized_target )} must equal "
3230+ f"{ ast .dump (optimized_tree )} " ,
3231+ )
3232+
3233+ def test_folding_binop (self ):
3234+ code = "1 %s 1"
3235+ operators = self .binop .keys ()
3236+
3237+ def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3238+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3239+
3240+ for op in operators :
3241+ result_code = code % op
3242+ non_optimized_target = self .wrap_expr (create_binop (op ))
3243+ optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3244+
3245+ with self .subTest (
3246+ result_code = result_code ,
3247+ non_optimized_target = non_optimized_target ,
3248+ optimized_target = optimized_target
3249+ ):
3250+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3251+
3252+ # Multiplication of constant tuples must be folded
3253+ code = "(1,) * 3"
3254+ non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3255+ optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3256+
3257+ self .assert_ast (code , non_optimized_target , optimized_target )
3258+
3259+ def test_folding_unaryop (self ):
3260+ code = "%s1"
3261+ operators = self .unaryop .keys ()
3262+
3263+ def create_unaryop (operand ):
3264+ return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3265+
3266+ for op in operators :
3267+ result_code = code % op
3268+ non_optimized_target = self .wrap_expr (create_unaryop (op ))
3269+ optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3270+
3271+ with self .subTest (
3272+ result_code = result_code ,
3273+ non_optimized_target = non_optimized_target ,
3274+ optimized_target = optimized_target
3275+ ):
3276+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3277+
3278+ def test_folding_not (self ):
3279+ code = "not (1 %s (1,))"
3280+ operators = {
3281+ "in" : ast .In (),
3282+ "is" : ast .Is (),
3283+ }
3284+ opt_operators = {
3285+ "is" : ast .IsNot (),
3286+ "in" : ast .NotIn (),
3287+ }
3288+
3289+ def create_notop (operand ):
3290+ return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3291+ left = ast .Constant (value = 1 ),
3292+ ops = [operators [operand ]],
3293+ comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3294+ ))
3295+
3296+ for op in operators .keys ():
3297+ result_code = code % op
3298+ non_optimized_target = self .wrap_expr (create_notop (op ))
3299+ optimized_target = self .wrap_expr (
3300+ ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3301+ )
3302+
3303+ with self .subTest (
3304+ result_code = result_code ,
3305+ non_optimized_target = non_optimized_target ,
3306+ optimized_target = optimized_target
3307+ ):
3308+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3309+
3310+ def test_folding_format (self ):
3311+ code = "'%s' % (a,)"
3312+
3313+ non_optimized_target = self .wrap_expr (
3314+ ast .BinOp (
3315+ left = ast .Constant (value = "%s" ),
3316+ op = ast .Mod (),
3317+ right = ast .Tuple (elts = [ast .Name (id = 'a' )]))
3318+ )
3319+ optimized_target = self .wrap_expr (
3320+ ast .JoinedStr (
3321+ values = [
3322+ ast .FormattedValue (value = ast .Name (id = 'a' ), conversion = 115 )
3323+ ]
3324+ )
3325+ )
3326+
3327+ self .assert_ast (code , non_optimized_target , optimized_target )
3328+
3329+
3330+ def test_folding_tuple (self ):
3331+ code = "(1,)"
3332+
3333+ non_optimized_target = self .wrap_expr (ast .Tuple (elts = [ast .Constant (1 )]))
3334+ optimized_target = self .wrap_expr (ast .Constant (value = (1 ,)))
3335+
3336+ self .assert_ast (code , non_optimized_target , optimized_target )
3337+
3338+ def test_folding_comparator (self ):
3339+ code = "1 %s %s1%s"
3340+ operators = [("in" , ast .In ()), ("not in" , ast .NotIn ())]
3341+ braces = [
3342+ ("[" , "]" , ast .List , (1 ,)),
3343+ ("{" , "}" , ast .Set , frozenset ({1 })),
3344+ ]
3345+ for left , right , non_optimized_comparator , optimized_comparator in braces :
3346+ for op , node in operators :
3347+ non_optimized_target = self .wrap_expr (ast .Compare (
3348+ left = ast .Constant (1 ), ops = [node ],
3349+ comparators = [non_optimized_comparator (elts = [ast .Constant (1 )])]
3350+ ))
3351+ optimized_target = self .wrap_expr (ast .Compare (
3352+ left = ast .Constant (1 ), ops = [node ],
3353+ comparators = [ast .Constant (value = optimized_comparator )]
3354+ ))
3355+ self .assert_ast (code % (op , left , right ), non_optimized_target , optimized_target )
3356+
3357+ def test_folding_iter (self ):
3358+ code = "for _ in %s1%s: pass"
3359+ braces = [
3360+ ("[" , "]" , ast .List , (1 ,)),
3361+ ("{" , "}" , ast .Set , frozenset ({1 })),
3362+ ]
3363+
3364+ for left , right , ast_cls , optimized_iter in braces :
3365+ non_optimized_target = self .wrap_for (ast .For (
3366+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3367+ iter = ast_cls (elts = [ast .Constant (1 )]),
3368+ body = [ast .Pass ()]
3369+ ))
3370+ optimized_target = self .wrap_for (ast .For (
3371+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3372+ iter = ast .Constant (value = optimized_iter ),
3373+ body = [ast .Pass ()]
3374+ ))
3375+
3376+ self .assert_ast (code % (left , right ), non_optimized_target , optimized_target )
3377+
3378+ def test_folding_subscript (self ):
3379+ code = "(1,)[0]"
3380+
3381+ non_optimized_target = self .wrap_expr (
3382+ ast .Subscript (value = ast .Tuple (elts = [ast .Constant (value = 1 )]), slice = ast .Constant (value = 0 ))
3383+ )
3384+ optimized_target = self .wrap_expr (ast .Constant (value = 1 ))
3385+
3386+ self .assert_ast (code , non_optimized_target , optimized_target )
3387+
3388+
3389+ if __name__ == "__main__" :
3390+ unittest .main ()
0 commit comments