Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Jun 22, 2023
1 parent efc53b0 commit 67d9d66
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 16 deletions.
33 changes: 33 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import TYPE_CHECKING

import paddle

from ...utils import BreakGraphError, NotImplementException
from ...utils.magic_methods import (
BINARY_OPS,
Expand Down Expand Up @@ -97,6 +99,37 @@
lambda var: var.bool(),
)

# TensorVariable
Dispatcher.register(
paddle.is_tensor,
("TensorVariable",),
{},
lambda var: var.is_tensor(),
)
Dispatcher.register(
paddle.is_complex,
("TensorVariable",),
{},
lambda var: var.is_complex(),
)
Dispatcher.register(
paddle.is_integer,
("TensorVariable",),
{},
lambda var: var.is_integer(),
)
Dispatcher.register(
paddle.is_floating_point,
("TensorVariable",),
{},
lambda var: var.is_floating_point(),
)
Dispatcher.register(
paddle.rank,
("TensorVariable",),
{},
lambda var: var.ndim,
)

# VariableBase
Dispatcher.register(
Expand Down
33 changes: 17 additions & 16 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,33 @@
if TYPE_CHECKING:
from ..function_graph import FunctionGraph

DTYPE_ABBRS = {
FP_DTYPE_ABBRS = {
paddle.bfloat16: 'bfloat16',
paddle.float64: 'float64',
paddle.float32: 'float32',
paddle.float16: 'float16',
}

CP_DTYPE_ABBRS = {
paddle.complex64: 'complex64',
paddle.complex128: 'complex128',
}

INT_DTYPE_ABBRS = {
paddle.int8: 'int8',
paddle.int16: 'int16',
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.bool: 'bool',
paddle.uint8: 'uint8',
}

DTYPE_ABBRS = {
**FP_DTYPE_ABBRS,
**CP_DTYPE_ABBRS,
**INT_DTYPE_ABBRS,
paddle.bool: 'bool',
}


class ConstantVariable(VariableBase):
def __init__(
Expand Down Expand Up @@ -233,28 +245,17 @@ def is_tensor(self):

def is_complex(self):
dtype = self.meta.dtype
is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128
is_cp_dtype = dtype in CP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_cp_dtype)

def is_integer(self):
dtype = self.meta.dtype
is_int_dtype = (
dtype == paddle.int8
or dtype == paddle.uint8
or dtype == paddle.int16
or dtype == paddle.int32
or dtype == paddle.int64
)
is_int_dtype = dtype in INT_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_int_dtype)

def is_floating_point(self):
dtype = self.meta.dtype
is_fp_dtype = (
dtype == paddle.float32
or dtype == paddle.float64
or dtype == paddle.float16
or dtype == paddle.bfloat16
)
is_fp_dtype = dtype in FP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_fp_dtype)

def getattr(self, name: str):
Expand Down

0 comments on commit 67d9d66

Please sign in to comment.