Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Jun 22, 2023
1 parent 676d1b6 commit 1affc7d
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,34 @@
if TYPE_CHECKING:
from ..function_graph import FunctionGraph

DTYPE_ABBRS = {

FP_DTYPE_ABBRS = {
paddle.bfloat16: 'bfloat16',
paddle.float64: 'float64',
paddle.float32: 'float32',
paddle.float16: 'float16',
}

CP_DTYPE_ABBRS = {
paddle.complex64: 'complex64',
paddle.complex128: 'complex128',
}

INT_DTYPE_ABBRS = {
paddle.int8: 'int8',
paddle.int16: 'int16',
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.bool: 'bool',
paddle.uint8: 'uint8',
}

DTYPE_ABBRS = {
**FP_DTYPE_ABBRS,
**CP_DTYPE_ABBRS,
**INT_DTYPE_ABBRS,
paddle.bool: 'bool',
}


class ConstantVariable(VariableBase):
def __init__(
Expand Down Expand Up @@ -235,28 +248,17 @@ def is_tensor(self):

def is_complex(self):
dtype = self.meta.dtype
is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128
is_cp_dtype = dtype in CP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_cp_dtype)

def is_integer(self):
dtype = self.meta.dtype
is_int_dtype = (
dtype == paddle.int8
or dtype == paddle.uint8
or dtype == paddle.int16
or dtype == paddle.int32
or dtype == paddle.int64
)
is_int_dtype = dtype in INT_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_int_dtype)

def is_floating_point(self):
dtype = self.meta.dtype
is_fp_dtype = (
dtype == paddle.float32
or dtype == paddle.float64
or dtype == paddle.float16
or dtype == paddle.bfloat16
)
is_fp_dtype = dtype in FP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_fp_dtype)

def getattr(self, name: str):
Expand Down

0 comments on commit 1affc7d

Please sign in to comment.