diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index 3330193a43..435d7aced8 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -27,6 +27,7 @@ from manim.mobject.graphing.functions import ImplicitFunction, ParametricFunction from manim.mobject.graphing.number_line import NumberLine from manim.mobject.graphing.scale import LinearBase +from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.opengl.opengl_surface import OpenGLSurface from manim.mobject.text.tex_mobject import MathTex @@ -96,10 +97,10 @@ def construct(self): ) # Extra lines and labels for point (1,1) - graphs += grid.get_horizontal_line(grid.c2p(1, 1, 0), color=BLUE) - graphs += grid.get_vertical_line(grid.c2p(1, 1, 0), color=BLUE) - graphs += Dot(point=grid.c2p(1, 1, 0), color=YELLOW) - graphs += Tex("(1,1)").scale(0.75).next_to(grid.c2p(1, 1, 0)) + graphs += grid.get_horizontal_line(grid @ (1, 1, 0), color=BLUE) + graphs += grid.get_vertical_line(grid @ (1, 1, 0), color=BLUE) + graphs += Dot(point=grid @ (1, 1, 0), color=YELLOW) + graphs += Tex("(1,1)").scale(0.75).next_to(grid @ (1, 1, 0)) title = Title( # spaces between braces to prevent SyntaxError r"Graphs of $y=x^{ {1}\over{n} }$ and $y=x^n (n=1,2,3,...,20)$", @@ -145,7 +146,7 @@ def __init__( self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 - def coords_to_point(self, *coords: Sequence[ManimFloat]): + def coords_to_point(self, *coords: ManimFloat): raise NotImplementedError() def point_to_coords(self, point: Point3D): @@ -570,7 +571,7 @@ def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line: class GetHorizontalLineExample(Scene): def construct(self): ax = Axes().add_coordinates() - point = ax.c2p(-4, 1.5) + point = ax @ (-4, 1.5) dot = Dot(point) line = ax.get_horizontal_line(point, line_func=Line) @@ -1790,6 +1791,14 @@ def construct(self): return T_label_group + def __matmul__(self, coord: Point3D | Mobject): + if isinstance(coord, Mobject): + coord = coord.get_center() + return self.coords_to_point(*coord) + + def __rmatmul__(self, point: Point3D): + return self.point_to_coords(point) + class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): """Creates a set of axes. @@ -1990,6 +1999,7 @@ def coords_to_point( self, *coords: float | Sequence[float] | Sequence[Sequence[float]] | np.ndarray ) -> np.ndarray: """Accepts coordinates from the axes and returns a point with respect to the scene. + Equivalent to `ax @ (coord1)` Parameters ---------- @@ -2018,6 +2028,8 @@ def coords_to_point( >>> ax = Axes() >>> np.around(ax.coords_to_point(1, 0, 0), 2) array([0.86, 0. , 0. ]) + >>> np.around(ax @ (1, 0, 0), 2) + array([0.86, 0. , 0. ]) >>> np.around(ax.coords_to_point([[0, 1], [1, 1], [1, 0]]), 2) array([[0. , 0.75, 0. ], [0.86, 0.75, 0. ], diff --git a/manim/mobject/graphing/number_line.py b/manim/mobject/graphing/number_line.py index 834f3086c2..26df7df044 100644 --- a/manim/mobject/graphing/number_line.py +++ b/manim/mobject/graphing/number_line.py @@ -2,6 +2,7 @@ from __future__ import annotations +from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject __all__ = ["NumberLine", "UnitInterval"] @@ -12,6 +13,7 @@ if TYPE_CHECKING: from manim.mobject.geometry.tips import ArrowTip + from manim.typing import Point3D import numpy as np @@ -344,6 +346,7 @@ def get_tick_range(self) -> np.ndarray: def number_to_point(self, number: float | np.ndarray) -> np.ndarray: """Accepts a value along the number line and returns a point with respect to the scene. + Equivalent to `NumberLine @ number` Parameters ---------- @@ -364,6 +367,8 @@ def number_to_point(self, number: float | np.ndarray) -> np.ndarray: array([0., 0., 0.]) >>> number_line.number_to_point(1) array([1., 0., 0.]) + >>> number_line @ 1 + array([1., 0., 0.]) >>> number_line.number_to_point([1, 2, 3]) array([[1., 0., 0.], [2., 0., 0.], @@ -642,6 +647,14 @@ def _decimal_places_from_step(step) -> int: return 0 return len(step.split(".")[-1]) + def __matmul__(self, other: float): + return self.n2p(other) + + def __rmatmul__(self, other: Point3D | Mobject): + if isinstance(other, Mobject): + other = other.get_center() + return self.p2n(other) + class UnitInterval(NumberLine): def __init__( diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index fd86a9bd4d..4aa71f8968 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -7,7 +7,7 @@ from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane from manim import CoordinateSystem as CS -from manim import NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig +from manim import Dot, NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig def test_initial_config(): @@ -119,7 +119,15 @@ def test_coords_to_point(): # a point with respect to the axes c2p_coord = np.around(ax.coords_to_point(2, 2), decimals=4) - np.testing.assert_array_equal(c2p_coord, (1.7143, 1.5, 0)) + c2p_coord_matmul = np.around(ax @ (2, 2), decimals=4) + + expected = (1.7143, 1.5, 0) + + np.testing.assert_array_equal(c2p_coord, expected) + np.testing.assert_array_equal(c2p_coord_matmul, c2p_coord) + + mob = Dot().move_to((2, 2, 0)) + np.testing.assert_array_equal(np.around(ax @ mob, decimals=4), expected) def test_coords_to_point_vectorized():