Skip to content

Total time updaters #1520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/source/_static/pull-requests.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion manim/animation/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def update_mobjects(self, dt: float) -> None:
nothing to self.mobject.
"""
for mob in self.get_all_mobjects_to_update():
mob.update(dt)
mob._apply_updaters(dt)

def get_all_mobjects_to_update(self) -> List[Mobject]:
"""Get all mobjects to be updated during the animation.
Expand Down
129 changes: 74 additions & 55 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
color_gradient,
interpolate_color,
)
from ..utils.deprecation import deprecated
from ..utils.exceptions import MultiAnimationOverrideException
from ..utils.iterables import list_update, remove_list_redundancies
from ..utils.paths import straight_path
Expand Down Expand Up @@ -175,7 +176,7 @@ def add_animation_override(
animation_class
The animation type to be overridden
override_func
The function returning an aniamtion replacing the default animation. It gets
The function returning an animation replacing the default animation. It gets
passed the parameters given to the animnation constructor.

Raises
Expand Down Expand Up @@ -773,15 +774,20 @@ def generate_target(self, use_deepcopy=False):

# Updating

def update(self, dt: float = 0, recursive: bool = True) -> "Mobject":
@deprecated(since="v0.7.0", until="v0.8.0", replacement="_apply_updaters")
def update(self, *args, **kwargs):
self._apply_updaters(*args, **kwargs)

def _apply_updaters(self, dt: float = 0, recursive: bool = True) -> "Mobject":
"""Apply all updaters.

Does nothing if updating is suspended.

Parameters
----------
dt
The parameter ``dt`` to pass to the update functions. Usually this is the time in seconds since the last call of ``update``.
The parameter ``dt`` to pass to the update functions. Usually this is the
time in seconds since the last call of ``_apply_updaters``.
recursive
Whether to recursively update all submobjects.

Expand All @@ -800,33 +806,12 @@ def update(self, dt: float = 0, recursive: bool = True) -> "Mobject":
return self
for updater in self.updaters:
parameters = get_parameters(updater)
if "dt" in parameters:
updater(self, dt)
else:
updater(self)
updater(self, dt)
if recursive:
for submob in self.submobjects:
submob.update(dt, recursive)
submob._apply_updaters(dt, recursive)
return self

def get_time_based_updaters(self) -> List[Updater]:
"""Return all updaters using the ``dt`` parameter.

The updaters use this parameter as the input for difference in time.

Returns
-------
List[:class:`Callable`]
The list of time based updaters.

See Also
--------
:meth:`get_updaters`
:meth:`has_time_based_updater`

"""
return [updater for updater in self.updaters if "dt" in get_parameters(updater)]

def has_time_based_updater(self) -> bool:
"""Test if ``self`` has a time based updater.

Expand All @@ -835,15 +820,9 @@ def has_time_based_updater(self) -> bool:
class:`bool`
``True`` if at least one updater uses the ``dt`` parameter, ``False`` otherwise.

See Also
--------
:meth:`get_time_based_updaters`

"""
for updater in self.updaters:
if "dt" in get_parameters(updater):
return True
return False
return any(updater.time_based for updater in self.updaters)

def get_updaters(self) -> List[Updater]:
"""Return all updaters.
Expand All @@ -856,7 +835,6 @@ def get_updaters(self) -> List[Updater]:
See Also
--------
:meth:`add_updater`
:meth:`get_time_based_updaters`

"""
return self.updaters
Expand All @@ -866,24 +844,38 @@ def get_family_updaters(self):

def add_updater(
self,
update_function: Updater,
updater: Updater,
*,
pass_relative_times: Optional[bool] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass_relative_times: Optional[bool] = None,
pass_relative_times: optional[bool] = True,

index: Optional[int] = None,
call_updater: bool = False,
) -> "Mobject":
"""Add an update function to this mobject.

Update functions, or updaters in short, are functions that are applied to the Mobject in every frame.
Update functions, or updaters in short, are functions that are applied to the
Mobject before every frame.

Parameters
----------
update_function
updater
The update function to be added.
Whenever :meth:`update` is called, this update function gets called using ``self`` as the first parameter.
The updater can have a second parameter ``dt``. If it uses this parameter, it gets called using a second value ``dt``, usually representing the time in seconds since the last call of :meth:`update`.
Whenever :meth:`update` is called, this update function gets called using
``self`` as the first parameter. The updater can have a second parameter
that gets passed information about the execution time.
pass_relative_times
Whether the time information passed to the updater should be relative
(usually the time since last call) or absolute (sum of all relative times
since the updater was added). If ``pass_relative_times`` is ``None``
relative times get passed, except if the second parameter of the updater is
named ``t`` or ``time``.
Comment on lines +865 to +870
Copy link
Member

@jsonvillanueva jsonvillanueva May 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather pass_relative_times have a default argument of True than be None and later make it False if necessary. Also, I don't particularly like the feature of changing the updater to total time if the second argument is t /time... seems a bit hacky.

index
The index at which the new updater should be added in ``self.updaters``. In case ``index`` is ``None`` the updater will be added at the end.
The index at which the new updater should be added in the list of updaters.
The order of this list determines the execution order of the updaters.
In case ``index`` is ``None`` the updater will be added at the end, where it
would be last in the execution order.
call_updater
Whether or not to call the updater initially. If ``True``, the updater will be called using ``dt=0``.
Whether or not to call the updater initially after adding it (using ``0``
as time if applicable).

Returns
-------
Expand Down Expand Up @@ -925,22 +917,47 @@ def construct(self):
:class:`~.UpdateFromFunc`
"""

# Test if updater is time based and whether to use time difference
parameters = list(get_parameters(updater).keys())
time_based = len(parameters) > 1
if time_based:
if parameters[1] not in ["t", "time"] and pass_relative_times is None:
pass_relative_times = True
pass_relative_times = bool(pass_relative_times)
Comment on lines +924 to +926
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, I don't particularly like this hard coded check for the second argument, but with a default of True, you can simplify the logic to only check for if it should be set to False

Suggested change
if parameters[1] not in ["t", "time"] and pass_relative_times is None:
pass_relative_times = True
pass_relative_times = bool(pass_relative_times)
if parameters[1] in ["t", "time"]:
pass_relative_times = False


# Wrap updaters to allow calling all of them using updatable and dt as parameter
if time_based:
if pass_relative_times:
unified_updater = updater
else:

def unified_updater(mob, dt):
unified_updater.total_time += dt
updater(mob, unified_updater.total_time)

unified_updater.total_time = 0
else:
unified_updater = lambda mob, dt: updater(mob)

unified_updater.time_based = time_based
unified_updater.base_func = updater # used to enable removing

if index is None:
self.updaters.append(update_function)
self.updaters.append(unified_updater)
else:
self.updaters.insert(index, update_function)
self.updaters.insert(index, unified_updater)
if call_updater:
update_function(self, 0)
unified_updater(self, 0)
return self

def remove_updater(self, update_function: Updater) -> "Mobject":
def remove_updater(self, updater: Updater) -> "Mobject":
"""Remove an updater.

If the same updater is applied multiple times, every instance gets removed.

Parameters
----------
update_function
updater
The update function to be removed.


Expand All @@ -956,8 +973,12 @@ def remove_updater(self, update_function: Updater) -> "Mobject":
:meth:`get_updaters`

"""
while update_function in self.updaters:
self.updaters.remove(update_function)
self.updaters = list(
filter(
lambda unified_updater: unified_updater.base_func is not updater,
self.updaters,
)
)
return self

def clear_updaters(self, recursive: bool = True) -> "Mobject":
Expand Down Expand Up @@ -1010,9 +1031,7 @@ def match_updaters(self, mobject: "Mobject") -> "Mobject":

"""

self.clear_updaters()
for updater in mobject.get_updaters():
self.add_updater(updater)
self.updaters = list(mobject.get_updaters())
return self

def suspend_updating(self, recursive: bool = True) -> "Mobject":
Expand Down Expand Up @@ -1065,7 +1084,7 @@ def resume_updating(self, recursive: bool = True) -> "Mobject":
if recursive:
for submob in self.submobjects:
submob.resume_updating(recursive)
self.update(dt=0, recursive=recursive)
self._apply_updaters(dt=0, recursive=recursive)
return self

# Transforming operations
Expand Down Expand Up @@ -2291,7 +2310,7 @@ def init_size(num, alignments, sizes):
# make the grid as close to quadratic as possible.
# choosing cols first can results in cols>rows.
# This is favored over rows>cols since in general
# the sceene is wider than high.
# the scene is wider than high.
if rows is None:
rows = ceil(len(mobs) / cols)
if cols is None:
Expand Down Expand Up @@ -2361,7 +2380,7 @@ def reverse(maybe_list):
mobs.extend([placeholder] * (rows * cols - len(mobs)))
grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)]

measured_heigths = [
measured_heights = [
max([grid[r][c].height for c in range(cols)]) for r in range(rows)
]
measured_widths = [
Expand All @@ -2378,7 +2397,7 @@ def init_sizes(sizes, num, measures, name):
sizes[i] if sizes[i] is not None else measures[i] for i in range(num)
]

heights = init_sizes(row_heights, rows, measured_heigths, "row_heights")
heights = init_sizes(row_heights, rows, measured_heights, "row_heights")
widths = init_sizes(col_widths, cols, measured_widths, "col_widths")

x, y = 0, 0
Expand Down
Loading