diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 932b1d0d10..c907c4c2e0 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -5,9 +5,11 @@ import itertools as it import random import sys -from collections.abc import Iterable, Sequence +import types +from collections.abc import Iterable, Iterator, Sequence from functools import partialmethod, wraps from math import ceil +from typing import TYPE_CHECKING, Any, Callable, TypeVar import moderngl import numpy as np @@ -44,10 +46,33 @@ rotation_matrix_transpose, ) +if TYPE_CHECKING: + import numpy.typing as npt + from typing_extensions import Self, TypeAlias -def affects_shader_info_id(func): + from manim.renderer.shader_wrapper import ShaderWrapper + from manim.typing import ( + ManimFloat, + MappingFunction, + MatrixMN, + PathFuncType, + Point3D, + Point3D_Array, + Vector3D, + ) + + TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] + NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] + Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater + + T = TypeVar("T") + + +def affects_shader_info_id( + func: Callable[[OpenGLMobject], OpenGLMobject], +) -> Callable[[OpenGLMobject], OpenGLMobject]: @wraps(func) - def wrapper(self): + def wrapper(self: OpenGLMobject) -> OpenGLMobject: for mob in self.get_family(): func(mob) mob.refresh_shader_wrapper_id() @@ -93,26 +118,26 @@ class OpenGLMobject: def __init__( self, - color=WHITE, - opacity=1, - dim=3, # TODO, get rid of this + color: ParsableManimColor | Iterable[ParsableManimColor] = WHITE, + opacity: float = 1, + dim: int = 3, # TODO, get rid of this # Lighting parameters # Positive gloss up to 1 makes it reflect the light. - gloss=0.0, + gloss: float = 0.0, # Positive shadow up to 1 makes a side opposite the light darker - shadow=0.0, + shadow: float = 0.0, # For shaders - render_primitive=moderngl.TRIANGLES, - texture_paths=None, - depth_test=False, + render_primitive: int = moderngl.TRIANGLES, + texture_paths: dict[str, str] | None = None, + depth_test: bool = False, # If true, the mobject will not get rotated according to camera position - is_fixed_in_frame=False, - is_fixed_orientation=False, + is_fixed_in_frame: bool = False, + is_fixed_orientation: bool = False, # Must match in attributes of vert shader # Event listener - listen_to_events=False, - model_matrix=None, - should_render=True, + listen_to_events: bool = False, + model_matrix: MatrixMN | None = None, + should_render: bool = True, name: str | None = None, **kwargs, ): @@ -199,7 +224,7 @@ def _assert_valid_submobjects(self, submobjects: Iterable[OpenGLMobject]) -> Sel return self._assert_valid_submobjects_internal(submobjects, OpenGLMobject) def _assert_valid_submobjects_internal( - self, submobjects: list[OpenGLMobject], mob_class: type[OpenGLMobject] + self, submobjects: Iterable[OpenGLMobject], mob_class: type[OpenGLMobject] ) -> Self: for i, submob in enumerate(submobjects): if not isinstance(submob, mob_class): @@ -224,14 +249,14 @@ def _assert_valid_submobjects_internal( return self @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) cls._original__init__ = cls.__init__ - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ - def __repr__(self): + def __repr__(self) -> str: return str(self.name) def __sub__(self, other): @@ -247,7 +272,7 @@ def __iadd__(self, mobject): return NotImplemented @classmethod - def set_default(cls, **kwargs): + def set_default(cls, **kwargs) -> None: """Sets the default values of keyword arguments. If this method is called without any additional keyword @@ -294,14 +319,14 @@ def construct(self): else: cls.__init__ = cls._original__init__ - def init_data(self): + def init_data(self) -> None: """Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data. Subclasses can inherit and overwrite this method to extend `self.data`.""" self.points = np.zeros((0, 3)) self.bounding_box = np.zeros((3, 3)) self.rgbas = np.zeros((1, 4)) - def init_colors(self): + def init_colors(self) -> None: """Initializes the colors. Gets called upon creation""" @@ -315,7 +340,7 @@ def init_points(self): # Typically implemented in subclass, unless purposefully left blank pass - def set(self, **kwargs) -> OpenGLMobject: + def set(self, **kwargs) -> Self: """Sets attributes. Mainly to be used along with :attr:`animate` to @@ -349,18 +374,18 @@ def set(self, **kwargs) -> OpenGLMobject: return self - def set_data(self, data): + def set_data(self, data: dict[str, Any]) -> Self: for key in data: self.data[key] = data[key].copy() return self - def set_uniforms(self, uniforms): + def set_uniforms(self, uniforms: dict[str, Any]) -> Self: for key in uniforms: self.uniforms[key] = uniforms[key] # Copy? return self @property - def animate(self): + def animate(self) -> _AnimationBuilder | Self: """Used to animate the application of a method. .. warning:: @@ -448,7 +473,7 @@ def construct(self): return _AnimationBuilder(self) @property - def width(self): + def width(self) -> float: """The width of the mobject. Returns @@ -482,11 +507,11 @@ def construct(self): # Only these methods should directly affect points @width.setter - def width(self, value): + def width(self, value: float) -> None: self.rescale_to_fit(value, 0, stretch=False) @property - def height(self): + def height(self) -> float: """The height of the mobject. Returns @@ -519,11 +544,11 @@ def construct(self): return self.length_over_dim(1) @height.setter - def height(self, value): + def height(self, value: float) -> None: self.rescale_to_fit(value, 1, stretch=False) @property - def depth(self): + def depth(self) -> float: """The depth of the mobject. Returns @@ -540,7 +565,7 @@ def depth(self): return self.length_over_dim(2) @depth.setter - def depth(self, value): + def depth(self, value: float) -> None: self.rescale_to_fit(value, 2, stretch=False) def resize_points(self, new_length, resize_func=resize_array): @@ -549,7 +574,7 @@ def resize_points(self, new_length, resize_func=resize_array): self.refresh_bounding_box() return self - def set_points(self, points): + def set_points(self, points: Point3D_Array) -> Self: if len(points) == len(self.points): self.points[:] = points elif isinstance(points, np.ndarray): @@ -559,23 +584,26 @@ def set_points(self, points): self.refresh_bounding_box() return self - def apply_over_attr_arrays(self, func): + def apply_over_attr_arrays( + self, func: Callable[[npt.NDArray[T]], npt.NDArray[T]] + ) -> Self: + # TODO: OpenGLMobject.get_array_attrs() doesn't even exist! for attr in self.get_array_attrs(): setattr(self, attr, func(getattr(self, attr))) return self - def append_points(self, new_points): + def append_points(self, new_points: Point3D_Array) -> Self: self.points = np.vstack([self.points, new_points]) self.refresh_bounding_box() return self - def reverse_points(self): + def reverse_points(self) -> Self: for mob in self.get_family(): for key in mob.data: mob.data[key] = mob.data[key][::-1] return self - def get_midpoint(self) -> np.ndarray: + def get_midpoint(self) -> Point3D: """Get coordinates of the middle of the path that forms the :class:`~.OpenGLMobject`. Examples @@ -600,11 +628,11 @@ def construct(self): def apply_points_function( self, - func, - about_point=None, - about_edge=ORIGIN, - works_on_bounding_box=False, - ): + func: MappingFunction, + about_point: Point3D | None = None, + about_edge: Vector3D | None = ORIGIN, + works_on_bounding_box: bool = False, + ) -> Self: if about_point is None and about_edge is not None: about_point = self.get_bounding_box_point(about_edge) @@ -630,7 +658,7 @@ def apply_points_function( # Others related to points - def match_points(self, mobject): + def match_points(self, mobject: OpenGLMobject) -> Self: """Edit points, positions, and submobjects to be identical to another :class:`~.OpenGLMobject`, while keeping the style unchanged. @@ -648,29 +676,31 @@ def construct(self): self.wait(0.5) """ self.set_points(mobject.points) + return self - def clear_points(self): + def clear_points(self) -> Self: self.points = np.empty((0, 3)) + return self - def get_num_points(self): + def get_num_points(self) -> int: return len(self.points) - def get_all_points(self): + def get_all_points(self) -> Point3D_Array: if self.submobjects: return np.vstack([sm.points for sm in self.get_family()]) else: return self.points - def has_points(self): + def has_points(self) -> bool: return self.get_num_points() > 0 - def get_bounding_box(self): + def get_bounding_box(self) -> npt.NDArray[float]: if self.needs_new_bounding_box: self.bounding_box = self.compute_bounding_box() self.needs_new_bounding_box = False return self.bounding_box - def compute_bounding_box(self): + def compute_bounding_box(self) -> npt.NDArray[float]: all_points = np.vstack( [ self.points, @@ -690,7 +720,9 @@ def compute_bounding_box(self): mids = (mins + maxs) / 2 return np.array([mins, mids, maxs]) - def refresh_bounding_box(self, recurse_down=False, recurse_up=True): + def refresh_bounding_box( + self, recurse_down: bool = False, recurse_up: bool = True + ) -> Self: for mob in self.get_family(recurse_down): mob.needs_new_bounding_box = True if recurse_up: @@ -698,7 +730,7 @@ def refresh_bounding_box(self, recurse_down=False, recurse_up=True): parent.refresh_bounding_box() return self - def is_point_touching(self, point, buff=MED_SMALL_BUFF): + def is_point_touching(self, point: Point3D, buff: float = MED_SMALL_BUFF) -> bool: bb = self.get_bounding_box() mins = bb[0] - buff maxs = bb[2] + buff @@ -706,22 +738,22 @@ def is_point_touching(self, point, buff=MED_SMALL_BUFF): # Family matters - def __getitem__(self, value): + def __getitem__(self, value: int | slice) -> OpenGLMobject: if isinstance(value, slice): GroupClass = self.get_group_class() return GroupClass(*self.split().__getitem__(value)) return self.split().__getitem__(value) - def __iter__(self): + def __iter__(self) -> Iterator[OpenGLMobject]: return iter(self.split()) - def __len__(self): + def __len__(self) -> int: return len(self.split()) - def split(self): + def split(self) -> Sequence[OpenGLMobject]: return self.submobjects - def assemble_family(self): + def assemble_family(self) -> Self: sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *uniq_chain(*sub_families)] self.refresh_has_updater_status() @@ -730,18 +762,16 @@ def assemble_family(self): parent.assemble_family() return self - def get_family(self, recurse=True): + def get_family(self, recurse: bool = True) -> Sequence[OpenGLMobject]: if recurse and hasattr(self, "family"): return self.family else: return [self] - def family_members_with_points(self): + def family_members_with_points(self) -> Sequence[OpenGLMobject]: return [m for m in self.get_family() if m.has_points()] - def add( - self, *mobjects: OpenGLMobject, update_parent: bool = False - ) -> OpenGLMobject: + def add(self, *mobjects: OpenGLMobject, update_parent: bool = False) -> Self: """Add mobjects as submobjects. The mobjects are added to :attr:`submobjects`. @@ -826,7 +856,9 @@ def add( self.assemble_family() return self - def insert(self, index: int, mobject: OpenGLMobject, update_parent: bool = False): + def insert( + self, index: int, mobject: OpenGLMobject, update_parent: bool = False + ) -> Self: """Inserts a mobject at a specific position into self.submobjects Effectively just calls ``self.submobjects.insert(index, mobject)``, @@ -858,9 +890,7 @@ def insert(self, index: int, mobject: OpenGLMobject, update_parent: bool = False self.assemble_family() return self - def remove( - self, *mobjects: OpenGLMobject, update_parent: bool = False - ) -> OpenGLMobject: + def remove(self, *mobjects: OpenGLMobject, update_parent: bool = False) -> Self: """Remove :attr:`submobjects`. The mobjects are removed from :attr:`submobjects`, if they exist. @@ -894,7 +924,7 @@ def remove( self.assemble_family() return self - def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: + def add_to_back(self, *mobjects: OpenGLMobject) -> Self: # NOTE: is the note true OpenGLMobjects? """Add all passed mobjects to the back of the submobjects. @@ -943,7 +973,7 @@ def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: self.submobjects = list_update(mobjects, self.submobjects) return self - def replace_submobject(self, index, new_submob): + def replace_submobject(self, index: int, new_submob: OpenGLMobject) -> Self: self._assert_valid_submobjects([new_submob]) old_submob = self.submobjects[index] if self in old_submob.parents: @@ -954,7 +984,9 @@ def replace_submobject(self, index, new_submob): # Submobject organization - def arrange(self, direction=RIGHT, center=True, **kwargs): + def arrange( + self, direction: Vector3D = RIGHT, center: bool = True, **kwargs + ) -> Self: """Sorts :class:`~.OpenGLMobject` next to each other on screen. Examples @@ -983,14 +1015,14 @@ def arrange_in_grid( rows: int | None = None, cols: int | None = None, buff: float | tuple[float, float] = MED_SMALL_BUFF, - cell_alignment: np.ndarray = ORIGIN, + cell_alignment: Vector3D = ORIGIN, row_alignments: str | None = None, # "ucd" col_alignments: str | None = None, # "lcr" - row_heights: Iterable[float | None] | None = None, - col_widths: Iterable[float | None] | None = None, + row_heights: Sequence[float | None] | None = None, + col_widths: Sequence[float | None] | None = None, flow_order: str = "rd", **kwargs, - ) -> OpenGLMobject: + ) -> Self: """Arrange submobjects in a grid. Parameters @@ -1086,16 +1118,27 @@ def construct(self): start_pos = self.get_center() # get cols / rows values if given (implicitly) - def init_size(num, alignments, sizes): + def init_size( + num: int | None, + alignments: str | None, + sizes: Sequence[float | None] | None, + name: str, + ) -> int: if num is not None: return num if alignments is not None: return len(alignments) if sizes is not None: return len(sizes) + raise ValueError( + f"At least one of the following parameters: '{name}s', " + f"'{name}_alignments' or " + f"'{name}_{'widths' if name == 'col' else 'heights'}', " + "must not be None" + ) - cols = init_size(cols, col_alignments, col_widths) - rows = init_size(rows, row_alignments, row_heights) + cols = init_size(cols, col_alignments, col_widths, "col") + rows = init_size(rows, row_alignments, row_heights, "row") # calculate rows cols if rows is None and cols is None: @@ -1119,16 +1162,19 @@ def init_size(num, alignments, sizes): buff_x = buff_y = buff # Initialize alignments correctly - def init_alignments(alignments, num, mapping, name, dir): - if alignments is None: + def init_alignments( + str_alignments: str | None, + num: int, + mapping: dict[str, Vector3D], + name: str, + direction: Vector3D, + ) -> Sequence[Vector3D]: + if str_alignments is None: # Use cell_alignment as fallback - return [cell_alignment * dir] * num - if len(alignments) != num: + return [cell_alignment * direction] * num + if len(str_alignments) != num: raise ValueError(f"{name}_alignments has a mismatching size.") - alignments = list(alignments) - for i in range(num): - alignments[i] = mapping[alignments[i]] - return alignments + return [mapping[letter] for letter in str_alignments] row_alignments = init_alignments( row_alignments, @@ -1164,11 +1210,12 @@ def init_alignments(alignments, num, mapping, name, dir): # Reverse row_alignments and row_heights. Necessary since the # grid filling is handled bottom up for simplicity reasons. - def reverse(maybe_list): + def reverse(maybe_list: Sequence[Any] | None) -> Sequence[Any] | None: if maybe_list is not None: maybe_list = list(maybe_list) maybe_list.reverse() return maybe_list + return None row_alignments = reverse(row_alignments) row_heights = reverse(row_heights) @@ -1189,7 +1236,12 @@ def reverse(maybe_list): ] # Initialize row_heights / col_widths correctly using measurements as fallback - def init_sizes(sizes, num, measures, name): + def init_sizes( + sizes: Sequence[float | None] | None, + num: int, + measures: Sequence[float], + name: str, + ) -> Sequence[float]: if sizes is None: sizes = [None] * num if len(sizes) != num: @@ -1222,7 +1274,9 @@ def init_sizes(sizes, num, measures, name): self.move_to(start_pos) return self - def get_grid(self, n_rows, n_cols, height=None, **kwargs): + def get_grid( + self, n_rows: int, n_cols: int, height: float | None = None, **kwargs + ) -> OpenGLGroup: """ Returns a new mobject containing multiple copies of this one arranged in a grid @@ -1233,11 +1287,15 @@ def get_grid(self, n_rows, n_cols, height=None, **kwargs): grid.set_height(height) return grid - def duplicate(self, n: int): - """Returns an :class:`~.OpenGLVGroup` containing ``n`` copies of the mobject.""" + def duplicate(self, n: int) -> OpenGLGroup: + """Returns an :class:`~.OpenGLGroup` containing ``n`` copies of the mobject.""" return self.get_group_class()(*[self.copy() for _ in range(n)]) - def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): + def sort( + self, + point_to_num_func: Callable[[Point3D], float] = lambda p: p[0], + submob_func: Callable[[OpenGLMobject], Any] | None = None, + ) -> Self: """Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" if submob_func is not None: self.submobjects.sort(key=submob_func) @@ -1245,7 +1303,7 @@ def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center())) return self - def shuffle(self, recurse=False): + def shuffle(self, recurse: bool = False) -> Self: """Shuffles the order of :attr:`submobjects` Examples @@ -1268,7 +1326,7 @@ def construct(self): self.assemble_family() return self - def invert(self, recursive=False): + def invert(self, recursive: bool = False) -> Self: """Inverts the list of :attr:`submobjects`. Parameters @@ -1297,7 +1355,7 @@ def construct(self): # Copying - def copy(self, shallow: bool = False): + def copy(self, shallow: bool = False) -> OpenGLMobject: """Create and return an identical copy of the :class:`OpenGLMobject` including all :attr:`submobjects`. @@ -1355,14 +1413,14 @@ def copy(self, shallow: bool = False): # setattr(copy_mobject, attr, value.copy()) return copy_mobject - def deepcopy(self): + def deepcopy(self) -> OpenGLMobject: parents = self.parents self.parents = [] result = copy.deepcopy(self) self.parents = parents return result - def generate_target(self, use_deepcopy: bool = False): + def generate_target(self, use_deepcopy: bool = False) -> OpenGLMobject: self.target = None # Prevent exponential explosion if use_deepcopy: self.target = self.deepcopy() @@ -1370,7 +1428,7 @@ def generate_target(self, use_deepcopy: bool = False): self.target = self.copy() return self.target - def save_state(self, use_deepcopy: bool = False): + def save_state(self, use_deepcopy: bool = False) -> Self: """Save the current state (position, color & size). Can be restored with :meth:`~.OpenGLMobject.restore`.""" if hasattr(self, "saved_state"): # Prevent exponential growth of data @@ -1381,7 +1439,7 @@ def save_state(self, use_deepcopy: bool = False): self.saved_state = self.copy() return self - def restore(self): + def restore(self) -> Self: """Restores the state that was previously saved with :meth:`~.OpenGLMobject.save_state`.""" if not hasattr(self, "saved_state") or self.save_state is None: raise Exception("Trying to restore without having saved") @@ -1390,13 +1448,13 @@ def restore(self): # Updating - def init_updaters(self): + def init_updaters(self) -> None: self.time_based_updaters = [] self.non_time_updaters = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0, recurse=True): + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self for updater in self.time_based_updaters: @@ -1408,19 +1466,24 @@ def update(self, dt=0, recurse=True): submob.update(dt, recurse) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> Sequence[TimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> Sequence[Updater]: return self.time_based_updaters + self.non_time_updaters - def get_family_updaters(self): + def get_family_updaters(self) -> Sequence[Updater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) - def add_updater(self, update_function, index=None, call_updater=False): + def add_updater( + self, + update_function: Updater, + index: int | None = None, + call_updater: bool = False, + ) -> Self: if "dt" in inspect.signature(update_function).parameters: updater_list = self.time_based_updaters else: @@ -1436,14 +1499,14 @@ def add_updater(self, update_function, index=None, call_updater=False): self.update() return self - def remove_updater(self, update_function): + def remove_updater(self, update_function: Updater) -> Self: for updater_list in [self.time_based_updaters, self.non_time_updaters]: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): + def clear_updaters(self, recurse: bool = True) -> Self: self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -1452,20 +1515,20 @@ def clear_updaters(self, recurse=True): submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: OpenGLMobject) -> Self: self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True) -> Self: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -1476,13 +1539,13 @@ def resume_updating(self, recurse=True, call_updater=True): self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self # Transforming operations - def shift(self, vector): + def shift(self, vector: Vector3D) -> Self: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -1496,7 +1559,7 @@ def scale( about_point: Sequence[float] | None = None, about_edge: Sequence[float] = ORIGIN, **kwargs, - ) -> OpenGLMobject: + ) -> Self: r"""Scale the size by a factor. Default behavior is to scale about the center of the mobject. @@ -1552,7 +1615,7 @@ def construct(self): ) return self - def stretch(self, factor, dim, **kwargs): + def stretch(self, factor: float, dim: int, **kwargs) -> Self: def func(points): points[:, dim] *= factor return points @@ -1560,16 +1623,16 @@ def func(points): self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle, axis=OUT): + def rotate_about_origin(self, angle: float, axis: Vector3D = OUT) -> Self: return self.rotate(angle, axis, about_point=ORIGIN) def rotate( self, - angle, - axis=OUT, + angle: float, + axis: Vector3D = OUT, about_point: Sequence[float] | None = None, **kwargs, - ): + ) -> Self: """Rotates the :class:`~.OpenGLMobject` about a certain point.""" rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( @@ -1579,7 +1642,7 @@ def rotate( ) return self - def flip(self, axis=UP, **kwargs): + def flip(self, axis: Vector3D = UP, **kwargs) -> Self: """Flips/Mirrors an mobject about its center. Examples @@ -1598,7 +1661,7 @@ def construct(self): """ return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function, **kwargs): + def apply_function(self, function: MappingFunction, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -1607,16 +1670,16 @@ def apply_function(self, function, **kwargs): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position(self, function: MappingFunction) -> Self: self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions(self, function: MappingFunction) -> Self: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self, matrix: MatrixMN, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -1628,7 +1691,9 @@ def apply_matrix(self, matrix, **kwargs): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function( + self, function: Callable[[complex], complex], **kwargs + ) -> Self: """Applies a complex function to a :class:`OpenGLMobject`. The x and y coordinates correspond to the real and imaginary parts respectively. @@ -1662,7 +1727,7 @@ def R3_func(point): return self.apply_function(R3_func) - def hierarchical_model_matrix(self): + def hierarchical_model_matrix(self) -> MatrixMN: if self.parent is None: return self.model_matrix @@ -1673,7 +1738,12 @@ def hierarchical_model_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(model_matrices))) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag( + self, + direction: Vector3D = RIGHT, + axis: Vector3D = DOWN, + wag_factor: float = 1.0, + ) -> Self: for mob in self.family_members_with_points(): alphas = np.dot(mob.points, np.transpose(axis)) alphas -= min(alphas) @@ -1690,12 +1760,16 @@ def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): # Positioning methods - def center(self): + def center(self) -> Self: """Moves the mobject to the center of the Scene.""" self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self, + direction: Vector3D, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -1711,22 +1785,30 @@ def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self, + corner: Vector3D = LEFT + DOWN, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge( + self, + edge: Vector3D = LEFT, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: return self.align_on_border(edge, buff) def next_to( self, - mobject_or_point, - direction=RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge=ORIGIN, - submobject_to_align=None, - index_of_submobject_to_align=None, - coor_mask=np.array([1, 1, 1]), - ): + mobject_or_point: OpenGLMobject | Point3D, + direction: Vector3D = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + aligned_edge: Vector3D = ORIGIN, + submobject_to_align: OpenGLMobject | None = None, + index_of_submobject_to_align: int | None = None, + coor_mask: Point3D = np.array([1, 1, 1]), + ) -> Self: """Move this :class:`~.OpenGLMobject` next to another's :class:`~.OpenGLMobject` or coordinate. Examples @@ -1768,7 +1850,7 @@ def construct(self): self.shift((target_point - point_to_align + buff * direction) * coor_mask) return self - def shift_onto_screen(self, **kwargs): + def shift_onto_screen(self, **kwargs) -> Self: space_lengths = [config["frame_x_radius"], config["frame_y_radius"]] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) @@ -1779,7 +1861,7 @@ def shift_onto_screen(self, **kwargs): self.to_edge(vect, **kwargs) return self - def is_off_screen(self): + def is_off_screen(self) -> bool: if self.get_left()[0] > config.frame_x_radius: return True if self.get_right()[0] < config.frame_x_radius: @@ -1790,10 +1872,12 @@ def is_off_screen(self): return True return False - def stretch_about_point(self, factor, dim, point): + def stretch_about_point(self, factor: float, dim: int, point: Point3D) -> Self: return self.stretch(factor, dim, about_point=point) - def rescale_to_fit(self, length, dim, stretch=False, **kwargs): + def rescale_to_fit( + self, length: float, dim: int, stretch: bool = False, **kwargs + ) -> Self: old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -1803,7 +1887,7 @@ def rescale_to_fit(self, length, dim, stretch=False, **kwargs): self.scale(length / old_length, **kwargs) return self - def stretch_to_fit_width(self, width, **kwargs): + def stretch_to_fit_width(self, width: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a width, not keeping height/depth proportional. Returns @@ -1828,15 +1912,15 @@ def stretch_to_fit_width(self, width, **kwargs): """ return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def stretch_to_fit_height(self, height, **kwargs): + def stretch_to_fit_height(self, height: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a height, not keeping width/height proportional.""" return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def stretch_to_fit_depth(self, depth, **kwargs): + def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a depth, not keeping width/height proportional.""" return self.rescale_to_fit(depth, 1, stretch=True, **kwargs) - def set_width(self, width, stretch=False, **kwargs): + def set_width(self, width: float, stretch: bool = False, **kwargs) -> Self: """Scales the :class:`~.OpenGLMobject` to fit a width while keeping height/depth proportional. Returns @@ -1863,38 +1947,38 @@ def set_width(self, width, stretch=False, **kwargs): scale_to_fit_width = set_width - def set_height(self, height, stretch=False, **kwargs): + def set_height(self, height: float, stretch: bool = False, **kwargs) -> Self: """Scales the :class:`~.OpenGLMobject` to fit a height while keeping width/depth proportional.""" return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs) scale_to_fit_height = set_height - def set_depth(self, depth, stretch=False, **kwargs): + def set_depth(self, depth: float, stretch: bool = False, **kwargs): """Scales the :class:`~.OpenGLMobject` to fit a depth while keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs) scale_to_fit_depth = set_depth - def set_coord(self, value, dim, direction=ORIGIN): + def set_coord(self, value: float, dim: int, direction: Vector3D = ORIGIN) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x, direction=ORIGIN): + def set_x(self, x: float, direction: Vector3D = ORIGIN) -> Self: """Set x value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(x, 0, direction) - def set_y(self, y, direction=ORIGIN): + def set_y(self, y: float, direction: Vector3D = ORIGIN) -> Self: """Set y value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(y, 1, direction) - def set_z(self, z, direction=ORIGIN): + def set_z(self, z: float, direction: Vector3D = ORIGIN) -> Self: """Set z value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor=1.5, **kwargs): + def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1.0 / factor) @@ -1902,10 +1986,10 @@ def space_out_submobjects(self, factor=1.5, **kwargs): def move_to( self, - point_or_mobject, - aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1]), - ): + point_or_mobject: Point3D | OpenGLMobject, + aligned_edge: Vector3D = ORIGIN, + coor_mask: Point3D = np.array([1, 1, 1]), + ) -> Self: """Move center of the :class:`~.OpenGLMobject` to certain coordinate.""" if isinstance(point_or_mobject, OpenGLMobject): target = point_or_mobject.get_bounding_box_point(aligned_edge) @@ -1915,7 +1999,12 @@ def move_to( self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject, dim_to_match=0, stretch=False): + def replace( + self, + mobject: OpenGLMobject, + dim_to_match: int = 0, + stretch: bool = False, + ) -> Self: if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -1937,13 +2026,13 @@ def surround( dim_to_match: int = 0, stretch: bool = False, buff: float = MED_SMALL_BUFF, - ): + ) -> Self: self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: Point3D, end: Point3D) -> Self: curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -1968,7 +2057,13 @@ def put_start_and_end_on(self, start, end): # Color functions - def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): + def set_rgba_array( + self, + color: ParsableManimColor | Iterable[ParsableManimColor] | None = None, + opacity: float | Iterable[float] | None = None, + name: str = "rgbas", + recurse: bool = True, + ) -> Self: if color is not None: rgbs = np.array([color_to_rgb(c) for c in listify(color)]) if opacity is not None: @@ -1998,7 +2093,12 @@ def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): mob.data[name] = rgbas.copy() return self - def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): + def set_rgba_array_direct( + self, + rgbas: npt.NDArray[RGBA_Array_Float], + name: str = "rgbas", + recurse: bool = True, + ) -> Self: """Directly set rgba data from `rgbas` and optionally do the same recursively with submobjects. This can be used if the `rgbas` have already been generated with the correct shape and simply need to be set. @@ -2015,7 +2115,12 @@ def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): for mob in self.get_family(recurse): mob.data[name] = rgbas.copy() - def set_color(self, color: ParsableManimColor | None, opacity=None, recurse=True): + def set_color( + self, + color: ParsableManimColor | Iterable[ParsableManimColor] | None, + opacity: float | Iterable[float] | None = None, + recurse: bool = True, + ) -> Self: self.set_rgba_array(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array # in case they implement set_color differently @@ -2028,24 +2133,25 @@ def set_color(self, color: ParsableManimColor | None, opacity=None, recurse=True submob.set_color(color, recurse=True) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity( + self, opacity: float | Iterable[float] | None, recurse: bool = True + ) -> Self: self.set_rgba_array(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: submob.set_opacity(opacity, recurse=True) return self - def get_color(self): + def get_color(self) -> str: return rgb_to_hex(self.rgbas[0, :3]) - def get_opacity(self): + def get_opacity(self) -> float: return self.rgbas[0, 3] - def set_color_by_gradient(self, *colors): - self.set_submobject_colors_by_gradient(*colors) - return self + def set_color_by_gradient(self, *colors: ParsableManimColor) -> Self: + return self.set_submobject_colors_by_gradient(*colors) - def set_submobject_colors_by_gradient(self, *colors): + def set_submobject_colors_by_gradient(self, *colors: ParsableManimColor) -> Self: if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -2059,21 +2165,21 @@ def set_submobject_colors_by_gradient(self, *colors): mob.set_color(color) return self - def fade(self, darkness=0.5, recurse=True): - self.set_opacity(1.0 - darkness, recurse=recurse) + def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self: + return self.set_opacity(1.0 - darkness, recurse=recurse) - def get_gloss(self): + def get_gloss(self) -> float: return self.gloss - def set_gloss(self, gloss, recurse=True): + def set_gloss(self, gloss: float, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.gloss = gloss return self - def get_shadow(self): + def get_shadow(self) -> float: return self.shadow - def set_shadow(self, shadow, recurse=True): + def set_shadow(self, shadow: float, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.shadow = shadow return self @@ -2081,8 +2187,11 @@ def set_shadow(self, shadow, recurse=True): # Background rectangle def add_background_rectangle( - self, color: ParsableManimColor | None = None, opacity: float = 0.75, **kwargs - ): + self, + color: ParsableManimColor | None = None, + opacity: float = 0.75, + **kwargs, + ) -> Self: # TODO, this does not behave well when the mobject has points, # since it gets displayed on top """Add a BackgroundRectangle as submobject. @@ -2120,39 +2229,39 @@ def add_background_rectangle( self.add_to_back(self.background_rectangle) return self - def add_background_rectangle_to_submobjects(self, **kwargs): + def add_background_rectangle_to_submobjects(self, **kwargs) -> Self: for submobject in self.submobjects: submobject.add_background_rectangle(**kwargs) return self - def add_background_rectangle_to_family_members_with_points(self, **kwargs): + def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Self: for mob in self.family_members_with_points(): mob.add_background_rectangle(**kwargs) return self # Getters - def get_bounding_box_point(self, direction): + def get_bounding_box_point(self, direction: Vector3D) -> Point3D: bb = self.get_bounding_box() indices = (np.sign(direction) + 1).astype(int) return np.array([bb[indices[i]][i] for i in range(3)]) - def get_edge_center(self, direction) -> np.ndarray: + def get_edge_center(self, direction: Vector3D) -> Point3D: """Get edge coordinates for certain direction.""" return self.get_bounding_box_point(direction) - def get_corner(self, direction) -> np.ndarray: + def get_corner(self, direction: Vector3D) -> Point3D: """Get corner coordinates for certain direction.""" return self.get_bounding_box_point(direction) - def get_center(self) -> np.ndarray: + def get_center(self) -> Point3D: """Get center coordinates.""" return self.get_bounding_box()[1] - def get_center_of_mass(self): + def get_center_of_mass(self) -> Point3D: return self.get_all_points().mean(0) - def get_boundary_point(self, direction): + def get_boundary_point(self, direction: Vector3D) -> Point3D: all_points = self.get_all_points() boundary_directions = all_points - self.get_center() norms = np.linalg.norm(boundary_directions, axis=1) @@ -2160,7 +2269,7 @@ def get_boundary_point(self, direction): index = np.argmax(np.dot(boundary_directions, np.array(direction).T)) return all_points[index] - def get_continuous_bounding_box_point(self, direction): + def get_continuous_bounding_box_point(self, direction: Vector3D) -> Point3D: dl, center, ur = self.get_bounding_box() corner_vect = ur - center return center + direction / np.max( @@ -2174,86 +2283,86 @@ def get_continuous_bounding_box_point(self, direction): ), ) - def get_top(self) -> np.ndarray: + def get_top(self) -> Point3D: """Get top coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(UP) - def get_bottom(self) -> np.ndarray: + def get_bottom(self) -> Point3D: """Get bottom coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(DOWN) - def get_right(self) -> np.ndarray: + def get_right(self) -> Point3D: """Get right coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(RIGHT) - def get_left(self) -> np.ndarray: + def get_left(self) -> Point3D: """Get left coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(LEFT) - def get_zenith(self) -> np.ndarray: + def get_zenith(self) -> Point3D: """Get zenith coordinates of a box bounding a 3D :class:`~.OpenGLMobject`.""" return self.get_edge_center(OUT) - def get_nadir(self) -> np.ndarray: + def get_nadir(self) -> Point3D: """Get nadir (opposite the zenith) coordinates of a box bounding a 3D :class:`~.OpenGLMobject`.""" return self.get_edge_center(IN) - def length_over_dim(self, dim): + def length_over_dim(self, dim: int) -> float: bb = self.get_bounding_box() return abs((bb[2] - bb[0])[dim]) - def get_width(self): + def get_width(self) -> float: """Returns the width of the mobject.""" return self.length_over_dim(0) - def get_height(self): + def get_height(self) -> float: """Returns the height of the mobject.""" return self.length_over_dim(1) - def get_depth(self): + def get_depth(self) -> float: """Returns the depth of the mobject.""" return self.length_over_dim(2) - def get_coord(self, dim: int, direction=ORIGIN): + def get_coord(self, dim: int, direction: Vector3D = ORIGIN) -> ManimFloat: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" return self.get_bounding_box_point(direction)[dim] - def get_x(self, direction=ORIGIN) -> np.float64: + def get_x(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns x coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(0, direction) - def get_y(self, direction=ORIGIN) -> np.float64: + def get_y(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns y coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(1, direction) - def get_z(self, direction=ORIGIN) -> np.float64: + def get_z(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns z coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(2, direction) - def get_start(self): + def get_start(self) -> Point3D: """Returns the point, where the stroke that surrounds the :class:`~.OpenGLMobject` starts.""" self.throw_error_if_no_points() return np.array(self.points[0]) - def get_end(self): + def get_end(self) -> Point3D: """Returns the point, where the stroke that surrounds the :class:`~.OpenGLMobject` ends.""" self.throw_error_if_no_points() return np.array(self.points[-1]) - def get_start_and_end(self): + def get_start_and_end(self) -> tuple[Point3D, Point3D]: """Returns starting and ending point of a stroke as a ``tuple``.""" return self.get_start(), self.get_end() - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> Point3D: points = self.points i, subalpha = integer_interpolate(0, len(points) - 1, alpha) return interpolate(points[i], points[i + 1], subalpha) - def pfp(self, alpha): + def pfp(self, alpha: float) -> Point3D: """Abbreviation for point_from_proportion""" return self.point_from_proportion(alpha) - def get_pieces(self, n_pieces): + def get_pieces(self, n_pieces: int) -> OpenGLMobject: template = self.copy() template.submobjects = [] alphas = np.linspace(0, 1, n_pieces + 1) @@ -2264,34 +2373,36 @@ def get_pieces(self, n_pieces): ) ) - def get_z_index_reference_point(self): + def get_z_index_reference_point(self) -> Point3D: # TODO, better place to define default z_index_group? z_index_group = getattr(self, "z_index_group", self) return z_index_group.get_center() # Match other mobject properties - def match_color(self, mobject: OpenGLMobject): + def match_color(self, mobject: OpenGLMobject) -> Self: """Match the color with the color of another :class:`~.OpenGLMobject`.""" return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: OpenGLMobject, dim, **kwargs): + def match_dim_size(self, mobject: OpenGLMobject, dim: int, **kwargs) -> Self: """Match the specified dimension with the dimension of another :class:`~.OpenGLMobject`.""" return self.rescale_to_fit(mobject.length_over_dim(dim), dim, **kwargs) - def match_width(self, mobject: OpenGLMobject, **kwargs): + def match_width(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the width with the width of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: OpenGLMobject, **kwargs): + def match_height(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the height with the height of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: OpenGLMobject, **kwargs): + def match_depth(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the depth with the depth of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject: OpenGLMobject, dim, direction=ORIGIN): + def match_coord( + self, mobject: OpenGLMobject, dim: int, direction: Vector3D = ORIGIN + ) -> Self: """Match the coordinates with the coordinates of another :class:`~.OpenGLMobject`.""" return self.set_coord( mobject.get_coord(dim, direction), @@ -2299,23 +2410,23 @@ def match_coord(self, mobject: OpenGLMobject, dim, direction=ORIGIN): direction=direction, ) - def match_x(self, mobject, direction=ORIGIN): + def match_x(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match x coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 0, direction) - def match_y(self, mobject, direction=ORIGIN): + def match_y(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match y coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 1, direction) - def match_z(self, mobject, direction=ORIGIN): + def match_z(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match z coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 2, direction) def align_to( self, - mobject_or_point: OpenGLMobject | Sequence[float], - direction=ORIGIN, - ): + mobject_or_point: OpenGLMobject | Point3D, + direction: Vector3D = ORIGIN, + ) -> Self: """ Examples: mob1.align_to(mob2, UP) moves mob1 vertically so that its @@ -2335,21 +2446,22 @@ def align_to( self.set_coord(point[dim], dim, direction) return self - def get_group_class(self): + def get_group_class(self) -> type[OpenGLGroup]: return OpenGLGroup @staticmethod - def get_mobject_type_class(): + def get_mobject_type_class() -> type[OpenGLMobject]: """Return the base class of this mobject type.""" return OpenGLMobject # Alignment - def align_data_and_family(self, mobject): + def align_data_and_family(self, mobject: OpenGLMobject) -> Self: self.align_family(mobject) self.align_data(mobject) + return self - def align_data(self, mobject): + def align_data(self, mobject: OpenGLMobject) -> Self: # In case any data arrays get resized when aligned to shader data # self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): @@ -2365,14 +2477,15 @@ def align_data(self, mobject): mob1.data[key] = resize_preserving_order(arr1, len(arr2)) elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) + return self - def align_points(self, mobject): + def align_points(self, mobject: OpenGLMobject) -> Self: max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject): + def align_family(self, mobject: OpenGLMobject) -> Self: mob1 = self mob2 = mobject n1 = len(mob1) @@ -2385,14 +2498,14 @@ def align_family(self, mobject): sm1.align_family(sm2) return self - def push_self_into_submobjects(self): + def push_self_into_submobjects(self) -> Self: copy = self.deepcopy() copy.submobjects = [] self.resize_points(0) self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self, n: int) -> Self: if n == 0: return self @@ -2421,7 +2534,13 @@ def add_n_more_submobjects(self, n): # Interpolate - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()): + def interpolate( + self, + mobject1: OpenGLMobject, + mobject2: OpenGLMobject, + alpha: float, + path_func: PathFuncType = straight_path(), + ) -> Self: """Turns this :class:`~.OpenGLMobject` into an interpolation between ``mobject1`` and ``mobject2``. @@ -2474,7 +2593,9 @@ def construct(self): ) return self - def pointwise_become_partial(self, mobject, a, b): + def pointwise_become_partial( + self, mobject: OpenGLMobject, a: float, b: float + ) -> None: """ Set points in such a way as to become only part of mobject. @@ -2491,7 +2612,7 @@ def become( match_depth: bool = False, match_center: bool = False, stretch: bool = False, - ): + ) -> Self: """Edit all data and submobjects to be identical to another :class:`~.OpenGLMobject` @@ -2551,7 +2672,7 @@ def construct(self): # Locking data - def lock_data(self, keys): + def lock_data(self, keys: Iterable[str]) -> None: """ To speed up some animations, particularly transformations, it can be handy to acknowledge which pieces of data @@ -2565,7 +2686,9 @@ def lock_data(self, keys): self.refresh_shader_data() self.locked_data_keys = set(keys) - def lock_matching_data(self, mobject1, mobject2): + def lock_matching_data( + self, mobject1: OpenGLMobject, mobject2: OpenGLMobject + ) -> Self: for sm, sm1, sm2 in zip( self.get_family(), mobject1.get_family(), @@ -2582,57 +2705,57 @@ def lock_matching_data(self, mobject1, mobject2): ) return self - def unlock_data(self): + def unlock_data(self) -> None: for mob in self.get_family(): mob.locked_data_keys = set() # Operations touching shader uniforms @affects_shader_info_id - def fix_in_frame(self): + def fix_in_frame(self) -> Self: self.is_fixed_in_frame = 1.0 return self @affects_shader_info_id - def fix_orientation(self): + def fix_orientation(self) -> Self: self.is_fixed_orientation = 1.0 self.fixed_orientation_center = tuple(self.get_center()) self.depth_test = True return self @affects_shader_info_id - def unfix_from_frame(self): + def unfix_from_frame(self) -> Self: self.is_fixed_in_frame = 0.0 return self @affects_shader_info_id - def unfix_orientation(self): + def unfix_orientation(self) -> Self: self.is_fixed_orientation = 0.0 self.fixed_orientation_center = (0, 0, 0) self.depth_test = False return self @affects_shader_info_id - def apply_depth_test(self): + def apply_depth_test(self) -> Self: self.depth_test = True return self @affects_shader_info_id - def deactivate_depth_test(self): + def deactivate_depth_test(self) -> Self: self.depth_test = False return self # Shader code manipulation - def replace_shader_code(self, old, new): + def replace_shader_code(self, old_code: str, new_code: str) -> Self: # TODO, will this work with VMobject structure, given # that it does not simpler return shader_wrappers of # family? for wrapper in self.get_shader_wrapper_list(): - wrapper.replace_code(old, new) + wrapper.replace_code(old_code, new_code) return self - def set_color_by_code(self, glsl_code): + def set_color_by_code(self, glsl_code: str) -> Self: """ Takes a snippet of code and inserts it into a context which has the following variables: @@ -2644,11 +2767,11 @@ def set_color_by_code(self, glsl_code): def set_color_by_xyz_func( self, - glsl_snippet, - min_value=-5.0, - max_value=5.0, - colormap="viridis", - ): + glsl_snippet: str, + min_value: float = -5.0, + max_value: float = 5.0, + colormap: str = "viridis", + ) -> Self: """ Pass in a glsl expression in terms of x, y and z which returns a float. @@ -2665,11 +2788,11 @@ def set_color_by_xyz_func( # For shader data - def refresh_shader_wrapper_id(self): + def refresh_shader_wrapper_id(self) -> Self: self.get_shader_wrapper().refresh_id() return self - def get_shader_wrapper(self): + def get_shader_wrapper(self) -> ShaderWrapper: from manim.renderer.shader_wrapper import ShaderWrapper # if hasattr(self, "__shader_wrapper"): @@ -2686,7 +2809,7 @@ def get_shader_wrapper(self): ) return self.shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> Sequence[ShaderWrapper]: shader_wrappers = it.chain( [self.get_shader_wrapper()], *(sm.get_shader_wrapper_list() for sm in self.submobjects), @@ -2703,7 +2826,7 @@ def get_shader_wrapper_list(self): result.append(shader_wrapper) return result - def check_data_alignment(self, array, data_key): + def check_data_alignment(self, array: npt.NDArray, data_key: str) -> Self: # Makes sure that self.data[key] can be broadcast into # the given array, meaning its length has to be either 1 # or the length of the array @@ -2715,45 +2838,50 @@ def check_data_alignment(self, array, data_key): ) return self - def get_resized_shader_data_array(self, length): + def get_resized_shader_data_array(self, length: float) -> npt.NDArray: # If possible, try to populate an existing array, rather # than recreating it each frame points = self.points shader_data = np.zeros(len(points), dtype=self.shader_dtype) return shader_data - def read_data_to_shader(self, shader_data, shader_data_key, data_key): + def read_data_to_shader( + self, + shader_data: npt.NDArray, # has structured data type, ex. ("point", np.float32, (3,)) + shader_data_key: str, + data_key: str, + ) -> None: if data_key in self.locked_data_keys: return self.check_data_alignment(shader_data, data_key) shader_data[shader_data_key] = self.data[data_key] - def get_shader_data(self): + def get_shader_data(self) -> npt.NDArray: shader_data = self.get_resized_shader_data_array(self.get_num_points()) self.read_data_to_shader(shader_data, "point", "points") return shader_data - def refresh_shader_data(self): + def refresh_shader_data(self) -> None: self.get_shader_data() - def get_shader_uniforms(self): + def get_shader_uniforms(self) -> dict[str, Any]: return self.uniforms - def get_shader_vert_indices(self): + def get_shader_vert_indices(self) -> Sequence[int]: return self.shader_indices @property - def submobjects(self): + def submobjects(self) -> Sequence[OpenGLMobject]: return self._submobjects if hasattr(self, "_submobjects") else [] @submobjects.setter - def submobjects(self, submobject_list): + def submobjects(self, submobject_list: Iterable[OpenGLMobject]) -> None: self.remove(*self.submobjects) self.add(*submobject_list) # Errors - def throw_error_if_no_points(self): + def throw_error_if_no_points(self) -> None: if not self.has_points(): message = ( "Cannot call OpenGLMobject.{} " + "for a OpenGLMobject with no points" @@ -2763,38 +2891,42 @@ def throw_error_if_no_points(self): class OpenGLGroup(OpenGLMobject): - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: OpenGLMobject, **kwargs): super().__init__(**kwargs) self.add(*mobjects) class OpenGLPoint(OpenGLMobject): def __init__( - self, location=ORIGIN, artificial_width=1e-6, artificial_height=1e-6, **kwargs + self, + location: Point3D = ORIGIN, + artificial_width: float = 1e-6, + artificial_height: float = 1e-6, + **kwargs, ): self.artificial_width = artificial_width self.artificial_height = artificial_height super().__init__(**kwargs) self.set_location(location) - def get_width(self): + def get_width(self) -> float: return self.artificial_width - def get_height(self): + def get_height(self) -> float: return self.artificial_height - def get_location(self): + def get_location(self) -> Point3D: return self.points[0].copy() - def get_bounding_box_point(self, *args, **kwargs): + def get_bounding_box_point(self, *args, **kwargs) -> Point3D: return self.get_location() - def set_location(self, new_loc): + def set_location(self, new_loc: Point3D) -> None: self.set_points(np.array(new_loc, ndmin=2, dtype=float)) class _AnimationBuilder: - def __init__(self, mobject): + def __init__(self, mobject: OpenGLMobject): self.mobject = mobject self.mobject.generate_target() @@ -2806,7 +2938,7 @@ def __init__(self, mobject): self.cannot_pass_args = False self.anim_args = {} - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> Self: if self.cannot_pass_args: raise ValueError( "Animation arguments must be passed before accessing methods and can only be passed once", @@ -2817,7 +2949,7 @@ def __call__(self, **kwargs): return self - def __getattr__(self, method_name): + def __getattr__(self, method_name: str) -> Callable[..., Self]: method = getattr(self.mobject.target, method_name) has_overridden_animation = hasattr(method, "_override_animate") @@ -2845,7 +2977,7 @@ def update_target(*method_args, **method_kwargs): return update_target - def build(self): + def build(self) -> _MethodAnimation: from manim.animation.transform import _MethodAnimation if self.overridden_animation: @@ -2859,7 +2991,7 @@ def build(self): return anim -def override_animate(method): +def override_animate(method: types.FunctionType) -> types.FunctionType: r"""Decorator for overriding method animations. This allows to specify a method (returning an :class:`~.Animation`)