@@ -123,12 +123,12 @@ def forward(
123
123
124
124
ray_directions = ray_directions .reshape (- 1 , 3 )
125
125
mask_intersect = mask_intersect .reshape (- 1 )
126
+ # pyre-fixme[9]: object_mask has type `BoolTensor`; used as `Tensor`.
126
127
object_mask = object_mask .reshape (- 1 )
127
128
128
129
in_mask = ~ network_object_mask & object_mask & ~ sampler_mask
129
130
out_mask = ~ object_mask & ~ sampler_mask
130
131
131
- # pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
132
132
mask_left_out = (in_mask | out_mask ) & ~ mask_intersect
133
133
if (
134
134
mask_left_out .sum () > 0
@@ -410,10 +410,17 @@ def ray_sampler(
410
410
if n_p_out > 0 :
411
411
out_pts_idx = torch .argmin (sdf_val [p_out_mask , :], - 1 )
412
412
sampler_pts [mask_intersect_idx [p_out_mask ]] = points [p_out_mask , :, :][
413
- torch .arange (n_p_out ), out_pts_idx , :
413
+ # pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
414
+ # but got `Tensor`.
415
+ torch .arange (n_p_out ),
416
+ out_pts_idx ,
417
+ :,
414
418
]
415
419
sampler_dists [mask_intersect_idx [p_out_mask ]] = pts_intervals [
416
- p_out_mask , :
420
+ p_out_mask ,
421
+ :
422
+ # pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but
423
+ # got `Tensor`.
417
424
][torch .arange (n_p_out ), out_pts_idx ]
418
425
419
426
# Get Network object mask
@@ -434,10 +441,16 @@ def ray_sampler(
434
441
secant_pts
435
442
]
436
443
z_low = pts_intervals [secant_pts ][
437
- torch .arange (n_secant_pts ), sampler_pts_ind [secant_pts ] - 1
444
+ # pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
445
+ # but got `Tensor`.
446
+ torch .arange (n_secant_pts ),
447
+ sampler_pts_ind [secant_pts ] - 1 ,
438
448
]
439
449
sdf_low = sdf_val [secant_pts ][
440
- torch .arange (n_secant_pts ), sampler_pts_ind [secant_pts ] - 1
450
+ # pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
451
+ # but got `Tensor`.
452
+ torch .arange (n_secant_pts ),
453
+ sampler_pts_ind [secant_pts ] - 1 ,
441
454
]
442
455
cam_loc_secant = cam_loc .reshape (- 1 , 3 )[mask_intersect_idx [secant_pts ]]
443
456
ray_directions_secant = ray_directions .reshape ((- 1 , 3 ))[
@@ -514,6 +527,7 @@ def minimal_sdf_points(
514
527
mask_max_dis = max_dis [mask ].unsqueeze (- 1 )
515
528
mask_min_dis = min_dis [mask ].unsqueeze (- 1 )
516
529
steps = (
530
+ # pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
517
531
steps .unsqueeze (0 ).repeat (n_mask_points , 1 ) * (mask_max_dis - mask_min_dis )
518
532
+ mask_min_dis
519
533
)
@@ -533,8 +547,13 @@ def minimal_sdf_points(
533
547
mask_sdf_all = torch .cat (mask_sdf_all ).reshape (- 1 , n )
534
548
min_vals , min_idx = mask_sdf_all .min (- 1 )
535
549
min_mask_points = mask_points_all .reshape (- 1 , n , 3 )[
536
- torch .arange (0 , n_mask_points ), min_idx
550
+ # pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but
551
+ # got `Tensor`.
552
+ torch .arange (0 , n_mask_points ),
553
+ min_idx ,
537
554
]
555
+ # pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but got
556
+ # `Tensor`.
538
557
min_mask_dist = steps .reshape (- 1 , n )[torch .arange (0 , n_mask_points ), min_idx ]
539
558
540
559
return min_mask_points , min_mask_dist
@@ -553,6 +572,7 @@ def _get_sphere_intersection(
553
572
# cam_loc = cam_loc.unsqueeze(-1)
554
573
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
555
574
ray_cam_dot = (ray_directions * cam_loc ).sum (- 1 ) # n_images x n_rays
575
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
556
576
under_sqrt = ray_cam_dot ** 2 - (cam_loc .norm (2 , dim = - 1 ) ** 2 - r ** 2 )
557
577
558
578
under_sqrt = under_sqrt .reshape (- 1 )
0 commit comments