Skip to content

Commit 4bffd1e

Browse files
authored
Merge branch 'main' into fix-docstring-fill-param
2 parents 316cd17 + 97920a5 commit 4bffd1e

File tree

3 files changed

+23
-36
lines changed

3 files changed

+23
-36
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,21 @@ versions.
2121
| `torch` | `torchvision` | Python |
2222
| ------------------ | ------------------ | ------------------- |
2323
| `main` / `nightly` | `main` / `nightly` | `>=3.9`, `<=3.12` |
24-
| `2.5` | `0.20` | `>=3.9`, `<=3.12` |
25-
| `2.4` | `0.19` | `>=3.8`, `<=3.12` |
26-
| `2.3` | `0.18` | `>=3.8`, `<=3.12` |
27-
| `2.2` | `0.17` | `>=3.8`, `<=3.11` |
28-
| `2.1` | `0.16` | `>=3.8`, `<=3.11` |
29-
| `2.0` | `0.15` | `>=3.8`, `<=3.11` |
24+
| `2.8` | `0.23` | `>=3.9`, `<=3.13` |
25+
| `2.7` | `0.22` | `>=3.9`, `<=3.13` |
26+
| `2.6` | `0.21` | `>=3.9`, `<=3.12` |
3027

3128
<details>
3229
<summary>older versions</summary>
3330

3431
| `torch` | `torchvision` | Python |
3532
|---------|-------------------|---------------------------|
33+
| `2.5` | `0.20` | `>=3.9`, `<=3.12` |
34+
| `2.4` | `0.19` | `>=3.8`, `<=3.12` |
35+
| `2.3` | `0.18` | `>=3.8`, `<=3.12` |
36+
| `2.2` | `0.17` | `>=3.8`, `<=3.11` |
37+
| `2.1` | `0.16` | `>=3.8`, `<=3.11` |
38+
| `2.0` | `0.15` | `>=3.8`, `<=3.11` |
3639
| `1.13` | `0.14` | `>=3.7.2`, `<=3.10` |
3740
| `1.12` | `0.13` | `>=3.7`, `<=3.10` |
3841
| `1.11` | `0.12` | `>=3.7`, `<=3.10` |

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,12 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452452
The output maintains the same dtype as the input.
453453
"""
454+
dtype = parallelogram.dtype
455+
acceptable_dtypes = [torch.float32, torch.float64]
456+
need_cast = dtype not in acceptable_dtypes
457+
if need_cast:
458+
# Up-case to avoid overflow for square operations
459+
parallelogram = parallelogram.to(torch.float32)
454460
out_boxes = parallelogram.clone()
455461

456462
# Calculate parallelogram diagonal vectors
@@ -489,6 +495,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
489495
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
490496
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
491497
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
498+
499+
if need_cast:
500+
out_boxes = out_boxes.to(dtype)
501+
492502
return out_boxes
493503

494504

torchvision/transforms/v2/functional/_meta.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
194194
if not inplace:
195195
cxcywhr = cxcywhr.clone()
196196

197-
dtype = cxcywhr.dtype
198-
need_cast = not cxcywhr.is_floating_point()
199-
if need_cast:
200-
cxcywhr = cxcywhr.float()
201-
202197
half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_()
203198
r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0)
204199
cos, sin = r_rad.cos(), r_rad.sin()
@@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
207202
# (cy + width / 2 * sin - height / 2 * cos) = y1
208203
cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos))
209204

210-
if need_cast:
211-
cxcywhr.round_()
212-
cxcywhr = cxcywhr.to(dtype)
213-
214205
return cxcywhr
215206

216207

217208
def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
218209
if not inplace:
219210
xywhr = xywhr.clone()
220211

221-
dtype = xywhr.dtype
222-
need_cast = not xywhr.is_floating_point()
223-
if need_cast:
224-
xywhr = xywhr.float()
225-
226212
half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_()
227213
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
228214
cos, sin = r_rad.cos(), r_rad.sin()
@@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
231217
# (y1 - width / 2 * sin + height / 2 * cos) = cy
232218
xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos))
233219

234-
if need_cast:
235-
xywhr.round_()
236-
xywhr = xywhr.to(dtype)
237-
238220
return xywhr
239221

240222

@@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
243225
if not inplace:
244226
xywhr = xywhr.clone()
245227

246-
dtype = xywhr.dtype
247-
need_cast = not xywhr.is_floating_point()
248-
if need_cast:
249-
xywhr = xywhr.float()
250-
251228
wh = xywhr[..., 2:-1]
252229
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
253230
cos, sin = r_rad.cos(), r_rad.sin()
@@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
265242
# y1 + h * cos = y4
266243
xywhr[..., 7].add_(wh[..., 1].mul(cos))
267244

268-
if need_cast:
269-
xywhr.round_()
270-
xywhr = xywhr.to(dtype)
271-
272245
return xywhr
273246

274247

@@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
278251
xyxyxyxy = xyxyxyxy.clone()
279252

280253
dtype = xyxyxyxy.dtype
281-
need_cast = not xyxyxyxy.is_floating_point()
254+
acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU.
255+
need_cast = dtype not in acceptable_dtypes
282256
if need_cast:
283-
xyxyxyxy = xyxyxyxy.float()
257+
# Up-case to avoid overflow for square operations
258+
xyxyxyxy = xyxyxyxy.to(torch.float32)
284259

285260
r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
286261
# x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4
@@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
293268
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)
294269

295270
if need_cast:
296-
xyxyxyxy.round_()
297271
xyxyxyxy = xyxyxyxy.to(dtype)
298272

299273
return xyxyxyxy[..., :5]

0 commit comments

Comments
 (0)