@@ -150,17 +150,21 @@ def __init__(self, row_partitions, inner_shape, dtype=None, validate=False):
150150 message = msg ))
151151
152152 self ._inner_shape .shape .assert_has_rank (1 )
153+ self ._static_inner_shape = tensor_util .constant_value_as_shape (
154+ self ._inner_shape )
155+
153156 if row_partitions :
154157 last_row_partition = row_partitions [- 1 ]
155158 static_nvals = last_row_partition .static_nvals
156- static_inner_shape = tensor_util .constant_value (self ._inner_shape )
157- if (static_nvals is not None ) and (static_inner_shape is not None ):
158- if static_nvals != static_inner_shape [0 ]:
159+ static_inner_shape_nvals = tensor_shape .dimension_value (
160+ self ._static_inner_shape [0 ])
161+ if static_nvals is not None and static_inner_shape_nvals is not None :
162+ if static_nvals != static_inner_shape_nvals :
159163 raise ValueError ("Last row partition does not match inner_shape." )
160164 elif validate :
161165 checks .append (
162166 check_ops .assert_equal (
163- row_partitions [ - 1 ] .nvals (),
167+ last_row_partition .nvals (),
164168 self ._inner_shape [0 ],
165169 message = "Last row partition does not match inner_shape." ))
166170 if checks :
@@ -319,12 +323,11 @@ def dtype(self):
319323 """The dtype of the shape -- one of tf.int32 or tf.int64."""
320324 return self ._inner_shape .dtype
321325
322- def _static_inner_shape (self , truncate_first ):
323- """Returns the lengths of the inner shape (if rank known)."""
324- result = tensor_util .constant_value (self .inner_shape , partial = True )
325- if result is None :
326+ def _static_inner_shape_as_list (self , truncate_first ):
327+ """Returns the lengths of the inner shape (if rank known), or [...]."""
328+ if self ._static_inner_shape .rank is None :
326329 return [...]
327- result = list ( result )
330+ result = self . _static_inner_shape . as_list ( )
328331 if truncate_first :
329332 return result [1 :]
330333 return result
@@ -347,7 +350,7 @@ def static_lengths(self, ragged_lengths=True):
347350 Ellipsis at the end.
348351 """
349352 if self .num_row_partitions == 0 :
350- return self ._static_inner_shape (False )
353+ return self ._static_inner_shape_as_list (False )
351354 first_dim = self .row_partitions [0 ].static_nrows
352355 if isinstance (first_dim , tensor_shape .Dimension ):
353356 first_dim = first_dim .value
@@ -364,7 +367,7 @@ def static_lengths(self, ragged_lengths=True):
364367 else :
365368 rp_dims .append (None )
366369
367- return rp_dims + self ._static_inner_shape (True )
370+ return rp_dims + self ._static_inner_shape_as_list (True )
368371
369372 def __repr__ (self ):
370373 lengths = _list_with_ellipsis_to_str (self .static_lengths ())
@@ -440,10 +443,22 @@ def _dimension(self, index):
440443 elif not self .is_uniform (index ):
441444 raise ValueError ("Index " + str (index ) + " is not uniform" )
442445 elif index == 0 and self .num_row_partitions > 0 :
446+ static_nrows = self .row_partitions [0 ].static_nrows
447+ if static_nrows is not None :
448+ return constant_op .constant (static_nrows , dtype = self .dtype )
443449 return self .row_partitions [0 ].nrows ()
444450 elif self .num_row_partitions == 0 :
451+ static_result = tensor_shape .dimension_value (
452+ self ._static_inner_shape [index ])
453+ if static_result is not None :
454+ return constant_op .constant (static_result , dtype = self .dtype )
445455 return self .inner_shape [index ]
446456 elif index > self .num_row_partitions :
457+ static_result = tensor_shape .dimension_value (
458+ self ._static_inner_shape [index - self .num_row_partitions ])
459+ if static_result is not None :
460+ return constant_op .constant (static_result , dtype = self .dtype )
461+
447462 return self .inner_shape [index - self .num_row_partitions ]
448463 else :
449464 return self .row_partitions [index - 1 ].uniform_row_length ()
@@ -533,9 +548,7 @@ def _num_slices_in_dimension(self, axis):
533548 "You can't use negative values if the rank is undefined" )
534549 axis = axis + rank
535550 if axis == 0 :
536- if self .num_row_partitions >= 1 :
537- return self .row_partitions [0 ].nrows ()
538- return self .inner_shape [0 ]
551+ return self ._dimension (0 )
539552 if axis <= self .num_row_partitions :
540553 return self .row_partitions [axis - 1 ].nvals ()
541554 # If self.num_row_partitions = 1, and
@@ -550,7 +563,7 @@ def _num_slices_in_dimension(self, axis):
550563 return math_ops .reduce_prod (self .inner_shape [:remainder ])
551564
552565 def is_uniform (self , axis ):
553- """Returns true if the indicated dimension is ragged ."""
566+ """Returns true if the indicated dimension is uniform ."""
554567 if not isinstance (axis , int ):
555568 raise TypeError ("axis must be an integer" )
556569 rank = self .rank
@@ -604,6 +617,9 @@ def _alt_inner_shape(self, new_inner_rank):
604617 elif new_inner_rank == self .inner_rank :
605618 return self .inner_shape
606619 elif new_inner_rank < self .inner_rank :
620+ if self ._static_inner_shape .is_fully_defined ():
621+ return _alt_inner_shape_from_tensor_shape (self ._static_inner_shape ,
622+ self .dtype , new_inner_rank )
607623 first_dimension = self ._num_slices_in_dimension (- new_inner_rank )
608624 if new_inner_rank == 1 :
609625 return array_ops .expand_dims (first_dimension , 0 )
@@ -625,10 +641,15 @@ def _alt_inner_shape(self, new_inner_rank):
625641 return array_ops .concat ([array_ops .stack (new_dims ), self .inner_shape [1 :]],
626642 axis = 0 )
627643
644+ def _inner_shape_dim (self , dimension ):
645+ """Returns an int or a tensor representing _inner_shape[dimension]."""
646+ result = tensor_shape .dimension_value (self ._static_inner_shape [dimension ])
647+ return self ._inner_shape [dimension ] if result is None else result
648+
628649 def with_inner_rank (self , inner_rank ):
629650 """Returns the same shape but a different inner_rank.
630651
631- All dimensions that are to represented in the inner_shape must be dense.
652+ All dimensions that are to be represented in the inner_shape must be dense.
632653 See inner_rank.
633654
634655 Args:
@@ -690,20 +711,19 @@ def _with_num_row_partitions(self, num_row_partitions):
690711 raise ValueError ("num_row_partitions must be less than rank" )
691712 if num_row_partitions > self .num_row_partitions :
692713 num_row_partitions_diff = num_row_partitions - self .num_row_partitions
693-
694- nvals = self .row_partitions [- 1 ].nvals () if (
695- self .num_row_partitions > 0 ) else self ._dimension (0 )
714+ new_inner_rank = self .rank - num_row_partitions
715+ nvals = self ._inner_shape_dim (0 )
696716 more_rp = []
697717 for i in range (num_row_partitions_diff ):
698718 nrows = nvals
699- row_length = self .inner_shape [ i + 1 ]
719+ row_length = self ._inner_shape_dim ( i + 1 )
700720 nvals = nrows * row_length
701721 rp = RowPartition .from_uniform_row_length (
702- row_length , nrows = nrows , nvals = nvals )
722+ row_length , nrows = nrows , dtype = self . dtype )
703723 more_rp .append (rp )
724+ alt_inner = self ._alt_inner_shape (new_inner_rank )
704725 return RaggedShape (
705- list (self .row_partitions ) + more_rp ,
706- self ._alt_inner_shape (self .rank - num_row_partitions ))
726+ list (self .row_partitions ) + more_rp , alt_inner )
707727 else :
708728 assert num_row_partitions < self .num_row_partitions
709729 return RaggedShape (self .row_partitions [:num_row_partitions ],
@@ -950,8 +970,8 @@ def ragged_binary_elementwise_op_impl(op, x, y):
950970
951971 (shape_z , bcast_xz ,
952972 bcast_yz ) = broadcast_dynamic_shape_extended (shape_x , shape_y )
953- x_new_flat = bcast_xz .broadcast_flat_values (x )
954- y_new_flat = bcast_yz .broadcast_flat_values (y )
973+ x_new_flat = bcast_xz .broadcast_flat_values (x , inner_dimensions = False )
974+ y_new_flat = bcast_yz .broadcast_flat_values (y , inner_dimensions = False )
955975 z_flat = op (x_new_flat , y_new_flat )
956976 return shape_z ._add_row_partitions (z_flat , validate = True ) # pylint: disable=protected-access
957977
@@ -1098,18 +1118,18 @@ def with_dependencies(self, checks):
10981118 pass
10991119
11001120 @classmethod
1101- def get_identity_broadcaster (cls , nvals ):
1121+ def get_identity_broadcaster (cls , nvals , dtype = None ):
11021122 """Create an identity broadcaster.
11031123
11041124 TODO(martinz): an identity broadcaster can be far more efficient than a
11051125 generic broadcaster. Add an optimized implementation.
11061126 Args:
11071127 nvals: the number of values for the broadcaster.
1108-
1128+ dtype: the dtype of the broadcaster, or None to use the dtype of nvals.
11091129 Returns:
11101130 an identity broadcaster from [0....nvals-1] to [0...nvals-1]
11111131 """
1112- return _GatherLayerBroadcaster (math_ops .range (nvals ))
1132+ return _GatherLayerBroadcaster (math_ops .range (nvals , dtype = dtype ))
11131133
11141134 def broadcast_tensor (self , tensor ):
11151135 """Broadcast from a dense tensor.
@@ -1487,8 +1507,6 @@ def _broadcast_dynamic_shape_one_layer(a, b):
14871507 """
14881508 a_0 = a [0 ]
14891509 b_0 = b [0 ]
1490- can_broadcast_from_a = math_ops .equal (a_0 , 1 )
1491- can_broadcast_from_b = math_ops .equal (b_0 , 1 )
14921510
14931511 def broadcast_from_a ():
14941512 # Assumes a_0 == 1
@@ -1497,22 +1515,38 @@ def broadcast_from_a():
14971515 target = b
14981516 return [a_layer , b_layer , target ]
14991517
1518+ a_static = tensor_util .constant_value (a )
1519+ if a_static is not None and a_static [0 ] == 1 :
1520+ [a_gi , b_gi , target ] = broadcast_from_a ()
1521+ a_layer = _LayerBroadcaster .from_gather_index (a_gi )
1522+ b_layer = _LayerBroadcaster .from_gather_index (b_gi )
1523+ return [a_layer , b_layer , target ]
1524+
15001525 def broadcast_from_b ():
15011526 # Assumes b_0 == 1
15021527 a_layer = math_ops .range (a_0 )
15031528 b_layer = array_ops .zeros (a_0 , dtype = a_0 .dtype )
15041529 target = a
15051530 return [a_layer , b_layer , target ]
15061531
1532+ b_static = tensor_util .constant_value (b )
1533+ if b_static is not None and b_static [0 ] == 1 :
1534+ [a_gi , b_gi , target ] = broadcast_from_b ()
1535+ a_layer = _LayerBroadcaster .from_gather_index (a_gi )
1536+ b_layer = _LayerBroadcaster .from_gather_index (b_gi )
1537+ return [a_layer , b_layer , target ]
1538+
15071539 def broadcast_noop ():
15081540 # Assumes a_0 == 1
15091541 a_layer = math_ops .range (a_0 )
15101542 b_layer = math_ops .range (b_0 )
15111543 target = b
15121544 return [a_layer , b_layer , target ]
15131545
1546+ can_broadcast_from_a = math_ops .equal (a_0 , 1 )
1547+ can_broadcast_from_b = math_ops .equal (b_0 , 1 )
1548+
15141549 def broadcast_not_from_a ():
1515- can_broadcast_from_b = math_ops .equal (b_0 , 1 )
15161550 return control_flow_ops .cond (
15171551 can_broadcast_from_b , true_fn = broadcast_from_b , false_fn = broadcast_noop )
15181552
@@ -1552,15 +1586,31 @@ def _broadcast_dynamic_shape_first_layer(a_0, b_0):
15521586 layer_a is a _LayerBroadcaster from a to the target.
15531587 layer_b is a _LayerBroadcaster from b to the target.
15541588 """
1555- can_broadcast_from_a = math_ops .equal (a_0 , constant_op .constant (1 , a_0 .dtype ))
1556- can_broadcast_from_b = math_ops .equal (b_0 , constant_op .constant (1 , b_0 .dtype ))
1557-
15581589 def broadcast_from_a ():
15591590 # Assumes a_0 == 1
15601591 a_layer = array_ops .zeros (b_0 , dtype = b_0 .dtype )
15611592 b_layer = math_ops .range (b_0 )
15621593 return [a_layer , b_layer ]
15631594
1595+ static_a_0 = tensor_util .constant_value (a_0 )
1596+ static_b_0 = tensor_util .constant_value (b_0 )
1597+ if static_a_0 is not None :
1598+ if static_a_0 == static_b_0 :
1599+ id_broadcaster = _LayerBroadcaster .get_identity_broadcaster (
1600+ static_a_0 , dtype = a_0 .dtype )
1601+ return [id_broadcaster , id_broadcaster ]
1602+ elif static_a_0 == 1 :
1603+ return [
1604+ _LayerBroadcaster .get_singleton_broadcaster (b_0 ),
1605+ _LayerBroadcaster .get_identity_broadcaster (b_0 )
1606+ ]
1607+
1608+ if static_b_0 == 1 :
1609+ return [
1610+ _LayerBroadcaster .get_identity_broadcaster (a_0 ),
1611+ _LayerBroadcaster .get_singleton_broadcaster (a_0 )
1612+ ]
1613+
15641614 def broadcast_from_b ():
15651615 # Assumes b_0 == 1
15661616 a_layer = math_ops .range (a_0 )
@@ -1573,6 +1623,9 @@ def broadcast_noop():
15731623 b_layer = math_ops .range (b_0 )
15741624 return [a_layer , b_layer ]
15751625
1626+ can_broadcast_from_a = math_ops .equal (a_0 , constant_op .constant (1 , a_0 .dtype ))
1627+ can_broadcast_from_b = math_ops .equal (b_0 , constant_op .constant (1 , b_0 .dtype ))
1628+
15761629 def broadcast_not_from_a ():
15771630 return control_flow_ops .cond (
15781631 can_broadcast_from_b , true_fn = broadcast_from_b , false_fn = broadcast_noop )
@@ -1663,6 +1716,15 @@ def _broadcast_dynamic_shape_next_layer_half_ragged(
16631716 assert a_1 .is_uniform ()
16641717 assert not b_1 .is_uniform ()
16651718
1719+ static_a_1 = tensor_util .constant_value (a_1 .uniform_row_length ())
1720+ if static_a_1 == 1 :
1721+ [bc_1 , c_1b ] = _broadcast_half (bc_0 , b_1 )
1722+ ac_1_gather_index = array_ops .gather (ac_0 .gather_index , c_1b .value_rowids ())
1723+ c_1 = RowPartition .from_row_splits (c_1b .row_splits ())
1724+ ac_1 = _LayerBroadcaster .from_gather_index (ac_1_gather_index )
1725+ bc_1 = _LayerBroadcaster .from_gather_index (bc_1 .gather_index )
1726+ return [c_1 , ac_1 , bc_1 ]
1727+
16661728 def broadcast_noop ():
16671729 # The sides must be "equal".
16681730 [ac_1 , c_1a ] = _broadcast_half (ac_0 , a_1 )
@@ -1730,6 +1792,37 @@ def _broadcast_dynamic_shape_next_layer_both_uniform(
17301792 assert a_1 .is_uniform ()
17311793 assert b_1 .is_uniform ()
17321794
1795+ static_a_1 = tensor_util .constant_value (a_1 .uniform_row_length ())
1796+ static_b_1 = tensor_util .constant_value (b_1 .uniform_row_length ())
1797+
1798+ if static_a_1 is not None :
1799+ if static_a_1 == static_b_1 :
1800+ # Here, this dimension is the same, but we may have to broadcast previous
1801+ # dimensions.
1802+ [ac_1 , _ ] = _broadcast_half (ac_0 , a_1 )
1803+ [bc_1 , _ ] = _broadcast_half (bc_0 , b_1 )
1804+ c_1 = RowPartition .from_uniform_row_length (
1805+ static_a_1 ,
1806+ nrows = ac_0 .dest_nrows ())
1807+ return [c_1 , ac_1 , bc_1 ]
1808+ elif static_a_1 == 1 :
1809+ [bc_1 , c_1b ] = _broadcast_half (bc_0 , b_1 )
1810+ ac_1 = _LayerBroadcaster .from_gather_index (
1811+ array_ops .gather (ac_0 .gather_index , c_1b .value_rowids ()))
1812+ c_1 = RowPartition .from_uniform_row_length (
1813+ b_1 .uniform_row_length (),
1814+ nrows = bc_0 .dest_nrows ())
1815+ return [c_1 , ac_1 , bc_1 ]
1816+
1817+ if static_b_1 == 1 :
1818+ [ac_1 , c_1a ] = _broadcast_half (ac_0 , a_1 )
1819+ bc_1 = _LayerBroadcaster .from_gather_index (
1820+ array_ops .gather (bc_0 .gather_index , c_1a .value_rowids ()))
1821+ c_1 = RowPartition .from_uniform_row_length (
1822+ a_1 .uniform_row_length (),
1823+ nrows = ac_0 .dest_nrows ())
1824+ return [c_1 , ac_1 , bc_1 ]
1825+
17331826 def broadcast_noop ():
17341827 # Assumes a_1.uniform_row_length() == b_1.uniform_row_length()
17351828 # Both sides broadcast to a single shape.
@@ -1920,8 +2013,8 @@ def _broadcast_dynamic_shape_extended_complete(
19202013 ]
19212014 c_num_row_partitions = _get_broadcast_num_row_partitions (a , b )
19222015
1923- c = RaggedShape .from_row_partitions (
1924- c_prefix + tuple ( c_suffix )) ._with_num_row_partitions (c_num_row_partitions )
2016+ c_raw = RaggedShape .from_row_partitions (c_prefix + tuple ( c_suffix ))
2017+ c = c_raw ._with_num_row_partitions (c_num_row_partitions )
19252018 return (c , _Broadcaster (a , c , ac ), _Broadcaster (b , c , bc_prefix + bc_suffix ))
19262019
19272020
@@ -1946,7 +2039,6 @@ def _broadcast_dynamic_shape_extended_helper(
19462039 assert 1 <= a .rank
19472040 a_rps = a ._as_row_partitions () # pylint: disable=protected-access
19482041 b_rps = b ._as_row_partitions () # pylint: disable=protected-access
1949- a_nrows = a [0 ]
19502042
19512043 if len (a_rps ) < len (b_rps ):
19522044 # Note: this includes the case where len(a_rps)==0.
@@ -1957,6 +2049,11 @@ def _broadcast_dynamic_shape_extended_helper(
19572049 # ... | ...
19582050 # a_rps[-1] | b_rps[-1]
19592051
2052+ a_nrows = a [0 ]
2053+ a_nrows_static = tensor_util .constant_value (a_nrows )
2054+ if a_nrows_static is not None :
2055+ a_nrows = a_nrows_static
2056+
19602057 neg_one_a_rp = RowPartition .from_uniform_row_length (
19612058 uniform_row_length = a_nrows , nrows = 1 , nvals = a_nrows )
19622059 neg_one_b_rp = b_rps [- (len (a_rps ) + 1 )]
@@ -2180,3 +2277,15 @@ def _is_int_or_tuple_of_ints(x):
21802277 if not isinstance (y , int ):
21812278 return False
21822279 return True
2280+
2281+
2282+ def _alt_inner_shape_from_tensor_shape (shape , dtype , new_inner_rank ):
2283+ """Helper for _alt_inner_shape, used directly in _with_num_row_partitions."""
2284+ if new_inner_rank == 1 :
2285+ return constant_op .constant ([shape .num_elements ()], dtype = dtype )
2286+ new_inner_rank_tail_length = new_inner_rank - 1
2287+ inner_shape_tail = shape [- new_inner_rank_tail_length :].as_list ()
2288+ first_dim = shape [:- new_inner_rank_tail_length ].num_elements ()
2289+ return constant_op .constant ([first_dim ] + inner_shape_tail ,
2290+ dtype = dtype )
2291+
0 commit comments