-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einops does not appear to handle Tensorflow tensors with dynamic shape #187
Comments
Hi @arogozhnikov, thanks for your great library. Any update on this issue? Building on the example above, the following also breaks, where the shape of the tensor is dynamic but the rank is known:
Relevant package versions:
|
Hi @jesnie , The only solution I see right now is to externally use I've sent a proposal for generic API that will make such operations possible, and it certainly needs support now: |
I'm not sure what you mean by "symbols for shape components". I understand if this is a lot of work, and outside the scope of Here's a small, dirty, example: import re
from typing import Any, Dict, List, Mapping, Tuple, Union
import tensorflow as tf
EllipsisType = Any
DimSpec = Union[int, Tuple[int, ...], EllipsisType]
TensorSpec = Tuple[DimSpec, ...]
RearrangeSpec = Tuple[TensorSpec, TensorSpec]
DIM_SPEC_RE = re.compile(r"(\.\.\.)|(\w+)|\((.*?)\)")
def parse_tensor_spec(spec: str) -> TensorSpec:
result: List[DimSpec] = []
for match in DIM_SPEC_RE.finditer(spec):
i = match.lastindex
if i == 1:
result.append(...)
elif i == 2:
var_name = match.group(2)
assert var_name is not None
result.append(var_name)
else:
assert i == 3
var_names = match.group(3)
assert var_names is not None
result.append(tuple(var_names.split()))
return tuple(result)
def parse_rearrange_spec(spec: str) -> RearrangeSpec:
frm, to = spec.split("->")
return parse_tensor_spec(frm), parse_tensor_spec(to)
def flatten_tensor_spec(spec: TensorSpec) -> Tuple[int, ...]:
result: List[int] = []
for dim_spec in spec:
if isinstance(dim_spec, tuple):
result.extend(dim_spec)
else:
result.append(dim_spec)
return tuple(result)
def parse_shape(t: tf.Tensor, spec: str) -> Mapping[str, tf.Tensor]:
frm_spec = parse_tensor_spec(spec)
frm_shape = tf.shape(t)
i = tf.zeros((), dtype=tf.int32)
sizes = {}
for var_name in frm_spec:
if isinstance(var_name, str):
sizes[var_name] = frm_shape[i]
i += 1
elif isinstance(var_name, tuple):
i += 1
else:
assert var_name is ...
i += tf.size(frm_shape) - len(frm_spec) + 1
return sizes
def rearrange(t: tf.Tensor, spec: str, **sizes: Union[int, tf.Tensor]) -> tf.Tensor:
frm_spec, to_spec = parse_rearrange_spec(spec)
tf_sizes: Dict[Union[str, EllipsisType], tf.Tensor] = {
dim_name: tf.reshape(tf.convert_to_tensor(size), [1])
for dim_name, size in sizes.items()
}
i = tf.zeros((), dtype=tf.int32)
frm_shape = tf.shape(t)
for var_name in frm_spec:
if isinstance(var_name, str):
tf_sizes[var_name] = frm_shape[i : i + 1]
i += 1
elif isinstance(var_name, tuple):
i += 1
else:
assert var_name is ...
size = tf.size(frm_shape) - len(frm_spec) + 1
tf_sizes[var_name] = frm_shape[i : i + size]
i += size
frm_spec_flat = flatten_tensor_spec(frm_spec)
frm_shape_flat = tf.concat([tf_sizes[n] for n in frm_spec_flat], axis=0)
t = tf.reshape(t, frm_shape_flat)
i = tf.zeros((), dtype=tf.int32)
frm_spec_indices = {}
for var_name in frm_spec_flat:
size = tf.size(tf_sizes[var_name])
frm_spec_indices[var_name] = tf.range(i, i + size)
i += size
to_spec_flat = flatten_tensor_spec(to_spec)
perm = tf.concat([frm_spec_indices[n] for n in to_spec_flat], axis=0)
t = tf.transpose(t, perm)
to_shape = []
for dim_spec in to_spec:
if isinstance(dim_spec, tuple):
if not dim_spec:
to_shape.append(tf.ones((), dtype=tf.int32))
else:
to_shape.append(
tf.math.reduce_prod([tf_sizes[n] for n in dim_spec], axis=0)
)
else:
to_shape.append(tf_sizes[dim_spec])
t = tf.reshape(t, tf.concat(to_shape, axis=0))
return t
def f(a: tf.Tensor) -> tf.Tensor:
parsed_shape = parse_shape(a, "... n m")
joined = rearrange(a, "... n m -> ... (n m)")
# Hypothetically do something interesting with `joined` here.
return rearrange(joined, "... (n m) -> ... n m", **parsed_shape)
static_shape = tf.Variable(tf.zeros((2, 3, 4, 5)))
dynamic_shape = tf.Variable(tf.zeros((2, 3, 4, 5)), shape=tf.TensorShape(None))
f(static_shape)
f(dynamic_shape)
compiled_f = tf.function(f)
compiled_f(static_shape)
compiled_f(dynamic_shape) |
Describe the bug
There appears to be several bugs when trying to use
einops
with TensorFlow tensors that have a dynamic shape.Reproduction steps
Steps to reproduce the behavior:
Expected behavior
f
should run without crashing, and do the same for bothstatic_shape
anddynamic_shape
.Your platform
The text was updated successfully, but these errors were encountered: