From 70222f701c037d0ca5be2ed1e6f4af59fe5b8a50 Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Tue, 11 Feb 2025 08:13:26 -0800 Subject: [PATCH] Add/fix typehints and refactoring for mypy issues --- warp/builtins.py | 4 +- warp/codegen.py | 50 ++++--- warp/context.py | 339 +++++++++++++++++++++++++---------------------- warp/types.py | 60 +++++---- warp/utils.py | 102 +++++++------- 5 files changed, 297 insertions(+), 258 deletions(-) diff --git a/warp/builtins.py b/warp/builtins.py index ad85597af..080c5a400 100644 --- a/warp/builtins.py +++ b/warp/builtins.py @@ -32,7 +32,7 @@ def sametypes(arg_types: Mapping[str, Any]): return all(types_equal(arg_type_0, t) for t in arg_types_iter) -def sametypes_create_value_func(default): +def sametypes_create_value_func(default: TypeVar): def fn(arg_types, arg_values): if arg_types is None: return default @@ -390,7 +390,7 @@ def fn(arg_types, arg_values): ) -def scalar_infer_type(arg_types: Mapping[str, type]): +def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]): if arg_types is None: return Scalar diff --git a/warp/codegen.py b/warp/codegen.py index 3696b6261..d17fc7d24 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -49,7 +49,7 @@ def __init__(self, message): # map operator to function name -builtin_operators = {} +builtin_operators: Dict[type[ast.AST], str] = {} # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a # nice overview of python operators @@ -397,12 +397,14 @@ def numpy_value(self): class Struct: - def __init__(self, cls, key, module): + hash: bytes + + def __init__(self, cls: type, key: str, module: warp.context.Module): self.cls = cls self.module = module self.key = key + self.vars: Dict[str, Var] = {} - self.vars = {} annotations = get_annotations(self.cls) for label, type in annotations.items(): self.vars[label] = Var(label, type) @@ -573,11 +575,11 @@ def __init__(self, value_type): self.value_type = value_type -def is_reference(type): +def is_reference(type: Any) -> builtins.bool: return isinstance(type, Reference) -def strip_reference(arg): +def strip_reference(arg: Any) -> Any: if is_reference(arg): return arg.value_type else: @@ -605,7 +607,14 @@ def param2str(p): class Var: - def __init__(self, label, type, requires_grad=False, constant=None, prefix=True): + def __init__( + self, + label: str, + type: type, + requires_grad: builtins.bool = False, + constant: Optional[builtins.bool] = None, + prefix: builtins.bool = True, + ): # convert built-in types to wp types if type == float: type = float32 @@ -632,7 +641,7 @@ def __str__(self): return self.label @staticmethod - def type_to_ctype(t, value_type=False): + def type_to_ctype(t: type, value_type: builtins.bool = False) -> str: if is_array(t): if hasattr(t.dtype, "_wp_generic_type_str_"): dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_) @@ -663,7 +672,7 @@ def type_to_ctype(t, value_type=False): else: return f"wp::{t.__name__}" - def ctype(self, value_type=False): + def ctype(self, value_type: builtins.bool = False) -> str: return Var.type_to_ctype(self.type, value_type) def emit(self, prefix: str = "var"): @@ -785,7 +794,7 @@ def func_match_args(func, arg_types, kwarg_types): return True -def get_arg_type(arg: Union[Var, Any]): +def get_arg_type(arg: Union[Var, Any]) -> type: if isinstance(arg, str): return str @@ -801,7 +810,7 @@ def get_arg_type(arg: Union[Var, Any]): return type(arg) -def get_arg_value(arg: Union[Var, Any]): +def get_arg_value(arg: Any) -> Any: if isinstance(arg, Sequence): return tuple(get_arg_value(x) for x in arg) @@ -923,9 +932,6 @@ def __init__( # for unit testing errors being spit out from kernels. adj.skip_build = False - # Collect the LTOIR required at link-time - adj.ltoirs = [] - # allocate extra space for a function call that requires its # own shared memory space, we treat shared memory as a stack # where each function pushes and pops space off, the extra @@ -1263,7 +1269,7 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): # Bind the positional and keyword arguments to the function's signature # in order to process them as Python does it. - bound_args = func.signature.bind(*args, **kwargs) + bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs) # Type args are the “compile time” argument values we get from codegen. # For example, when calling `wp.vec3f(...)` from within a kernel, @@ -2929,12 +2935,16 @@ def eval_len(obj): # We want to replace the expression code in-place, # so reparse it to get the correct column info. - len_value_locs = [] + len_value_locs: List[Tuple[int, int, int]] = [] expr_tree = ast.parse(static_code) assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr) expr_root = expr_tree.body[0].value for expr_node in ast.walk(expr_root): - if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1: + if ( + isinstance(expr_node, ast.Call) + and getattr(expr_node.func, "id", None) == "len" + and len(expr_node.args) == 1 + ): len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset] try: len_value = eval(len_expr, len_expr_ctx) @@ -3092,9 +3102,9 @@ def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.conte local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed - constants = {} - types = {} - functions = {} + constants: Dict[str, Any] = {} + types: Dict[Union[Struct, type], Any] = {} + functions: Dict[warp.context.Function, Any] = {} for node in ast.walk(adj.tree): if isinstance(node, ast.Name) and node.id not in local_variables: @@ -3400,7 +3410,7 @@ def indent(args, stops=1): # generates a C function name based on the python function name -def make_full_qualified_name(func): +def make_full_qualified_name(func: Union[str, Callable]) -> str: if not isinstance(func, str): func = func.__qualname__ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__")) diff --git a/warp/context.py b/warp/context.py index 39957cd6a..7a67c54c2 100644 --- a/warp/context.py +++ b/warp/context.py @@ -26,7 +26,21 @@ import weakref from copy import copy as shallowcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, get_args, get_origin +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + get_args, + get_origin, +) import numpy as np @@ -34,6 +48,7 @@ import warp.build import warp.codegen import warp.config +from warp.types import Array # represents either a built-in or user-defined function @@ -62,10 +77,10 @@ def get_function_args(func): complex_type_hints = (Any, Callable, Tuple) sequence_types = (list, tuple) -function_key_counts = {} +function_key_counts: Dict[str, int] = {} -def generate_unique_function_identifier(key): +def generate_unique_function_identifier(key: str) -> str: # Generate unique identifiers for user-defined functions in native code. # - Prevents conflicts when a function is redefined and old versions are still in use. # - Prevents conflicts between multiple closures returned from the same function. @@ -98,40 +113,40 @@ def generate_unique_function_identifier(key): class Function: def __init__( self, - func, - key, - namespace, - input_types=None, - value_type=None, - value_func=None, - export_func=None, - dispatch_func=None, - lto_dispatch_func=None, - module=None, - variadic=False, - initializer_list_func=None, - export=False, - doc="", - group="", - hidden=False, - skip_replay=False, - missing_grad=False, - generic=False, - native_func=None, - defaults=None, - custom_replay_func=None, - native_snippet=None, - adj_native_snippet=None, - replay_snippet=None, - skip_forward_codegen=False, - skip_reverse_codegen=False, - custom_reverse_num_input_args=-1, - custom_reverse_mode=False, - overloaded_annotations=None, - code_transformers=None, - skip_adding_overload=False, - require_original_output_arg=False, - scope_locals=None, # the locals() where the function is defined, used for overload management + func: Optional[Callable], + key: str, + namespace: str, + input_types: Optional[Dict[str, Union[type, TypeVar]]] = None, + value_type: Optional[type] = None, + value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None, + export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None, + dispatch_func: Optional[Callable] = None, + lto_dispatch_func: Optional[Callable] = None, + module: Optional[Module] = None, + variadic: bool = False, + initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None, + export: bool = False, + doc: str = "", + group: str = "", + hidden: bool = False, + skip_replay: bool = False, + missing_grad: bool = False, + generic: bool = False, + native_func: Optional[str] = None, + defaults: Optional[Dict[str, Any]] = None, + custom_replay_func: Optional[Function] = None, + native_snippet: Optional[str] = None, + adj_native_snippet: Optional[str] = None, + replay_snippet: Optional[str] = None, + skip_forward_codegen: bool = False, + skip_reverse_codegen: bool = False, + custom_reverse_num_input_args: int = -1, + custom_reverse_mode: bool = False, + overloaded_annotations: Optional[Dict[str, type]] = None, + code_transformers: Optional[List[ast.NodeTransformer]] = None, + skip_adding_overload: bool = False, + require_original_output_arg: bool = False, + scope_locals: Optional[Dict[str, Any]] = None, ): if code_transformers is None: code_transformers = [] @@ -156,7 +171,7 @@ def __init__( self.native_snippet = native_snippet self.adj_native_snippet = adj_native_snippet self.replay_snippet = replay_snippet - self.custom_grad_func = None + self.custom_grad_func: Optional[Function] = None self.require_original_output_arg = require_original_output_arg self.generic_parent = None # generic function that was used to instantiate this overload @@ -172,6 +187,7 @@ def __init__( ) self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint self.generic = generic + self.mangled_name: Optional[str] = None # allow registering functions with a different name in Python and native code if native_func is None: @@ -188,8 +204,8 @@ def __init__( # user-defined function # generic and concrete overload lookups by type signature - self.user_templates = {} - self.user_overloads = {} + self.user_templates: Dict[str, Function] = {} + self.user_overloads: Dict[str, Function] = {} # user defined (Python) function self.adj = warp.codegen.Adjoint( @@ -220,19 +236,17 @@ def __init__( # builtin function # embedded linked list of all overloads - # the builtin_functions dictionary holds - # the list head for a given key (func name) - self.overloads = [] + # the builtin_functions dictionary holds the list head for a given key (func name) + self.overloads: List[Function] = [] # builtin (native) function, canonicalize argument types - for k, v in input_types.items(): - self.input_types[k] = warp.types.type_to_warp(v) + if input_types is not None: + for k, v in input_types.items(): + self.input_types[k] = warp.types.type_to_warp(v) # cache mangled name if self.export and self.is_simple(): self.mangled_name = self.mangle() - else: - self.mangled_name = None if not skip_adding_overload: self.add_overload(self) @@ -263,7 +277,7 @@ def __init__( signature_params.append(param) self.signature = inspect.Signature(signature_params) - # scope for resolving overloads + # scope for resolving overloads, the locals() where the function is defined if scope_locals is None: scope_locals = inspect.currentframe().f_back.f_locals @@ -325,10 +339,10 @@ def __call__(self, *args, **kwargs): # this function has no overloads, call it like a plain Python function return self.func(*args, **kwargs) - def is_builtin(self): + def is_builtin(self) -> bool: return self.func is None - def is_simple(self): + def is_simple(self) -> bool: if self.variadic: return False @@ -342,9 +356,8 @@ def is_simple(self): return True - def mangle(self): - # builds a mangled name for the C-exported - # function, e.g.: builtin_normalize_vec3() + def mangle(self) -> str: + """Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`.""" name = "builtin_" + self.key @@ -360,7 +373,7 @@ def mangle(self): return "_".join([name, *types]) - def add_overload(self, f): + def add_overload(self, f: Function) -> None: if self.is_builtin(): # todo: note that it is an error to add two functions # with the exact same signature as this would cause compile @@ -375,7 +388,7 @@ def add_overload(self, f): else: # get function signature based on the input types sig = warp.types.get_signature( - f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys()) + list(f.input_types.values()), func_name=f.key, arg_names=list(f.input_types.keys()) ) # check if generic @@ -384,7 +397,7 @@ def add_overload(self, f): else: self.user_overloads[sig] = f - def get_overload(self, arg_types, kwarg_types): + def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]: assert not self.is_builtin() for f in self.user_overloads.values(): @@ -437,7 +450,7 @@ def __repr__(self): return f"" -def call_builtin(func: Function, *params) -> Tuple[bool, Any]: +def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]: uses_non_warp_array_type = False init() @@ -754,7 +767,7 @@ def get_mangled_name(self): # decorator to register function, @func -def func(f): +def func(f: Callable) -> Callable: name = warp.codegen.make_full_qualified_name(f) scope_locals = inspect.currentframe().f_back.f_locals @@ -777,14 +790,18 @@ def func(f): return functools.update_wrapper(g, f) -def func_native(snippet, adj_snippet=None, replay_snippet=None): +def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None): """ Decorator to register native code snippet, @func_native """ - scope_locals = inspect.currentframe().f_back.f_locals + frame = inspect.currentframe() + if frame is None or frame.f_back is None: + scope_locals = {} + else: + scope_locals = frame.f_back.f_locals - def snippet_func(f): + def snippet_func(f: Callable) -> Callable: name = warp.codegen.make_full_qualified_name(f) m = get_module(f.__module__) @@ -958,7 +975,7 @@ def wrapper(replay_fn): # decorator to register kernel, @kernel, custom_name may be a string # that creates a kernel with a different name from the actual function -def kernel(f=None, *, enable_backward=None): +def kernel(f: Optional[Callable] = None, *, enable_backward: Optional[bool] = None): def wrapper(f, *args, **kwargs): options = {} @@ -983,7 +1000,7 @@ def wrapper(f, *args, **kwargs): # decorator to register struct, @struct -def struct(c): +def struct(c: type): m = get_module(c.__module__) s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m) s = functools.update_wrapper(s, c) @@ -1096,47 +1113,47 @@ def get_generic_vtypes(): def add_builtin( - key, - input_types=None, - constraint=None, - value_type=None, - value_func=None, - export_func=None, - dispatch_func=None, - lto_dispatch_func=None, - doc="", - namespace="wp::", - variadic=False, + key: str, + input_types: Optional[Dict[str, Union[type, TypeVar]]] = None, + constraint: Optional[Callable[[Mapping[str, type]], bool]] = None, + value_type: Optional[type] = None, + value_func: Optional[Callable] = None, + export_func: Optional[Callable] = None, + dispatch_func: Optional[Callable] = None, + lto_dispatch_func: Optional[Callable] = None, + doc: str = "", + namespace: str = "wp::", + variadic: bool = False, initializer_list_func=None, - export=True, - group="Other", - hidden=False, - skip_replay=False, - missing_grad=False, - native_func=None, - defaults=None, - require_original_output_arg=False, + export: bool = True, + group: str = "Other", + hidden: bool = False, + skip_replay: bool = False, + missing_grad: bool = False, + native_func: Optional[str] = None, + defaults: Optional[Dict[str, Any]] = None, + require_original_output_arg: bool = False, ): """Main entry point to register a new built-in function. Args: - key (str): Function name. Multiple overloaded functions can be registered + key: Function name. Multiple overloaded functions can be registered under the same name as long as their signature differ. - input_types (Mapping[str, Any]): Signature of the user-facing function. + input_types: Signature of the user-facing function. Variadic arguments are supported by prefixing the parameter names with asterisks as in `*args` and `**kwargs`. Generic arguments are supported with types such as `Any`, `Float`, `Scalar`, etc. - constraint (Callable): For functions that define generic arguments and + constraint: For functions that define generic arguments and are to be exported, this callback is used to specify whether some combination of inferred arguments are valid or not. - value_type (Any): Type returned by the function. - value_func (Callable): Callback used to specify the return type when + value_type: Type returned by the function. + value_func: Callback used to specify the return type when `value_type` isn't enough. - export_func (Callable): Callback used during the context stage to specify + export_func: Callback used during the context stage to specify the signature of the underlying C++ function, not accounting for the template parameters. If not provided, `input_types` is used. - dispatch_func (Callable): Callback used during the codegen stage to specify + dispatch_func: Callback used during the codegen stage to specify the runtime and template arguments to be passed to the underlying C++ function. In other words, this allows defining a mapping between the signatures of the user-facing and the C++ functions, and even to @@ -1144,27 +1161,26 @@ def add_builtin( The arguments returned must be of type `codegen.Var`. If not provided, all arguments passed by the users when calling the built-in are passed as-is as runtime arguments to the C++ function. - lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict + lto_dispatch_func: Same as dispatch_func, but takes an 'option' dict as extra argument (indicating tile_size and target architecture) and returns an LTO-IR buffer as extra return value - doc (str): Used to generate the Python's docstring and the HTML documentation. + doc: Used to generate the Python's docstring and the HTML documentation. namespace: Namespace for the underlying C++ function. - variadic (bool): Whether the function declares variadic arguments. - initializer_list_func (bool): Whether to use the initializer list syntax - when passing the arguments to the underlying C++ function. - export (bool): Whether the function is to be exposed to the Python + variadic: Whether the function declares variadic arguments. + initializer_list_func: Callback to determine whether to use the + initializer list syntax when passing the arguments to the underlying + C++ function. + export: Whether the function is to be exposed to the Python interpreter so that it becomes available from within the `warp` module. - group (str): Classification used for the documentation. - hidden (bool): Whether to add that function into the documentation. - skip_replay (bool): Whether operation will be performed during + group: Classification used for the documentation. + hidden: Whether to add that function into the documentation. + skip_replay: Whether operation will be performed during the forward replay in the backward pass. - missing_grad (bool): Whether the function is missing a corresponding - adjoint. - native_func (str): Name of the underlying C++ function. - defaults (Mapping[str, Any]): Default values for the parameters defined - in `input_types`. - require_original_output_arg (bool): Used during the codegen stage to + missing_grad: Whether the function is missing a corresponding adjoint. + native_func: Name of the underlying C++ function. + defaults: Default values for the parameters defined in `input_types`. + require_original_output_arg: Used during the codegen stage to specify whether an adjoint parameter corresponding to the return value should be included in the signature of the backward function. """ @@ -1346,19 +1362,14 @@ def initializer_list_func(args, return_type): def register_api_function( function: Function, group: str = "Other", - hidden=False, + hidden: bool = False, ): """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation. Args: - function (Function): Warp function to be registered. - group (str): Classification used for the documentation. - input_types (Mapping[str, Any]): Signature of the user-facing function. - Variadic arguments are supported by prefixing the parameter names - with asterisks as in `*args` and `**kwargs`. Generic arguments are - supported with types such as `Any`, `Float`, `Scalar`, etc. - value_type (Any): Type returned by the function. - hidden (bool): Whether to add that function into the documentation. + function: Warp function to be registered. + group: Classification used for the documentation. + hidden: Whether to add that function into the documentation. """ function.group = group function.hidden = hidden @@ -1366,10 +1377,10 @@ def register_api_function( # global dictionary of modules -user_modules = {} +user_modules: Dict[str, Module] = {} -def get_module(name): +def get_module(name: str) -> Module: # some modules might be manually imported using `importlib` without being # registered into `sys.modules` parent = sys.modules.get(name, None) @@ -1457,7 +1468,7 @@ def __init__(self, module): # save the module hash self.module_hash = ch.digest() - def hash_kernel(self, kernel): + def hash_kernel(self, kernel: Kernel) -> bytes: # NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here. ch = hashlib.sha256() @@ -1471,7 +1482,7 @@ def hash_kernel(self, kernel): return h - def hash_function(self, func): + def hash_function(self, func: Function) -> bytes: # NOTE: This method hashes all possible overloads that a function call could resolve to. # The exact overload will be resolved at build time, when the argument types are known. @@ -1486,7 +1497,7 @@ def hash_function(self, func): ch.update(bytes(func.key, "utf-8")) # include all concrete and generic overloads - overloads = {**func.user_overloads, **func.user_templates} + overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates} for sig in sorted(overloads.keys()): ovl = overloads[sig] @@ -1517,7 +1528,7 @@ def hash_function(self, func): return h - def hash_adjoint(self, adj): + def hash_adjoint(self, adj: warp.codegen.Adjoint) -> bytes: # NOTE: We don't cache adjoint hashes, because adjoints are always unique. # Even instances of generic kernels and functions have unique adjoints with # different argument types. @@ -1566,7 +1577,7 @@ def hash_adjoint(self, adj): return ch.digest() - def get_constant_bytes(self, value): + def get_constant_bytes(self, value) -> bytes: if isinstance(value, int): # this also handles builtins.bool return bytes(ctypes.c_int(value)) @@ -1584,7 +1595,7 @@ def get_constant_bytes(self, value): else: raise TypeError(f"Invalid constant type: {type(value)}") - def get_module_hash(self): + def get_module_hash(self) -> bytes: return self.module_hash def get_unique_kernels(self): @@ -2503,7 +2514,7 @@ def __new__(cls, *args, **kwargs): instance.owner = False return instance - def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs): + def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs): """Initialize the stream on a device with an optional specified priority. Args: @@ -2943,18 +2954,14 @@ def can_access(self, other): class Graph: - def __new__(cls, *args, **kwargs): - instance = super(Graph, cls).__new__(cls) - instance.graph_exec = None - return instance - def __init__(self, device: Device, capture_id: int): self.device = device self.capture_id = capture_id - self.module_execs = set() + self.module_execs: Set[ModuleExec] = set() + self.graph_exec: Optional[ctypes.c_void_p] = None def __del__(self): - if not self.graph_exec: + if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec: return # use CUDA context guard to avoid side effects during garbage collection @@ -4164,7 +4171,7 @@ def set_device(ident: Devicelike) -> None: device.make_current() -def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device: +def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device: """Assign a device alias to a CUDA context. This function can be used to create a wp.Device for an external CUDA context. @@ -4591,7 +4598,7 @@ def wait_event(event: Event): get_stream().wait_event(event) -def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True): +def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True): """Get the elapsed time between two recorded events. Both events must have been previously recorded with @@ -4616,7 +4623,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Op return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event) -def wait_stream(other_stream: Stream, event: Event = None): +def wait_stream(other_stream: Stream, event: Optional[Event] = None): """Convenience function for calling :meth:`Stream.wait_stream` on the current stream. Args: @@ -4783,7 +4790,7 @@ def unmap(self): def zeros( - shape: Tuple = None, + shape: Union[int, Tuple[int, ...], List[int], None] = None, dtype=float, device: Devicelike = None, requires_grad: bool = False, @@ -4811,7 +4818,7 @@ def zeros( def zeros_like( - src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None + src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None ) -> warp.array: """Return a zero-initialized array with the same type and dimension of another array @@ -4833,7 +4840,7 @@ def zeros_like( def ones( - shape: Tuple = None, + shape: Union[int, Tuple[int, ...], List[int], None] = None, dtype=float, device: Devicelike = None, requires_grad: bool = False, @@ -4857,7 +4864,7 @@ def ones( def ones_like( - src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None + src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None ) -> warp.array: """Return a one-initialized array with the same type and dimension of another array @@ -4875,7 +4882,7 @@ def ones_like( def full( - shape: Tuple = None, + shape: Union[int, Tuple[int, ...], List[int], None] = None, value=0, dtype=Any, device: Devicelike = None, @@ -4941,7 +4948,11 @@ def full( def full_like( - src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None + src: Array, + value: Any, + device: Devicelike = None, + requires_grad: Optional[bool] = None, + pinned: Optional[bool] = None, ) -> warp.array: """Return an array with all elements initialized to the given value with the same type and dimension of another array @@ -4963,7 +4974,9 @@ def full_like( return arr -def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array: +def clone( + src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None +) -> warp.array: """Clone an existing array, allocates a copy of the src memory Args: @@ -4984,7 +4997,7 @@ def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None def empty( - shape: Tuple = None, + shape: Union[int, Tuple[int, ...], List[int], None] = None, dtype=float, device: Devicelike = None, requires_grad: bool = False, @@ -5017,7 +5030,7 @@ def empty( def empty_like( - src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None + src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None ) -> warp.array: """Return an uninitialized array with the same type and dimension of another array @@ -5390,12 +5403,12 @@ def launch( adj_inputs: Sequence = [], adj_outputs: Sequence = [], device: Devicelike = None, - stream: Stream = None, - adjoint=False, - record_tape=True, - record_cmd=False, - max_blocks=0, - block_dim=256, + stream: Optional[Stream] = None, + adjoint: bool = False, + record_tape: bool = True, + record_cmd: bool = False, + max_blocks: int = 0, + block_dim: int = 256, ): """Launch a Warp kernel on the target device @@ -5821,7 +5834,12 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]: return get_module(m.__name__).options -def capture_begin(device: Devicelike = None, stream=None, force_module_load=None, external=False): +def capture_begin( + device: Devicelike = None, + stream: Optional[Stream] = None, + force_module_load: Optional[bool] = None, + external: bool = False, +): """Begin capture of a CUDA graph Captures all subsequent kernel launches and memory operations on CUDA devices. @@ -5888,16 +5906,15 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None runtime.captures[capture_id] = graph -def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph: - """Ends the capture of a CUDA graph +def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph: + """End the capture of a CUDA graph. Args: - device: The CUDA device where capture began stream: The CUDA stream where capture began Returns: - A Graph object that can be launched with :func:`~warp.capture_launch()` + A :class:`Graph` object that can be launched with :func:`~warp.capture_launch()` """ if stream is not None: @@ -5931,12 +5948,12 @@ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph: return graph -def capture_launch(graph: Graph, stream: Stream = None): +def capture_launch(graph: Graph, stream: Optional[Stream] = None): """Launch a previously captured CUDA graph Args: - graph: A Graph as returned by :func:`~warp.capture_end()` - stream: A Stream to launch the graph on (optional) + graph: A :class:`Graph` as returned by :func:`~warp.capture_end()` + stream: A :class:`Stream` to launch the graph on """ if stream is not None: @@ -5952,24 +5969,28 @@ def capture_launch(graph: Graph, stream: Stream = None): def copy( - dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None + dest: warp.array, + src: warp.array, + dest_offset: int = 0, + src_offset: int = 0, + count: int = 0, + stream: Optional[Stream] = None, ): """Copy array contents from `src` to `dest`. Args: - dest: Destination array, must be at least as big as source buffer + dest: Destination array, must be at least as large as source buffer src: Source array dest_offset: Element offset in the destination array src_offset: Element offset in the source array count: Number of array elements to copy (will copy all elements if set to 0) - stream: The stream on which to perform the copy (optional) + stream: The stream on which to perform the copy The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules: (1) If the destination array is on a CUDA device, use the current stream on the destination device. (2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device. If neither source nor destination are on a CUDA device, no stream is used for the copy. - """ from warp.context import runtime diff --git a/warp/types.py b/warp/types.py index 65d64258b..1493f9cbf 100644 --- a/warp/types.py +++ b/warp/types.py @@ -62,7 +62,9 @@ class Transformation(Generic[Float]): class Array(Generic[DType]): - pass + device: Optional[warp.context.Device] + dtype: type + size: int int_tuple_type_hints = { @@ -1145,7 +1147,7 @@ def __init__(self): class launch_bounds_t(ctypes.Structure): _fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)] - def __init__(self, shape): + def __init__(self, shape: Sequence[int]): if isinstance(shape, int): # 1d launch self.ndim = 1 @@ -1266,7 +1268,7 @@ def type_scalar_type(dtype): } -def type_size_in_bytes(dtype): +def type_size_in_bytes(dtype: type) -> int: size = _type_size_cache.get(dtype) if size is None: @@ -1285,7 +1287,7 @@ def type_size_in_bytes(dtype): return size -def type_to_warp(dtype): +def type_to_warp(dtype: type) -> type: if dtype == float: return float32 elif dtype == int: @@ -1296,7 +1298,7 @@ def type_to_warp(dtype): return dtype -def type_typestr(dtype): +def type_typestr(dtype: type) -> str: if dtype == bool: return "|b1" elif dtype == float16: @@ -1386,25 +1388,25 @@ def type_is_transformation(t): # returns true for all value types (int, float, bool, scalars, vectors, matrices) -def type_is_value(x): +def type_is_value(x: Any) -> builtins.bool: return x in value_types or hasattr(x, "_wp_scalar_type_") # equivalent of the above but for values -def is_int(x): +def is_int(x: Any) -> builtins.bool: return type_is_int(type(x)) -def is_float(x): +def is_float(x: Any) -> builtins.bool: return type_is_float(type(x)) -def is_value(x): +def is_value(x: Any) -> builtins.bool: return type_is_value(type(x)) -# returns true if the passed *instance* is one of the array types -def is_array(a): +def is_array(a) -> builtins.bool: + """Return true if the passed *instance* is one of the array types.""" return isinstance(a, array_types) @@ -1606,7 +1608,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None): return array_ctype -class array(Array): +class array(Array[DType]): """A fixed-size multi-dimensional array containing values of the same type. Attributes: @@ -1635,21 +1637,21 @@ def __new__(cls, *args, **kwargs): def __init__( self, - data: Optional[Union[List, Tuple, npt.NDArray]] = None, - dtype: Union[DType, Any] = Any, - shape: Optional[Tuple[int, ...]] = None, + data: Union[List, Tuple, npt.NDArray, None] = None, + dtype: Any = Any, + shape: Union[int, Tuple[int, ...], List[int], None] = None, strides: Optional[Tuple[int, ...]] = None, length: Optional[int] = None, ptr: Optional[int] = None, capacity: Optional[int] = None, device=None, - pinned: bool = False, - copy: bool = True, - owner: bool = False, # deprecated - pass deleter instead + pinned: builtins.bool = False, + copy: builtins.bool = True, + owner: builtins.bool = False, # deprecated - pass deleter instead deleter: Optional[Callable[[int, int], None]] = None, ndim: Optional[int] = None, grad: Optional[array] = None, - requires_grad: bool = False, + requires_grad: builtins.bool = False, ): """Constructs a new Warp array object @@ -2947,7 +2949,7 @@ def from_ipc_handle( # A base class for non-contiguous arrays, providing the implementation of common methods like # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_(). -class noncontiguous_array_base(Generic[T]): +class noncontiguous_array_base(Array[T]): def __init__(self, array_type_id): self.type_id = array_type_id self.is_contiguous = False @@ -3044,12 +3046,18 @@ def check_index_array(indices, expected_device): raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))") -class indexedarray(noncontiguous_array_base[T]): +class indexedarray(noncontiguous_array_base): # member attributes available during code-gen (e.g.: d = arr.shape[0]) # (initialized when needed) _vars = None - def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None): + def __init__( + self, + data: Optional[array] = None, + indices: Union[array, List[array], None] = None, + dtype=None, + ndim: Optional[int] = None, + ): super().__init__(ARRAY_TYPE_INDEXED) # canonicalize types @@ -3642,7 +3650,7 @@ def __new__(cls, *args, **kwargs): instance.id = None return instance - def __init__(self, data: array, copy: bool = True): + def __init__(self, data: array, copy: builtins.bool = True): """Class representing a sparse grid. Args: @@ -5129,7 +5137,7 @@ def infer_argument_types(args, template_types, arg_names=None): } -def get_type_code(arg_type): +def get_type_code(arg_type: type) -> str: if arg_type == Any: # special case for generics # note: since Python 3.11 Any is a type, so we check for it first @@ -5193,8 +5201,8 @@ def get_type_code(arg_type): raise TypeError(f"Unrecognized type '{arg_type}'") -def get_signature(arg_types, func_name=None, arg_names=None): - type_codes = [] +def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str: + type_codes: List[str] = [] for i, arg_type in enumerate(arg_types): try: type_codes.append(get_type_code(arg_type)) diff --git a/warp/utils.py b/warp/utils.py index 037bba072..50ad15cfe 100644 --- a/warp/utils.py +++ b/warp/utils.py @@ -5,13 +5,15 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. +from __future__ import annotations + import cProfile import ctypes import os import sys import time import warnings -from typing import Any, Optional +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -665,37 +667,38 @@ class ScopedTimer: def __init__( self, - name, - active=True, - print=True, - detailed=False, - dict=None, - use_nvtx=False, - color="rapids", - synchronize=False, - cuda_filter=0, - report_func=None, - skip_tape=False, + name: str, + active: bool = True, + print: bool = True, + detailed: bool = False, + dict: Optional[Dict[str, List[float]]] = None, + use_nvtx: bool = False, + color: Union[int, str] = "rapids", + synchronize: bool = False, + cuda_filter: int = 0, + report_func: Optional[Callable[[List[TimingResult], str], None]] = None, + skip_tape: bool = False, ): """Context manager object for a timer Parameters: - name (str): Name of timer - active (bool): Enables this timer - print (bool): At context manager exit, print elapsed time to sys.stdout - detailed (bool): Collects additional profiling data using cProfile and calls ``print_stats()`` at context exit - dict (dict): A dictionary of lists to which the elapsed time will be appended using ``name`` as a key - use_nvtx (bool): If true, timing functionality is replaced by an NVTX range - color (int or str): ARGB value (e.g. 0x00FFFF) or color name (e.g. 'cyan') associated with the NVTX range - synchronize (bool): Synchronize the CPU thread with any outstanding CUDA work to return accurate GPU timings - cuda_filter (int): Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL`` - report_func (Callable): A callback function to print the activity report (``wp.timing_print()`` is used by default) - skip_tape (bool): If true, the timer will not be recorded in the tape + name: Name of timer + active: Enables this timer + print: At context manager exit, print elapsed time to ``sys.stdout`` + detailed: Collects additional profiling data using cProfile and calls ``print_stats()`` at context exit + dict: A dictionary of lists to which the elapsed time will be appended using ``name`` as a key + use_nvtx: If true, timing functionality is replaced by an NVTX range + color: ARGB value (e.g. 0x00FFFF) or color name (e.g. 'cyan') associated with the NVTX range + synchronize: Synchronize the CPU thread with any outstanding CUDA work to return accurate GPU timings + cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL`` + report_func: A callback function to print the activity report. + If ``None``, :func:`wp.timing_print() ` will be used. + skip_tape: If true, the timer will not be recorded in the tape Attributes: extra_msg (str): Can be set to a string that will be added to the printout at context exit. elapsed (float): The duration of the ``with`` block used with this object - timing_results (list[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter`` + timing_results (List[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter`` """ self.name = name self.active = active and self.enabled @@ -791,7 +794,7 @@ def __exit__(self, exc_type, exc_value, traceback): # Allow temporarily enabling/disabling mempool allocators class ScopedMempool: - def __init__(self, device, enable: bool): + def __init__(self, device: Devicelike, enable: bool): self.device = wp.get_device(device) self.enable = enable @@ -805,7 +808,7 @@ def __exit__(self, exc_type, exc_value, traceback): # Allow temporarily enabling/disabling mempool access class ScopedMempoolAccess: - def __init__(self, target_device, peer_device, enable: bool): + def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool): self.target_device = target_device self.peer_device = peer_device self.enable = enable @@ -820,7 +823,7 @@ def __exit__(self, exc_type, exc_value, traceback): # Allow temporarily enabling/disabling peer access class ScopedPeerAccess: - def __init__(self, target_device, peer_device, enable: bool): + def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool): self.target_device = target_device self.peer_device = peer_device self.enable = enable @@ -834,7 +837,7 @@ def __exit__(self, exc_type, exc_value, traceback): class ScopedCapture: - def __init__(self, device=None, stream=None, force_module_load=None, external=False): + def __init__(self, device: Devicelike = None, stream=None, force_module_load=None, external=False): self.device = device self.stream = stream self.force_module_load = force_module_load @@ -898,31 +901,28 @@ class timing_result_t(ctypes.Structure): class TimingResult: - """Timing result for a single activity. + """Timing result for a single activity.""" - Parameters: - raw_result (warp.utils.timing_result_t): The result structure obtained from C++ (internal use only) + def __init__(self, device, name, filter, elapsed): + self.device: warp.context.Device = device + """The device where the activity was recorded.""" - Attributes: - device (warp.Device): The device where the activity was recorded. - name (str): The activity name. - filter (int): The type of activity (e.g., ``warp.TIMING_KERNEL``). - elapsed (float): The elapsed time in milliseconds. - """ + self.name: str = name + """The activity name.""" - def __init__(self, device, name, filter, elapsed): - self.device = device - self.name = name - self.filter = filter - self.elapsed = elapsed + self.filter: int = filter + """The type of activity (e.g., ``warp.TIMING_KERNEL``).""" + + self.elapsed: float = elapsed + """The elapsed time in milliseconds.""" -def timing_begin(cuda_filter=TIMING_ALL, synchronize=True): +def timing_begin(cuda_filter: int = TIMING_ALL, synchronize: bool = True) -> None: """Begin detailed activity timing. Parameters: - cuda_filter (int): Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL`` - synchronize (bool): Whether to synchronize all CUDA devices before timing starts + cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL`` + synchronize: Whether to synchronize all CUDA devices before timing starts """ if synchronize: @@ -931,14 +931,14 @@ def timing_begin(cuda_filter=TIMING_ALL, synchronize=True): warp.context.runtime.core.cuda_timing_begin(cuda_filter) -def timing_end(synchronize=True): +def timing_end(synchronize: bool = True) -> List[TimingResult]: """End detailed activity timing. Parameters: - synchronize (bool): Whether to synchronize all CUDA devices before timing ends + synchronize: Whether to synchronize all CUDA devices before timing ends Returns: - list[TimingResult]: A list of ``TimingResult`` objects for all recorded activities. + A list of :class:`TimingResult` objects for all recorded activities. """ if synchronize: @@ -977,12 +977,12 @@ def timing_end(synchronize=True): return results -def timing_print(results, indent=""): +def timing_print(results: List[TimingResult], indent: str = "") -> None: """Print timing results. Parameters: - results (list[TimingResult]): List of ``TimingResult`` objects. - indent (str): Optional indentation for the output. + results: List of :class:`TimingResult` objects to print. + indent: Optional indentation to prepend to all output lines. """ if not results: