Skip to content

Commit 7e580b2

Browse files
Use Literal for ord typehint
1 parent 022bccd commit 7e580b2

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from functools import partial
3-
from typing import Callable, Optional, Union
3+
from typing import Callable, Literal, Optional, Union
44

55
import numpy as np
66
from numpy.core.numeric import normalize_axis_tuple # type: ignore
@@ -723,9 +723,12 @@ def _multi_svd_norm(
723723
return result
724724

725725

726+
VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]
727+
728+
726729
def norm(
727730
x: ptb.TensorVariable,
728-
ord: Optional[Union[float, str]] = None,
731+
ord: Optional[Union[float, VALID_ORD]] = None,
729732
axis: Optional[Union[int, tuple[int, ...]]] = None,
730733
keepdims: bool = False,
731734
):
@@ -736,16 +739,22 @@ def norm(
736739
----------
737740
x: TensorVariable
738741
Tensor to take norm of.
739-
ord: float or str, optional
742+
743+
ord: float, str or int, optional
740744
Order of norm. If `ord` is a str, it must be one of the following:
741745
- 'fro' or 'f' : Frobenius norm
742746
- 'nuc' : nuclear norm
743747
- 'inf' : Infinity norm
744748
- '-inf' : Negative infinity norm
745-
Otherwise `ord` must be a float. Default is the Frobenius (L2) norm.
749+
If an integer, order can be one of -2, -1, 0, 1, or 2.
750+
Otherwise `ord` must be a float.
751+
752+
Default is the Frobenius (L2) norm.
753+
746754
axis: tuple of int, optional
747755
Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
748756
norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.
757+
749758
keepdims: bool
750759
If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.
751760

0 commit comments

Comments
 (0)