-
Notifications
You must be signed in to change notification settings - Fork 2
/
py3d_rotate.py
60 lines (37 loc) · 1.61 KB
/
py3d_rotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from pytorch3d.structures import Pointclouds
device = torch.device("cuda:0")
def my_rotate(in_data,in_depth,jiao,renderer):
batch,c,h,w = in_depth.shape
x,y = torch.meshgrid(torch.arange(w),torch.arange(h))
in_data = in_data.permute(0,3,2,1).float()
in_depth = in_depth.permute(0,3,2,1)
in_depth[..., 0][in_depth[..., 0] == 0] = 1020
verts = torch.zeros((batch,h,w,3),).to(device)
verts[:,:,:,0]=x-w/2
verts[:,:,:,1]=h/2-y
verts[:,:,:,2]=-in_depth[:,:,:,0]+127.5
verts = verts.reshape(batch,h*w,3)
x= torch.unsqueeze(x,2)
x = torch.unsqueeze(x,0)
R_depth = ((x - w / 2)* torch.sin(torch.tensor(jiao * 3.14159 / 180))).to(device) \
+ (in_depth - 127.5)* torch.cos(torch.tensor(jiao * 3.14159 / 180).to(device))
rgb = torch.cat((in_data.reshape(batch,h*w,3),R_depth.reshape(batch,h*w,1)),dim=2)
point_cloud = Pointclouds(points=verts, features=rgb)
images = renderer(point_cloud)
images[..., 3][images[..., 3] == 0] = 1020
return images[..., :3], torch.unsqueeze(images[..., 3], 3)
def rotate_back(in_data,in_depth,renderer):
batch,c,h,w = in_data.shape
x,y = torch.meshgrid(torch.arange(w),torch.arange(h))
in_data = in_data.permute(0,3,2,1).float()
in_depth = in_depth.permute(0,2,1,3)
verts = torch.zeros((batch,h,w,3),).to(device)
verts[:,:,:,0]=x-w/2
verts[:,:,:,1]=h/2-y
verts[:,:,:,2]=-in_depth[:,:,:,0]
verts = verts.reshape(batch,h*w,3)
rgb = in_data.reshape(batch,h*w,1)
point_cloud = Pointclouds(points=verts, features=rgb)
images = renderer(point_cloud)
return images