11import warnings
22from functools import partial
3- from typing import Callable , Optional , Union
3+ from typing import Callable , Literal , Optional , Union
44
55import numpy as np
66from numpy .core .numeric import normalize_axis_tuple # type: ignore
@@ -723,9 +723,12 @@ def _multi_svd_norm(
723723 return result
724724
725725
726+ VALID_ORD = Literal ["fro" , "f" , "nuc" , "inf" , "-inf" , 0 , 1 , - 1 , 2 , - 2 ]
727+
728+
726729def norm (
727730 x : ptb .TensorVariable ,
728- ord : Optional [Union [float , str ]] = None ,
731+ ord : Optional [Union [float , VALID_ORD ]] = None ,
729732 axis : Optional [Union [int , tuple [int , ...]]] = None ,
730733 keepdims : bool = False ,
731734):
@@ -736,16 +739,22 @@ def norm(
736739 ----------
737740 x: TensorVariable
738741 Tensor to take norm of.
739- ord: float or str, optional
742+
743+ ord: float, str or int, optional
740744 Order of norm. If `ord` is a str, it must be one of the following:
741745 - 'fro' or 'f' : Frobenius norm
742746 - 'nuc' : nuclear norm
743747 - 'inf' : Infinity norm
744748 - '-inf' : Negative infinity norm
745- Otherwise `ord` must be a float. Default is the Frobenius (L2) norm.
749+ If an integer, order can be one of -2, -1, 0, 1, or 2.
750+ Otherwise `ord` must be a float.
751+
752+ Default is the Frobenius (L2) norm.
753+
746754 axis: tuple of int, optional
747755 Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
748756 norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.
757+
749758 keepdims: bool
750759 If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.
751760
0 commit comments