Skip to content

Commit 39f1015

Browse files
Copilotjustinchuby
andauthored
[torchlib] Implement torch.ops.prims.broadcast_in_dim.default (#2382)
This PR implements the missing `torch.ops.prims.broadcast_in_dim.default` operation that appears in BERT_pytorch and other PyTorch models. ## Overview The `broadcast_in_dim` operation is a primitive that broadcasts a tensor to a target shape by specifying which dimensions of the output correspond to the input tensor dimensions. This is different from standard broadcasting operations. ## Implementation Details **Function signature:** ```python def prims_broadcast_in_dim( a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] ) -> TensorType: ``` **Parameters:** - `a`: Input tensor to broadcast - `shape`: Target output shape - `broadcast_dimensions`: Specifies which dimensions of the output shape correspond to the input tensor dimensions **Example:** ```python # Input tensor: [3, 4] # Target shape: [2, 3, 5, 4] # broadcast_dimensions: [1, 3] # Result: Input dimension 0 (size 3) maps to output dimension 1 # Input dimension 1 (size 4) maps to output dimension 3 # Output dimensions 0 and 2 are broadcasted (filled from size 1) ``` Fixes #2218. Fix pytorch/pytorch#135343 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 8ed3521 commit 39f1015

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
176176
raise NotImplementedError()
177177

178178

179+
@torch_op("prims::broadcast_in_dim", trace_only=True)
179180
def prims_broadcast_in_dim(
180-
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
181+
a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int]
181182
) -> TensorType:
182183
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
183184

184-
raise NotImplementedError()
185+
target_rank = len(shape)
186+
187+
if not broadcast_dimensions:
188+
# Special case: no broadcast dimensions - all target dims should be 1
189+
return op.Expand(a, common_ops.merge_dims(shape))
190+
191+
# Create base shape of all 1s
192+
ones = [1] * target_rank
193+
194+
# For each broadcast dimension, we'll replace the 1 with the actual input dimension
195+
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
196+
intermediate_shape = ones
197+
198+
for i, broadcast_dim in enumerate(broadcast_dimensions):
199+
# Get the input dimension value
200+
input_dim_value = op.Shape(a, start=i, end=i + 1)
201+
intermediate_shape[broadcast_dim] = input_dim_value
202+
203+
# Reshape input to intermediate shape and expand to target
204+
reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape))
205+
return op.Expand(reshaped, shape)
185206

186207

187208
def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType:

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,35 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
8787
yield opinfo_core.SampleInput(t, kwargs={"p": p})
8888

8989

90+
def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs):
91+
del op_info
92+
del kwargs
93+
94+
# cases: (input_shape, target_shape, broadcast_dimensions)
95+
# broadcast_dimensions maps each input dim to an axis in target_shape
96+
cases = (
97+
# scalar -> 1-D tensor
98+
((), (3,), ()),
99+
# identity (no-op broadcast)
100+
((3,), (3,), (0,)),
101+
# rank-preserving broadcast where singleton dims expand
102+
((1, 3, 1), (2, 3, 4), (0, 1, 2)),
103+
# input rank 2 -> output rank 3, input dims map to trailing axes
104+
((3, 1), (2, 3, 4), (1, 2)),
105+
# add leading broadcast axis
106+
((3, 4), (1, 3, 4), (1, 2)),
107+
# insert broadcasting in middle axis
108+
((3,), (2, 3, 1), (1,)),
109+
)
110+
make_arg = functools.partial(
111+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
112+
)
113+
114+
for shape, target_shape, broadcast_dimensions in cases:
115+
tensor = make_arg(shape)
116+
yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions))
117+
118+
90119
def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
91120
del op_info
92121
# input_shape, output_size, kernal, dilation, padding, stride
@@ -2687,6 +2716,13 @@ def __init__(self):
26872716
sample_inputs_func=sample_inputs_upsample_trilinear3d_vec,
26882717
supports_out=False,
26892718
),
2719+
opinfo_core.ReductionOpInfo(
2720+
"ops.prims.broadcast_in_dim.default",
2721+
op=torch.ops.prims.broadcast_in_dim.default,
2722+
dtypes=common_dtype.all_types(),
2723+
sample_inputs_func=sample_inputs_broadcast_in_dim,
2724+
supports_out=False,
2725+
),
26902726
opinfo_core.ReductionOpInfo(
26912727
"ops.prims.var.default",
26922728
nan_policy="propagate",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,7 @@ def _where_input_wrangler(
21362136
"Our implementation is based on that for CUDA"
21372137
),
21382138
),
2139+
TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim),
21392140
TorchLibOpInfo(
21402141
"ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}
21412142
),

0 commit comments

Comments
 (0)