Skip to content

Commit

Permalink
Declare key Op attributes in the class signature
Browse files Browse the repository at this point in the history
Creating new attributes at runtime lead to unclear typing and mypy errors.
  • Loading branch information
Michael Osthege authored and brandonwillard committed Mar 1, 2022
1 parent d81da63 commit 4474eb9
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions aesara/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class Op(MetaObject):
"""

itypes = None
otypes = None
params_type: Optional[ParamsType] = None

def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
Expand All @@ -208,13 +212,13 @@ def make_node(self, *inputs: Variable) -> Apply:
The constructed `Apply` node.
"""
if not hasattr(self, "itypes"):
if self.itypes is None:
raise NotImplementedError(
"You can either define itypes and otypes,\
or implement make_node"
)

if not hasattr(self, "otypes"):
if self.otypes is None:
raise NotImplementedError(
"You can either define itypes and otypes,\
or implement make_node"
Expand Down Expand Up @@ -446,7 +450,7 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:

def get_params(self, node: Apply) -> Params:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if hasattr(self, "params_type") and isinstance(self.params_type, ParamsType):
if isinstance(self.params_type, ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
Expand Down Expand Up @@ -1091,7 +1095,7 @@ def __get_op_params(self) -> List[Tuple[str, Any]]:
"""
params: List[Tuple[str, Any]] = []
if hasattr(self, "params_type") and isinstance(self.params_type, ParamsType):
if isinstance(self.params_type, ParamsType):
wrapper = self.params_type
params.append(("PARAMS_TYPE", wrapper.name))
for i in range(wrapper.length):
Expand All @@ -1111,7 +1115,7 @@ def __get_op_params(self) -> List[Tuple[str, Any]]:

def c_code_cache_version(self):
version = (hash(tuple(self.func_codes)),)
if hasattr(self, "params_type"):
if self.params_type is not None:
version += (self.params_type.c_code_cache_version(),)
return version

Expand Down

0 comments on commit 4474eb9

Please sign in to comment.