Skip to content

Commit

Permalink
Added numerics methods for tensorflow tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Nov 22, 2024
1 parent 8d501c7 commit 79ded03
Show file tree
Hide file tree
Showing 9 changed files with 779 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def collect_api_entities() -> APIInfo:
"nncf.tensor.functions.numpy_linalg",
"nncf.tensor.functions.torch_numeric",
"nncf.tensor.functions.torch_linalg",
"nncf.tensor.functions.tf_numeric",
"nncf.tensor.functions.tf_linalg",
]

with mock(mock_modules):
Expand Down
1 change: 1 addition & 0 deletions nncf/tensor/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TensorBackend(Enum):

numpy = auto()
torch = auto()
tf = auto()


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions nncf/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def _initialize_backends():
import nncf.tensor.functions.numpy_linalg
import nncf.tensor.functions.numpy_numeric

with contextlib.suppress(ImportError):
import nncf.tensor.functions.tf_linalg
import nncf.tensor.functions.tf_numeric

with contextlib.suppress(ImportError):
import nncf.tensor.functions.torch_linalg
import nncf.tensor.functions.torch_numeric # noqa: F401
Expand Down
5 changes: 5 additions & 0 deletions nncf/tensor/functions/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ def get_numeric_backend_fn(fn_name: str, backend: TensorBackend) -> Callable:
from nncf.tensor.functions import torch_numeric

return getattr(torch_numeric, fn_name)

if backend == TensorBackend.tf:
from nncf.tensor.functions import tf_numeric

return getattr(tf_numeric, fn_name)
95 changes: 95 additions & 0 deletions nncf/tensor/functions/tf_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Optional, Tuple, Union

import tensorflow as tf

from nncf.tensor.functions import linalg


@linalg.norm.register(tf.Tensor)
def _(
a: tf.Tensor,
ord: Optional[Union[str, float, int]] = None,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> tf.Tensor:
if axis is None:
axis = 0 if a._rank() == 1 else (0, 1)

if ord is None or (a._rank() == 1 and ord == "fro"):
ord = "euclidean"

with tf.device(a.device):
if ord == "nuc":
s, _, _ = tf.linalg.svd(a)
return tf.reduce_sum(s)

return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)


@linalg.cholesky.register(tf.Tensor)
def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor:
with tf.device(a.device):
cholesky = tf.linalg.cholesky(a)
if upper:
perm = list(range(tf.rank(a)))
perm[-1], perm[-2] = perm[-2], perm[-1]
cholesky = tf.transpose(cholesky, perm=perm)
return cholesky


@linalg.cholesky_inverse.register(tf.Tensor)
def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor:
with tf.device(a.device):
if upper:
perm = list(range(tf.rank(a)))
perm[-1], perm[-2] = perm[-2], perm[-1]
a = tf.transpose(a, perm=perm)

eye = tf.eye(a.shape[0], dtype=a.dtype)
return tf.linalg.cholesky_solve(a, eye)


@linalg.inv.register(tf.Tensor)
def _(a: tf.Tensor) -> tf.Tensor:
with tf.device(a.device):
return tf.linalg.inv(a)


@linalg.pinv.register(tf.Tensor)
def _(a: tf.Tensor) -> tf.Tensor:
with tf.device(a.device):
return tf.linalg.pinv(a)


@linalg.lstsq.register(tf.Tensor)
def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor:
with tf.device(a.device):
if driver is not None:
warnings.warn("Driver specifying is not supported in TensorFlow lstsq method")
if tf.rank(b) == 1:
b = tf.expand_dims(b, axis=0)
perm = list(range(tf.rank(b)))
perm[-1], perm[-2] = perm[-2], perm[-1]
b = tf.transpose(b, perm=perm)

return tf.linalg.lstsq(a, b)


@linalg.svd.register(tf.Tensor)
def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor:
with tf.device(a.device):
s, u, v = tf.linalg.svd(a, full_matrices=full_matrices)

return u, s, tf.transpose(v)
Loading

0 comments on commit 79ded03

Please sign in to comment.