Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional flag to clamp gradient in 'backward' to prevent crash #48

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions apps/gaussian_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.ndimage.filters as F


def render(canvas_width, canvas_height, shapes, shape_groups):
def render(canvas_width, canvas_height, shapes, shape_groups, backward_clamp_gradient_mag=None):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(\
canvas_width, canvas_height, shapes, shape_groups)
Expand All @@ -15,7 +15,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups):
2, # num_samples_x
2, # num_samples_y
0, # seed
None,
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
7 changes: 4 additions & 3 deletions apps/generative_models/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def _onehot(label):
return label_onehot.float()


def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(
canvas_width, canvas_height, shapes, shape_groups)
img = _render(canvas_width,
canvas_height,
samples,
samples,
0,
None,
0, # seed
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
3 changes: 2 additions & 1 deletion apps/generative_models/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def render(canvas_width, canvas_height, shapes, shape_groups, samples=2,
seed=None):
seed=None, backward_clamp_gradient_mag=None):
if seed is None:
seed = random.randint(0, 1000000)
_render = pydiffvg.RenderFunction.apply
Expand All @@ -21,6 +21,7 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2,
img = _render(canvas_width, canvas_height, samples, samples,
seed, # seed
None, # background image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
5 changes: 3 additions & 2 deletions apps/render_svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch as th


def render(canvas_width, canvas_height, shapes, shape_groups):
def render(canvas_width, canvas_height, shapes, shape_groups, backward_clamp_gradient_mag=None):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(\
canvas_width, canvas_height, shapes, shape_groups)
Expand All @@ -16,7 +16,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups):
2, # num_samples_x
2, # num_samples_y
0, # seed
None,
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
5 changes: 3 additions & 2 deletions apps/seam_carving.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def carve_seam(im):
return new_im


def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(\
canvas_width, canvas_height, shapes, shape_groups)
Expand All @@ -116,7 +116,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
samples, # num_samples_x
samples, # num_samples_y
0, # seed
None,
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
5 changes: 3 additions & 2 deletions apps/svg_brush.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def checkerboard(shape, square_size=2):
res=bin*np.array([[[1., 1., 1.,]]])+(1-bin)*np.array([[[.75, .75, .75,]]])
return torch.tensor(res,requires_grad=False,dtype=torch.float32)

def render(optim, viewport):
def render(optim, viewport, backward_clamp_gradient_mag=None):
scene_args = pydiffvg.RenderFunction.serialize_scene(*optim.build_scene())
render = pydiffvg.RenderFunction.apply
img = render(viewport[0], # width
viewport[1], # height
2, # num_samples_x
2, # num_samples_y
0, # seed
None,
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
5 changes: 3 additions & 2 deletions apps/texture_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def texture_syn(img_path):
return np.array(target_img)


def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(\
canvas_width, canvas_height, shapes, shape_groups)
Expand All @@ -42,7 +42,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
samples, # num_samples_x
samples, # num_samples_y
0, # seed
None,
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
3 changes: 2 additions & 1 deletion pydiffvg/optimize_svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def zero_grad(self):
if issubclass(item.__class__,OptimizableSvg.SvgNode):
item.zero_grad()

def render(self,scale=None,seed=0):
def render(self,scale=None,seed=0,backward_clamp_gradient_mag=None):
#render at native resolution
scene = self.build_scene()
scene_args = pydiffvg.RenderFunction.serialize_scene(*scene)
Expand All @@ -1029,6 +1029,7 @@ def render(self,scale=None,seed=0):
2, # num_samples_y
seed, # seed
None, # background_image
backward_clamp_gradient_mag,
*scene_args)
return img

Expand Down
37 changes: 33 additions & 4 deletions pydiffvg/render_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def serialize_scene(canvas_width,
else:
args.append(torch.zeros(shape.points.shape[0] - 1, dtype = torch.int32))
args.append(shape.points.cpu())
args.append(None)
args.append(None)
args.append(shape.is_closed)
args.append(False) # use_distance_approx
elif isinstance(shape, pydiffvg.Rect):
Expand Down Expand Up @@ -179,6 +179,7 @@ def forward(ctx,
num_samples_y,
seed,
background_image,
backward_clamp_gradient_mag,
*args):
"""
Forward rendering pass.
Expand Down Expand Up @@ -375,7 +376,7 @@ def forward(ctx,
assert(eval_positions.shape[0] == 0)
rendered_image = torch.zeros(height, width, 4, device = pydiffvg.get_device())
else:
assert(output_type == OutputType.sdf)
assert(output_type == OutputType.sdf)
if eval_positions.shape[0] == 0:
rendered_image = torch.zeros(height, width, 1, device = pydiffvg.get_device())
else:
Expand Down Expand Up @@ -427,6 +428,7 @@ def forward(ctx,
ctx.output_type = output_type
ctx.use_prefiltering = use_prefiltering
ctx.eval_positions = eval_positions
ctx.backward_clamp_gradient_mag = backward_clamp_gradient_mag
return rendered_image

@staticmethod
Expand Down Expand Up @@ -457,7 +459,7 @@ def render_grad(grad_img,
use_prefiltering = args[current_index]
current_index += 1
eval_positions = args[current_index]
current_index += 1
current_index += 1
shapes = []
shape_groups = []
shape_contents = [] # Important to avoid GC deleting the shapes
Expand Down Expand Up @@ -672,7 +674,6 @@ def backward(ctx,
grad_img):
if not grad_img.is_contiguous():
grad_img = grad_img.contiguous()
assert(torch.isfinite(grad_img).all())

scene = ctx.scene
width = ctx.width
Expand All @@ -684,6 +685,33 @@ def backward(ctx,
use_prefiltering = ctx.use_prefiltering
eval_positions = ctx.eval_positions
background_image = ctx.background_image
backward_clamp_gradient_mag = ctx.backward_clamp_gradient_mag

if backward_clamp_gradient_mag is None:
assert torch.isfinite(grad_img).all()
else:
try:
assert torch.isfinite(grad_img).all()
except:
# backward_clamp_gradient_mag can be:
# - A single float or int defining the magnitude of the clamp in both directions
# - A sequence of at least one or two floats or ints defining the magnitude
# of the clamp in the min (element 0) and max (element 1, or element 0 if only one element) direction
# To print a warning to the console when the gradient is not finite, pass a sequence of length 3.
# The third element is treated as a Boolean and if True, a warning is printed.
if type(backward_clamp_gradient_mag) is int or type(backward_clamp_gradient_mag) is float:
min_ = -float(abs(backward_clamp_gradient_mag))
max_ = +float(abs(backward_clamp_gradient_mag))
elif len(backward_clamp_gradient_mag) == 1:
min_ = -float(abs(backward_clamp_gradient_mag[0]))
max_ = +float(abs(backward_clamp_gradient_mag[0]))
elif len(backward_clamp_gradient_mag) >= 2:
min_ = -float(abs(backward_clamp_gradient_mag[0]))
max_ = +float(abs(backward_clamp_gradient_mag[1]))
if len(backward_clamp_gradient_mag) >= 3 and backward_clamp_gradient_mag[2]:
print(f'Pydiffvg::backward "isfinite" assertion failed: clamping gradient to: {min_}/{max_}')
backward_clamp_gradient_mag = {"min": min_, "max": max_}
grad_img = torch.clamp(grad_img, **backward_clamp_gradient_mag)

if background_image is not None:
d_background_image = torch.zeros_like(background_image)
Expand Down Expand Up @@ -719,6 +747,7 @@ def backward(ctx,
d_args.append(None) # num_samples_y
d_args.append(None) # seed
d_args.append(d_background_image)
d_args.append(None) # backward_clamp_gradient_mag
d_args.append(None) # canvas_width
d_args.append(None) # canvas_height
d_args.append(None) # num_shapes
Expand Down