Skip to content

Rendering of occluded objects looks strange #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
akarsakov opened this issue May 29, 2020 · 9 comments
Closed

Rendering of occluded objects looks strange #217

akarsakov opened this issue May 29, 2020 · 9 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@akarsakov
Copy link

akarsakov commented May 29, 2020

Hello,

I'm trying to render teeth arch using pytorch3d and got strange result (please see screenshot):

Screenshot 2020-05-29 at 14 58 30

When I removed some occluded teeth rendered image was ok.
I've tried to decimate mesh and reduced number of faces by 4 times and rendering was ok. Is it some limitation of framework? Could you please help me with this issue?

Here is code to reproduce:

import torch
import numpy as np
import matplotlib.pyplot as plt
import trimesh

from pytorch3d.structures import Meshes, Textures
from pytorch3d.transforms import euler_angles_to_matrix

from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    HardPhongShader, PointLights
)

device = torch.device("cuda:0")
torch.cuda.set_device(device)

stl_file = "teeth.stl"

def load_mesh(fn):
    mesh = trimesh.load_mesh(fn)
    
    verts, faces = torch.from_numpy(mesh.vertices).float(), torch.from_numpy(mesh.faces)
    verts /= 1000

    verts_rgb = torch.ones_like(verts)[None]
    textures = Textures(verts_rgb=verts_rgb.to(device))

    mesh = Meshes(
        verts=[verts.to(device)],   
        faces=[faces.to(device)], 
        textures=textures
    )
    
    return mesh

orig_mesh = load_mesh(stl_file)

cameras = OpenGLPerspectiveCameras(device=device, fov=0.27, degrees=False)
blend_params = BlendParams(sigma=1e-2, gamma=1e-7)
raster_settings = RasterizationSettings(
    image_size=1024, 
    blur_radius=0., 
    faces_per_pixel=10, 
)

lights = PointLights(device=device, location=((2.0, -2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras)
)

R = euler_angles_to_matrix(torch.tensor(np.radians([90, 40, 0])).float(), convention="XYZ").unsqueeze(0).cuda()
T = torch.tensor([0, 0, 0.5]).float().unsqueeze(0).cuda()

image_ref = phong_renderer(meshes_world=orig_mesh, R=R, T=T).squeeze()

plt.figure(figsize=(10, 10))
plt.imshow(image_ref.cpu().numpy()[:, :, 1], cmap='gray')

Thanks in advance!

Here is stl:
teeth.stl.zip

@nikhilaravi nikhilaravi added the how to How to use PyTorch3D in my project label May 29, 2020
@nikhilaravi nikhilaravi self-assigned this May 29, 2020
@nikhilaravi
Copy link
Contributor

@akarsakov thank you for providing a very clear description of the issue and code to reproduce it! Off the top of my head one option could be to try backface culling - cull_backfaces is an option in raster_settings. But I will investigate this further and get back to you!

@akarsakov
Copy link
Author

Hi @nikhilaravi,

Thank you for fast response!
I've tried backface culling, it doesn't help..

@gkioxari
Copy link
Contributor

Hi @akarsakov
I wanted to let you know I am working on this and will let you know if I find the issue!

@gkioxari
Copy link
Contributor

gkioxari commented Jul 10, 2020

@akarsakov
I took your code and reproduced it and what I get is different than you. See image below. So I am not sure what is going on. The code is exactly the same as yours with the only difference that I converted the stl file to an obj file and used the load_obj function from PyTorch3D. Here is the obj file: teeth.obj.zip

teeth

@akarsakov
Copy link
Author

Hi @gkioxari,
Thank you so much for your help!
I've also tried to reproduce with obj file you attached and it's still broken.
That's interesting, probably the issue with the environment. I've created colab notebook reproducing example: https://colab.research.google.com/drive/1STWJK3mk38bs9CUnqFcyxAHcCsBMP1po?usp=sharing
I'll be really appreciate if you look at it and play around and also compare with your code/environment/libs.

Thanks in advance!

@jcjohnson
Copy link
Contributor

Hi @abhshkdz, thanks for setting up a reproducing example! Can you change the permissions on the Colab notebook to be world-readable? I can't open it at the moment.

@akarsakov
Copy link
Author

@bottler
Copy link
Contributor

bottler commented Oct 12, 2020

This happens because of an overflow in statically allocated bins in the coarse-to-fine rasterization. The explanation is the same as #348. You can fix it by increasing max_faces_per_bin - e.g. adding max_faces_per_bin=30000 to the rasterization settings looks OK. Alternatively you can use the naive rasterizer, by adding bin_size=0 to the RasterizationSettings.

@akarsakov
Copy link
Author

@bottler
Thanks for explanation! Increasing value of max_faces_per_bin works perfectly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

5 participants