Skip to content

Commit

Permalink
Fix stale reference to util.prod.
Browse files Browse the repository at this point in the history
Work around pytype bug. It seems that the line
from functools import cached_property
causes pytype to give up on the entire module. Avoid the member import to fix the type inference.

PiperOrigin-RevId: 513544106
  • Loading branch information
hawkinsp authored and jax authors committed Mar 2, 2023
1 parent a9421a8 commit a002643
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions jax/_src/device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# On-device arrays.

from functools import partial, partialmethod
import math
import operator
from typing import (Any, List, Optional, Union)
import weakref
Expand All @@ -26,7 +27,6 @@
from jax._src import core
from jax._src import abstract_arrays
from jax._src import profiler
from jax._src import util
from jax._src.config import config
from jax._src.lib import xla_client as xc
from jax._src.typing import Array
Expand Down Expand Up @@ -163,7 +163,7 @@ def dtype(self):

@property
def size(self):
return util.prod(self.aval.shape)
return math.prod(self.aval.shape)

@property
def ndim(self):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import functools
from functools import partial, cached_property
from functools import partial
import itertools as it
import logging
import operator
Expand Down Expand Up @@ -496,7 +496,7 @@ def __eq__(self, other):
def _original_func(f):
if isinstance(f, property):
return cast(property, f).fget
elif isinstance(f, cached_property):
elif isinstance(f, functools.cached_property):
return f.func
return f

Expand Down

0 comments on commit a002643

Please sign in to comment.