Skip to content

Commit ccf2291

Browse files
Ruishenlfacebook-github-bot
authored andcommitted
Optimize list_to_packed to avoid for loop (#1737)
Summary: For larger N and Mi value (e.g. N=154, Mi=238) I notice list_to_packed() has become a bottleneck for my application. By removing the for loop and running on GPU, i see a 10-20 x speedup. Pull Request resolved: #1737 Reviewed By: MichaelRamamonjisoa Differential Revision: D54187993 Pulled By: bottler fbshipit-source-id: 16399a24cb63b48c30460c7d960abef603b115d0
1 parent 128be02 commit ccf2291

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

pytorch3d/structures/utils.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -135,22 +135,21 @@ def list_to_packed(x: List[torch.Tensor]):
135135
- **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the
136136
index of the element in the list the item belongs to.
137137
"""
138-
N = len(x)
139-
num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device)
140-
item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device)
141-
item_packed_to_list_idx = []
142-
cur = 0
143-
for i, y in enumerate(x):
144-
num = len(y)
145-
num_items[i] = num
146-
item_packed_first_idx[i] = cur
147-
item_packed_to_list_idx.append(
148-
torch.full((num,), i, dtype=torch.int64, device=y.device)
149-
)
150-
cur += num
151-
138+
if not x:
139+
raise ValueError("Input list is empty")
140+
device = x[0].device
141+
sizes = [xi.shape[0] for xi in x]
142+
sizes_total = sum(sizes)
143+
num_items = torch.tensor(sizes, dtype=torch.int64, device=device)
144+
item_packed_first_idx = torch.zeros_like(num_items)
145+
item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0)
146+
item_packed_to_list_idx = torch.arange(
147+
sizes_total, dtype=torch.int64, device=device
148+
)
149+
item_packed_to_list_idx = (
150+
torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1
151+
)
152152
x_packed = torch.cat(x, dim=0)
153-
item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0)
154153

155154
return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx
156155

0 commit comments

Comments
 (0)