@@ -2559,17 +2559,17 @@ def roll(x, shift, axis=None):
25592559 )
25602560
25612561
2562- def stack (tensors : Sequence [TensorVariable ], axis : int = 0 ):
2562+ def stack (tensors : Sequence ["TensorLike" ], axis : int = 0 ):
25632563 """Stack tensors in sequence on given axis (default is 0).
25642564
2565- Take a sequence of tensors and stack them on given axis to make a single
2566- tensor. The size in dimension `axis` of the result will be equal to the number
2567- of tensors passed.
2565+ Take a sequence of tensors or tensor-like constant and stack them on
2566+ given axis to make a single tensor. The size in dimension `axis` of the
2567+ result will be equal to the number of tensors passed.
25682568
25692569 Parameters
25702570 ----------
2571- tensors : Sequence[TensorVariable ]
2572- A list of tensors to be stacked.
2571+ tensors : Sequence[TensorLike ]
2572+ A list of tensors or tensor-like constants to be stacked.
25732573 axis : int
25742574 The index of the new axis. Default value is 0.
25752575
@@ -2604,11 +2604,11 @@ def stack(tensors: Sequence[TensorVariable], axis: int = 0):
26042604 (2, 2, 2, 3, 2)
26052605 """
26062606 if not isinstance (tensors , Sequence ):
2607- raise TypeError ("First argument should be Sequence[TensorVariable] " )
2607+ raise TypeError ("First argument should be a Sequence. " )
26082608 elif len (tensors ) == 0 :
2609- raise ValueError ("No tensor arguments provided" )
2609+ raise ValueError ("No tensor arguments provided. " )
26102610
2611- # If all tensors are scalars of the same type , call make_vector.
2611+ # If all tensors are scalars, call make_vector.
26122612 # It makes the graph simpler, by not adding DimShuffles and SpecifyShapes
26132613
26142614 # This should be an optimization!
@@ -2618,12 +2618,13 @@ def stack(tensors: Sequence[TensorVariable], axis: int = 0):
26182618 # optimization.
26192619 # See ticket #660
26202620 if all (
2621- # In case there are explicit ints in tensors
2622- isinstance (t , (np .number , float , int , builtins .complex ))
2621+ # In case there are explicit scalars in tensors
2622+ isinstance (t , Number )
2623+ or (isinstance (t , np .ndarray ) and t .ndim == 0 )
26232624 or (isinstance (t , Variable ) and isinstance (t .type , TensorType ) and t .ndim == 0 )
26242625 for t in tensors
26252626 ):
2626- # in case there is direct int
2627+ # In case there is direct scalar
26272628 tensors = list (map (as_tensor_variable , tensors ))
26282629 dtype = aes .upcast (* [i .dtype for i in tensors ])
26292630 return MakeVector (dtype )(* tensors )
0 commit comments