@@ -4348,7 +4348,11 @@ def cumsum(
43484348 flatten = True
43494349 else :
43504350 flatten = False
4351- if dtype is not None and x .dtype != convert_np_dtype_to_dtype_ (dtype ):
4351+ if dtype is None :
4352+ dtype = x .dtype
4353+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
4354+ dtype = convert_np_dtype_to_dtype_ (dtype )
4355+ if dtype is not None and x .dtype != dtype :
43524356 x = cast (x , dtype )
43534357
43544358 if in_dynamic_or_pir_mode ():
@@ -4396,7 +4400,11 @@ def cumsum_(
43964400 flatten = True
43974401 else :
43984402 flatten = False
4399- if dtype is not None and x .dtype != convert_np_dtype_to_dtype_ (dtype ):
4403+ if dtype is None :
4404+ dtype = x .dtype
4405+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
4406+ dtype = convert_np_dtype_to_dtype_ (dtype )
4407+ if dtype is not None and x .dtype != dtype :
44004408 x = cast_ (x , dtype )
44014409
44024410 if in_dynamic_mode ():
@@ -4652,8 +4660,12 @@ def logcumsumexp(
46524660 flatten = True
46534661 else :
46544662 flatten = False
4655- if dtype is not None and x .dtype != convert_np_dtype_to_dtype_ (dtype ):
4656- x = cast (x , dtype )
4663+ if dtype is None :
4664+ dtype = x .dtype
4665+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
4666+ dtype = convert_np_dtype_to_dtype_ (dtype )
4667+ if dtype is not None and x .dtype != dtype :
4668+ x = cast_ (x , dtype )
46574669
46584670 if in_dynamic_or_pir_mode ():
46594671 if axis is None :
@@ -4753,9 +4765,12 @@ def cumprod(
47534765 if dim is None :
47544766 dim = - 1
47554767 x = x .flatten (0 , len (x .shape ) - 1 )
4756-
4757- if dtype is not None and x .dtype != convert_np_dtype_to_dtype_ (dtype ):
4758- x = cast (x , dtype )
4768+ if dtype is None :
4769+ dtype = x .dtype
4770+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
4771+ dtype = convert_np_dtype_to_dtype_ (dtype )
4772+ if dtype is not None and x .dtype != dtype :
4773+ x = cast_ (x , dtype )
47594774
47604775 if in_dynamic_or_pir_mode ():
47614776 return _C_ops .cumprod (x , dim , False , False )
@@ -4802,8 +4817,11 @@ def cumprod_(
48024817 if dim is None :
48034818 dim = - 1
48044819 x = _C_ops .flatten_ (x , 0 , len (x .shape ) - 1 )
4805-
4806- if dtype is not None and x .dtype != convert_np_dtype_to_dtype_ (dtype ):
4820+ if dtype is None :
4821+ dtype = x .dtype
4822+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
4823+ dtype = convert_np_dtype_to_dtype_ (dtype )
4824+ if dtype is not None and x .dtype != dtype :
48074825 x = cast_ (x , dtype )
48084826
48094827 if in_dynamic_mode ():
@@ -5031,14 +5049,18 @@ def prod(
50315049 [24. , 1680.])
50325050
50335051 """
5052+ if dtype is None :
5053+ dtype = x .dtype
50345054 if dtype is not None :
5055+ if not isinstance (dtype , (core .VarDesc .VarType , core .DataType )):
5056+ dtype = convert_np_dtype_to_dtype_ (dtype )
50355057 check_dtype (
50365058 dtype ,
50375059 'dtype' ,
50385060 ['float32' , 'float64' , 'int32' , 'int64' , "float16" , "uint16" ],
50395061 'prod' ,
50405062 )
5041- if x .dtype != convert_np_dtype_to_dtype_ ( dtype ) :
5063+ if x .dtype != dtype :
50425064 x = cast (x , dtype )
50435065
50445066 # axis is 0-size tensor.
0 commit comments