Skip to content

Commit

Permalink
Merge branch 'ershi/update-typehints' into 'main'
Browse files Browse the repository at this point in the history
Update typehints

See merge request omniverse/warp!1073
  • Loading branch information
shi-eric committed Feb 11, 2025
2 parents 1d4b239 + 70222f7 commit 56a290b
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 258 deletions.
4 changes: 2 additions & 2 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sametypes(arg_types: Mapping[str, Any]):
return all(types_equal(arg_type_0, t) for t in arg_types_iter)


def sametypes_create_value_func(default):
def sametypes_create_value_func(default: TypeVar):
def fn(arg_types, arg_values):
if arg_types is None:
return default
Expand Down Expand Up @@ -390,7 +390,7 @@ def fn(arg_types, arg_values):
)


def scalar_infer_type(arg_types: Mapping[str, type]):
def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
if arg_types is None:
return Scalar

Expand Down
50 changes: 30 additions & 20 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, message):


# map operator to function name
builtin_operators = {}
builtin_operators: Dict[type[ast.AST], str] = {}

# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
# nice overview of python operators
Expand Down Expand Up @@ -397,12 +397,14 @@ def numpy_value(self):


class Struct:
def __init__(self, cls, key, module):
hash: bytes

def __init__(self, cls: type, key: str, module: warp.context.Module):
self.cls = cls
self.module = module
self.key = key
self.vars: Dict[str, Var] = {}

self.vars = {}
annotations = get_annotations(self.cls)
for label, type in annotations.items():
self.vars[label] = Var(label, type)
Expand Down Expand Up @@ -573,11 +575,11 @@ def __init__(self, value_type):
self.value_type = value_type


def is_reference(type):
def is_reference(type: Any) -> builtins.bool:
return isinstance(type, Reference)


def strip_reference(arg):
def strip_reference(arg: Any) -> Any:
if is_reference(arg):
return arg.value_type
else:
Expand Down Expand Up @@ -605,7 +607,14 @@ def param2str(p):


class Var:
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
def __init__(
self,
label: str,
type: type,
requires_grad: builtins.bool = False,
constant: Optional[builtins.bool] = None,
prefix: builtins.bool = True,
):
# convert built-in types to wp types
if type == float:
type = float32
Expand All @@ -632,7 +641,7 @@ def __str__(self):
return self.label

@staticmethod
def type_to_ctype(t, value_type=False):
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
if is_array(t):
if hasattr(t.dtype, "_wp_generic_type_str_"):
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
Expand Down Expand Up @@ -663,7 +672,7 @@ def type_to_ctype(t, value_type=False):
else:
return f"wp::{t.__name__}"

def ctype(self, value_type=False):
def ctype(self, value_type: builtins.bool = False) -> str:
return Var.type_to_ctype(self.type, value_type)

def emit(self, prefix: str = "var"):
Expand Down Expand Up @@ -785,7 +794,7 @@ def func_match_args(func, arg_types, kwarg_types):
return True


def get_arg_type(arg: Union[Var, Any]):
def get_arg_type(arg: Union[Var, Any]) -> type:
if isinstance(arg, str):
return str

Expand All @@ -801,7 +810,7 @@ def get_arg_type(arg: Union[Var, Any]):
return type(arg)


def get_arg_value(arg: Union[Var, Any]):
def get_arg_value(arg: Any) -> Any:
if isinstance(arg, Sequence):
return tuple(get_arg_value(x) for x in arg)

Expand Down Expand Up @@ -923,9 +932,6 @@ def __init__(
# for unit testing errors being spit out from kernels.
adj.skip_build = False

# Collect the LTOIR required at link-time
adj.ltoirs = []

# allocate extra space for a function call that requires its
# own shared memory space, we treat shared memory as a stack
# where each function pushes and pops space off, the extra
Expand Down Expand Up @@ -1263,7 +1269,7 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):

# Bind the positional and keyword arguments to the function's signature
# in order to process them as Python does it.
bound_args = func.signature.bind(*args, **kwargs)
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)

# Type args are the “compile time” argument values we get from codegen.
# For example, when calling `wp.vec3f(...)` from within a kernel,
Expand Down Expand Up @@ -2929,12 +2935,16 @@ def eval_len(obj):

# We want to replace the expression code in-place,
# so reparse it to get the correct column info.
len_value_locs = []
len_value_locs: List[Tuple[int, int, int]] = []
expr_tree = ast.parse(static_code)
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
expr_root = expr_tree.body[0].value
for expr_node in ast.walk(expr_root):
if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
if (
isinstance(expr_node, ast.Call)
and getattr(expr_node.func, "id", None) == "len"
and len(expr_node.args) == 1
):
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
try:
len_value = eval(len_expr, len_expr_ctx)
Expand Down Expand Up @@ -3092,9 +3102,9 @@ def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.conte

local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed

constants = {}
types = {}
functions = {}
constants: Dict[str, Any] = {}
types: Dict[Union[Struct, type], Any] = {}
functions: Dict[warp.context.Function, Any] = {}

for node in ast.walk(adj.tree):
if isinstance(node, ast.Name) and node.id not in local_variables:
Expand Down Expand Up @@ -3400,7 +3410,7 @@ def indent(args, stops=1):


# generates a C function name based on the python function name
def make_full_qualified_name(func):
def make_full_qualified_name(func: Union[str, Callable]) -> str:
if not isinstance(func, str):
func = func.__qualname__
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
Expand Down
Loading

0 comments on commit 56a290b

Please sign in to comment.