Skip to content

Commit 269825a

Browse files
Making sure that the static shape is preserved.
Also, don't broadcast inner_dimensions. There are two possible solutions for preserving the static shape: 1. Directly calculating the static shape. Effectively, looking for dimensions where the static shape is known, and overwriting the resulting shape for them. 2. Trying to see if we can carry the static shape through the calculation more effectively. #2 is selected here, because by tracking the static shape, we also identify optimizations, such as where identity transforms can take place. PiperOrigin-RevId: 425974190 Change-Id: Ia6c24a3f48444a3fa1e7f7c9914dc97f29eef9bd
1 parent 84a132b commit 269825a

File tree

2 files changed

+460
-40
lines changed

2 files changed

+460
-40
lines changed

tensorflow/python/ops/ragged/ragged_shape.py

Lines changed: 146 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,21 @@ def __init__(self, row_partitions, inner_shape, dtype=None, validate=False):
150150
message=msg))
151151

152152
self._inner_shape.shape.assert_has_rank(1)
153+
self._static_inner_shape = tensor_util.constant_value_as_shape(
154+
self._inner_shape)
155+
153156
if row_partitions:
154157
last_row_partition = row_partitions[-1]
155158
static_nvals = last_row_partition.static_nvals
156-
static_inner_shape = tensor_util.constant_value(self._inner_shape)
157-
if (static_nvals is not None) and (static_inner_shape is not None):
158-
if static_nvals != static_inner_shape[0]:
159+
static_inner_shape_nvals = tensor_shape.dimension_value(
160+
self._static_inner_shape[0])
161+
if static_nvals is not None and static_inner_shape_nvals is not None:
162+
if static_nvals != static_inner_shape_nvals:
159163
raise ValueError("Last row partition does not match inner_shape.")
160164
elif validate:
161165
checks.append(
162166
check_ops.assert_equal(
163-
row_partitions[-1].nvals(),
167+
last_row_partition.nvals(),
164168
self._inner_shape[0],
165169
message="Last row partition does not match inner_shape."))
166170
if checks:
@@ -319,12 +323,11 @@ def dtype(self):
319323
"""The dtype of the shape -- one of tf.int32 or tf.int64."""
320324
return self._inner_shape.dtype
321325

322-
def _static_inner_shape(self, truncate_first):
323-
"""Returns the lengths of the inner shape (if rank known)."""
324-
result = tensor_util.constant_value(self.inner_shape, partial=True)
325-
if result is None:
326+
def _static_inner_shape_as_list(self, truncate_first):
327+
"""Returns the lengths of the inner shape (if rank known), or [...]."""
328+
if self._static_inner_shape.rank is None:
326329
return [...]
327-
result = list(result)
330+
result = self._static_inner_shape.as_list()
328331
if truncate_first:
329332
return result[1:]
330333
return result
@@ -347,7 +350,7 @@ def static_lengths(self, ragged_lengths=True):
347350
Ellipsis at the end.
348351
"""
349352
if self.num_row_partitions == 0:
350-
return self._static_inner_shape(False)
353+
return self._static_inner_shape_as_list(False)
351354
first_dim = self.row_partitions[0].static_nrows
352355
if isinstance(first_dim, tensor_shape.Dimension):
353356
first_dim = first_dim.value
@@ -364,7 +367,7 @@ def static_lengths(self, ragged_lengths=True):
364367
else:
365368
rp_dims.append(None)
366369

367-
return rp_dims + self._static_inner_shape(True)
370+
return rp_dims + self._static_inner_shape_as_list(True)
368371

369372
def __repr__(self):
370373
lengths = _list_with_ellipsis_to_str(self.static_lengths())
@@ -440,10 +443,22 @@ def _dimension(self, index):
440443
elif not self.is_uniform(index):
441444
raise ValueError("Index " + str(index) + " is not uniform")
442445
elif index == 0 and self.num_row_partitions > 0:
446+
static_nrows = self.row_partitions[0].static_nrows
447+
if static_nrows is not None:
448+
return constant_op.constant(static_nrows, dtype=self.dtype)
443449
return self.row_partitions[0].nrows()
444450
elif self.num_row_partitions == 0:
451+
static_result = tensor_shape.dimension_value(
452+
self._static_inner_shape[index])
453+
if static_result is not None:
454+
return constant_op.constant(static_result, dtype=self.dtype)
445455
return self.inner_shape[index]
446456
elif index > self.num_row_partitions:
457+
static_result = tensor_shape.dimension_value(
458+
self._static_inner_shape[index - self.num_row_partitions])
459+
if static_result is not None:
460+
return constant_op.constant(static_result, dtype=self.dtype)
461+
447462
return self.inner_shape[index - self.num_row_partitions]
448463
else:
449464
return self.row_partitions[index - 1].uniform_row_length()
@@ -533,9 +548,7 @@ def _num_slices_in_dimension(self, axis):
533548
"You can't use negative values if the rank is undefined")
534549
axis = axis + rank
535550
if axis == 0:
536-
if self.num_row_partitions >= 1:
537-
return self.row_partitions[0].nrows()
538-
return self.inner_shape[0]
551+
return self._dimension(0)
539552
if axis <= self.num_row_partitions:
540553
return self.row_partitions[axis - 1].nvals()
541554
# If self.num_row_partitions = 1, and
@@ -550,7 +563,7 @@ def _num_slices_in_dimension(self, axis):
550563
return math_ops.reduce_prod(self.inner_shape[:remainder])
551564

552565
def is_uniform(self, axis):
553-
"""Returns true if the indicated dimension is ragged."""
566+
"""Returns true if the indicated dimension is uniform."""
554567
if not isinstance(axis, int):
555568
raise TypeError("axis must be an integer")
556569
rank = self.rank
@@ -604,6 +617,9 @@ def _alt_inner_shape(self, new_inner_rank):
604617
elif new_inner_rank == self.inner_rank:
605618
return self.inner_shape
606619
elif new_inner_rank < self.inner_rank:
620+
if self._static_inner_shape.is_fully_defined():
621+
return _alt_inner_shape_from_tensor_shape(self._static_inner_shape,
622+
self.dtype, new_inner_rank)
607623
first_dimension = self._num_slices_in_dimension(-new_inner_rank)
608624
if new_inner_rank == 1:
609625
return array_ops.expand_dims(first_dimension, 0)
@@ -625,10 +641,15 @@ def _alt_inner_shape(self, new_inner_rank):
625641
return array_ops.concat([array_ops.stack(new_dims), self.inner_shape[1:]],
626642
axis=0)
627643

644+
def _inner_shape_dim(self, dimension):
645+
"""Returns an int or a tensor representing _inner_shape[dimension]."""
646+
result = tensor_shape.dimension_value(self._static_inner_shape[dimension])
647+
return self._inner_shape[dimension] if result is None else result
648+
628649
def with_inner_rank(self, inner_rank):
629650
"""Returns the same shape but a different inner_rank.
630651
631-
All dimensions that are to represented in the inner_shape must be dense.
652+
All dimensions that are to be represented in the inner_shape must be dense.
632653
See inner_rank.
633654
634655
Args:
@@ -690,20 +711,19 @@ def _with_num_row_partitions(self, num_row_partitions):
690711
raise ValueError("num_row_partitions must be less than rank")
691712
if num_row_partitions > self.num_row_partitions:
692713
num_row_partitions_diff = num_row_partitions - self.num_row_partitions
693-
694-
nvals = self.row_partitions[-1].nvals() if (
695-
self.num_row_partitions > 0) else self._dimension(0)
714+
new_inner_rank = self.rank - num_row_partitions
715+
nvals = self._inner_shape_dim(0)
696716
more_rp = []
697717
for i in range(num_row_partitions_diff):
698718
nrows = nvals
699-
row_length = self.inner_shape[i + 1]
719+
row_length = self._inner_shape_dim(i + 1)
700720
nvals = nrows * row_length
701721
rp = RowPartition.from_uniform_row_length(
702-
row_length, nrows=nrows, nvals=nvals)
722+
row_length, nrows=nrows, dtype=self.dtype)
703723
more_rp.append(rp)
724+
alt_inner = self._alt_inner_shape(new_inner_rank)
704725
return RaggedShape(
705-
list(self.row_partitions) + more_rp,
706-
self._alt_inner_shape(self.rank - num_row_partitions))
726+
list(self.row_partitions) + more_rp, alt_inner)
707727
else:
708728
assert num_row_partitions < self.num_row_partitions
709729
return RaggedShape(self.row_partitions[:num_row_partitions],
@@ -950,8 +970,8 @@ def ragged_binary_elementwise_op_impl(op, x, y):
950970

951971
(shape_z, bcast_xz,
952972
bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y)
953-
x_new_flat = bcast_xz.broadcast_flat_values(x)
954-
y_new_flat = bcast_yz.broadcast_flat_values(y)
973+
x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False)
974+
y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False)
955975
z_flat = op(x_new_flat, y_new_flat)
956976
return shape_z._add_row_partitions(z_flat, validate=True) # pylint: disable=protected-access
957977

@@ -1098,18 +1118,18 @@ def with_dependencies(self, checks):
10981118
pass
10991119

11001120
@classmethod
1101-
def get_identity_broadcaster(cls, nvals):
1121+
def get_identity_broadcaster(cls, nvals, dtype=None):
11021122
"""Create an identity broadcaster.
11031123
11041124
TODO(martinz): an identity broadcaster can be far more efficient than a
11051125
generic broadcaster. Add an optimized implementation.
11061126
Args:
11071127
nvals: the number of values for the broadcaster.
1108-
1128+
dtype: the dtype of the broadcaster, or None to use the dtype of nvals.
11091129
Returns:
11101130
an identity broadcaster from [0....nvals-1] to [0...nvals-1]
11111131
"""
1112-
return _GatherLayerBroadcaster(math_ops.range(nvals))
1132+
return _GatherLayerBroadcaster(math_ops.range(nvals, dtype=dtype))
11131133

11141134
def broadcast_tensor(self, tensor):
11151135
"""Broadcast from a dense tensor.
@@ -1487,8 +1507,6 @@ def _broadcast_dynamic_shape_one_layer(a, b):
14871507
"""
14881508
a_0 = a[0]
14891509
b_0 = b[0]
1490-
can_broadcast_from_a = math_ops.equal(a_0, 1)
1491-
can_broadcast_from_b = math_ops.equal(b_0, 1)
14921510

14931511
def broadcast_from_a():
14941512
# Assumes a_0 == 1
@@ -1497,22 +1515,38 @@ def broadcast_from_a():
14971515
target = b
14981516
return [a_layer, b_layer, target]
14991517

1518+
a_static = tensor_util.constant_value(a)
1519+
if a_static is not None and a_static[0] == 1:
1520+
[a_gi, b_gi, target] = broadcast_from_a()
1521+
a_layer = _LayerBroadcaster.from_gather_index(a_gi)
1522+
b_layer = _LayerBroadcaster.from_gather_index(b_gi)
1523+
return [a_layer, b_layer, target]
1524+
15001525
def broadcast_from_b():
15011526
# Assumes b_0 == 1
15021527
a_layer = math_ops.range(a_0)
15031528
b_layer = array_ops.zeros(a_0, dtype=a_0.dtype)
15041529
target = a
15051530
return [a_layer, b_layer, target]
15061531

1532+
b_static = tensor_util.constant_value(b)
1533+
if b_static is not None and b_static[0] == 1:
1534+
[a_gi, b_gi, target] = broadcast_from_b()
1535+
a_layer = _LayerBroadcaster.from_gather_index(a_gi)
1536+
b_layer = _LayerBroadcaster.from_gather_index(b_gi)
1537+
return [a_layer, b_layer, target]
1538+
15071539
def broadcast_noop():
15081540
# Assumes a_0 == 1
15091541
a_layer = math_ops.range(a_0)
15101542
b_layer = math_ops.range(b_0)
15111543
target = b
15121544
return [a_layer, b_layer, target]
15131545

1546+
can_broadcast_from_a = math_ops.equal(a_0, 1)
1547+
can_broadcast_from_b = math_ops.equal(b_0, 1)
1548+
15141549
def broadcast_not_from_a():
1515-
can_broadcast_from_b = math_ops.equal(b_0, 1)
15161550
return control_flow_ops.cond(
15171551
can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
15181552

@@ -1552,15 +1586,31 @@ def _broadcast_dynamic_shape_first_layer(a_0, b_0):
15521586
layer_a is a _LayerBroadcaster from a to the target.
15531587
layer_b is a _LayerBroadcaster from b to the target.
15541588
"""
1555-
can_broadcast_from_a = math_ops.equal(a_0, constant_op.constant(1, a_0.dtype))
1556-
can_broadcast_from_b = math_ops.equal(b_0, constant_op.constant(1, b_0.dtype))
1557-
15581589
def broadcast_from_a():
15591590
# Assumes a_0 == 1
15601591
a_layer = array_ops.zeros(b_0, dtype=b_0.dtype)
15611592
b_layer = math_ops.range(b_0)
15621593
return [a_layer, b_layer]
15631594

1595+
static_a_0 = tensor_util.constant_value(a_0)
1596+
static_b_0 = tensor_util.constant_value(b_0)
1597+
if static_a_0 is not None:
1598+
if static_a_0 == static_b_0:
1599+
id_broadcaster = _LayerBroadcaster.get_identity_broadcaster(
1600+
static_a_0, dtype=a_0.dtype)
1601+
return [id_broadcaster, id_broadcaster]
1602+
elif static_a_0 == 1:
1603+
return [
1604+
_LayerBroadcaster.get_singleton_broadcaster(b_0),
1605+
_LayerBroadcaster.get_identity_broadcaster(b_0)
1606+
]
1607+
1608+
if static_b_0 == 1:
1609+
return [
1610+
_LayerBroadcaster.get_identity_broadcaster(a_0),
1611+
_LayerBroadcaster.get_singleton_broadcaster(a_0)
1612+
]
1613+
15641614
def broadcast_from_b():
15651615
# Assumes b_0 == 1
15661616
a_layer = math_ops.range(a_0)
@@ -1573,6 +1623,9 @@ def broadcast_noop():
15731623
b_layer = math_ops.range(b_0)
15741624
return [a_layer, b_layer]
15751625

1626+
can_broadcast_from_a = math_ops.equal(a_0, constant_op.constant(1, a_0.dtype))
1627+
can_broadcast_from_b = math_ops.equal(b_0, constant_op.constant(1, b_0.dtype))
1628+
15761629
def broadcast_not_from_a():
15771630
return control_flow_ops.cond(
15781631
can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
@@ -1663,6 +1716,15 @@ def _broadcast_dynamic_shape_next_layer_half_ragged(
16631716
assert a_1.is_uniform()
16641717
assert not b_1.is_uniform()
16651718

1719+
static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
1720+
if static_a_1 == 1:
1721+
[bc_1, c_1b] = _broadcast_half(bc_0, b_1)
1722+
ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
1723+
c_1 = RowPartition.from_row_splits(c_1b.row_splits())
1724+
ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
1725+
bc_1 = _LayerBroadcaster.from_gather_index(bc_1.gather_index)
1726+
return [c_1, ac_1, bc_1]
1727+
16661728
def broadcast_noop():
16671729
# The sides must be "equal".
16681730
[ac_1, c_1a] = _broadcast_half(ac_0, a_1)
@@ -1730,6 +1792,37 @@ def _broadcast_dynamic_shape_next_layer_both_uniform(
17301792
assert a_1.is_uniform()
17311793
assert b_1.is_uniform()
17321794

1795+
static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
1796+
static_b_1 = tensor_util.constant_value(b_1.uniform_row_length())
1797+
1798+
if static_a_1 is not None:
1799+
if static_a_1 == static_b_1:
1800+
# Here, this dimension is the same, but we may have to broadcast previous
1801+
# dimensions.
1802+
[ac_1, _] = _broadcast_half(ac_0, a_1)
1803+
[bc_1, _] = _broadcast_half(bc_0, b_1)
1804+
c_1 = RowPartition.from_uniform_row_length(
1805+
static_a_1,
1806+
nrows=ac_0.dest_nrows())
1807+
return [c_1, ac_1, bc_1]
1808+
elif static_a_1 == 1:
1809+
[bc_1, c_1b] = _broadcast_half(bc_0, b_1)
1810+
ac_1 = _LayerBroadcaster.from_gather_index(
1811+
array_ops.gather(ac_0.gather_index, c_1b.value_rowids()))
1812+
c_1 = RowPartition.from_uniform_row_length(
1813+
b_1.uniform_row_length(),
1814+
nrows=bc_0.dest_nrows())
1815+
return [c_1, ac_1, bc_1]
1816+
1817+
if static_b_1 == 1:
1818+
[ac_1, c_1a] = _broadcast_half(ac_0, a_1)
1819+
bc_1 = _LayerBroadcaster.from_gather_index(
1820+
array_ops.gather(bc_0.gather_index, c_1a.value_rowids()))
1821+
c_1 = RowPartition.from_uniform_row_length(
1822+
a_1.uniform_row_length(),
1823+
nrows=ac_0.dest_nrows())
1824+
return [c_1, ac_1, bc_1]
1825+
17331826
def broadcast_noop():
17341827
# Assumes a_1.uniform_row_length() == b_1.uniform_row_length()
17351828
# Both sides broadcast to a single shape.
@@ -1920,8 +2013,8 @@ def _broadcast_dynamic_shape_extended_complete(
19202013
]
19212014
c_num_row_partitions = _get_broadcast_num_row_partitions(a, b)
19222015

1923-
c = RaggedShape.from_row_partitions(
1924-
c_prefix + tuple(c_suffix))._with_num_row_partitions(c_num_row_partitions)
2016+
c_raw = RaggedShape.from_row_partitions(c_prefix + tuple(c_suffix))
2017+
c = c_raw._with_num_row_partitions(c_num_row_partitions)
19252018
return (c, _Broadcaster(a, c, ac), _Broadcaster(b, c, bc_prefix + bc_suffix))
19262019

19272020

@@ -1946,7 +2039,6 @@ def _broadcast_dynamic_shape_extended_helper(
19462039
assert 1 <= a.rank
19472040
a_rps = a._as_row_partitions() # pylint: disable=protected-access
19482041
b_rps = b._as_row_partitions() # pylint: disable=protected-access
1949-
a_nrows = a[0]
19502042

19512043
if len(a_rps) < len(b_rps):
19522044
# Note: this includes the case where len(a_rps)==0.
@@ -1957,6 +2049,11 @@ def _broadcast_dynamic_shape_extended_helper(
19572049
# ... | ...
19582050
# a_rps[-1] | b_rps[-1]
19592051

2052+
a_nrows = a[0]
2053+
a_nrows_static = tensor_util.constant_value(a_nrows)
2054+
if a_nrows_static is not None:
2055+
a_nrows = a_nrows_static
2056+
19602057
neg_one_a_rp = RowPartition.from_uniform_row_length(
19612058
uniform_row_length=a_nrows, nrows=1, nvals=a_nrows)
19622059
neg_one_b_rp = b_rps[-(len(a_rps) + 1)]
@@ -2180,3 +2277,15 @@ def _is_int_or_tuple_of_ints(x):
21802277
if not isinstance(y, int):
21812278
return False
21822279
return True
2280+
2281+
2282+
def _alt_inner_shape_from_tensor_shape(shape, dtype, new_inner_rank):
2283+
"""Helper for _alt_inner_shape, used directly in _with_num_row_partitions."""
2284+
if new_inner_rank == 1:
2285+
return constant_op.constant([shape.num_elements()], dtype=dtype)
2286+
new_inner_rank_tail_length = new_inner_rank - 1
2287+
inner_shape_tail = shape[-new_inner_rank_tail_length:].as_list()
2288+
first_dim = shape[:-new_inner_rank_tail_length].num_elements()
2289+
return constant_op.constant([first_dim] + inner_shape_tail,
2290+
dtype=dtype)
2291+

0 commit comments

Comments
 (0)