@@ -642,7 +642,7 @@ def _is_list_or_tuple_(data):
642642 if len (x .shape ) == 4 :
643643 if len (out_shape ) != 2 :
644644 raise ValueError (
645- "size length should be 2 for " " input 4-D tensor."
645+ "size length should be 2 for input 4-D tensor."
646646 )
647647 if contain_var :
648648 attrs ['out_h' ] = size_list [0 ]
@@ -667,6 +667,30 @@ def _is_list_or_tuple_(data):
667667 attrs ['out_w' ] = out_shape [2 ]
668668
669669 elif scale is not None :
670+ # scale in python is float64, but in kernel is float32, so we need to recalculate the scale in float32
671+ # Currently it is only used when x.size is 0.
672+ x_shape = x .shape
673+ if data_format == 'NCW' :
674+ max_dim = x_shape [2 ]
675+ elif data_format == 'NWC' :
676+ max_dim = x_shape [1 ]
677+ elif data_format == 'NCHW' :
678+ max_dim = max (x .shape [2 ], x .shape [3 ])
679+ elif data_format == 'NHWC' :
680+ max_dim = max (x .shape [1 ], x .shape [2 ])
681+ elif data_format == 'NCDHW' :
682+ max_dim = max (x .shape [2 ], x .shape [3 ], x .shape [4 ])
683+ elif data_format == 'NDHWC' :
684+ max_dim = max (x .shape [1 ], x .shape [2 ], x .shape [3 ])
685+ else :
686+ max_dim = 1
687+
688+ def _scale_to_float32 (value ):
689+ if len (str (value )) <= 10 :
690+ return value
691+ # round down
692+ return numpy .float32 (int (value * max_dim ) / max_dim )
693+
670694 if recompute_scale_factor :
671695 if in_dynamic_mode () and isinstance (scale , Variable ):
672696 if scale .shape == []:
@@ -710,11 +734,15 @@ def _is_list_or_tuple_(data):
710734
711735 scale = None
712736 else :
713- if in_dynamic_mode () and isinstance (scale , Variable ):
737+ dynamic_mode = False
738+ if in_dynamic_mode ():
739+ dynamic_mode = True
740+ if dynamic_mode and isinstance (scale , Variable ):
714741 if scale .shape == []:
715742 scale = float (scale )
716743 else :
717744 scale = list (scale .numpy ())
745+
718746 if isinstance (scale , (Variable , paddle .pir .Value )):
719747 scale .stop_gradient = True
720748 inputs ["Scale" ] = scale
@@ -724,7 +752,10 @@ def _is_list_or_tuple_(data):
724752 scale_list = []
725753 for i in range (len (x .shape ) - 2 ):
726754 scale_list .append (scale )
727- attrs ['scale' ] = list (map (float , scale_list ))
755+ if dynamic_mode and x .size == 0 :
756+ attrs ['scale' ] = list (map (_scale_to_float32 , scale_list ))
757+ else :
758+ attrs ['scale' ] = list (map (float , scale_list ))
728759 elif isinstance (scale , (list , tuple )):
729760 if len (scale ) != len (x .shape ) - 2 :
730761 raise ValueError (
@@ -736,7 +767,10 @@ def _is_list_or_tuple_(data):
736767 raise ValueError (
737768 "Attr(scale) should be greater than zero."
738769 )
739- attrs ['scale' ] = list (map (float , scale ))
770+ if dynamic_mode and x .size == 0 :
771+ attrs ['scale' ] = list (map (_scale_to_float32 , scale ))
772+ else :
773+ attrs ['scale' ] = list (map (float , scale ))
740774 else :
741775 raise TypeError (
742776 "Attr(scale)'s type should be float, int, list, tuple, or Tensor."
0 commit comments