diff --git a/apps/gaussian_blur.py b/apps/gaussian_blur.py index 8d148026..bdb9c9ae 100644 --- a/apps/gaussian_blur.py +++ b/apps/gaussian_blur.py @@ -6,7 +6,7 @@ import scipy.ndimage.filters as F -def render(canvas_width, canvas_height, shapes, shape_groups): +def render(canvas_width, canvas_height, shapes, shape_groups, backward_clamp_gradient_mag=None): _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) @@ -15,7 +15,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups): 2, # num_samples_x 2, # num_samples_y 0, # seed - None, + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/generative_models/mnist_vae.py b/apps/generative_models/mnist_vae.py index c0da6265..e403a26f 100644 --- a/apps/generative_models/mnist_vae.py +++ b/apps/generative_models/mnist_vae.py @@ -47,7 +47,7 @@ def _onehot(label): return label_onehot.float() -def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): +def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None): _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) @@ -55,8 +55,9 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): canvas_height, samples, samples, - 0, - None, + 0, # seed + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/generative_models/rendering.py b/apps/generative_models/rendering.py index 4ef475ec..9459c952 100644 --- a/apps/generative_models/rendering.py +++ b/apps/generative_models/rendering.py @@ -12,7 +12,7 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, - seed=None): + seed=None, backward_clamp_gradient_mag=None): if seed is None: seed = random.randint(0, 1000000) _render = pydiffvg.RenderFunction.apply @@ -21,6 +21,7 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, img = _render(canvas_width, canvas_height, samples, samples, seed, # seed None, # background image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/render_svg.py b/apps/render_svg.py index 0aa92736..cc237c48 100644 --- a/apps/render_svg.py +++ b/apps/render_svg.py @@ -7,7 +7,7 @@ import torch as th -def render(canvas_width, canvas_height, shapes, shape_groups): +def render(canvas_width, canvas_height, shapes, shape_groups, backward_clamp_gradient_mag=None): _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) @@ -16,7 +16,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups): 2, # num_samples_x 2, # num_samples_y 0, # seed - None, + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/seam_carving.py b/apps/seam_carving.py index aa0176ef..d75bd36e 100644 --- a/apps/seam_carving.py +++ b/apps/seam_carving.py @@ -106,7 +106,7 @@ def carve_seam(im): return new_im -def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): +def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None): _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) @@ -116,7 +116,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): samples, # num_samples_x samples, # num_samples_y 0, # seed - None, + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/svg_brush.py b/apps/svg_brush.py index de54e48e..3dec0e3a 100644 --- a/apps/svg_brush.py +++ b/apps/svg_brush.py @@ -35,7 +35,7 @@ def checkerboard(shape, square_size=2): res=bin*np.array([[[1., 1., 1.,]]])+(1-bin)*np.array([[[.75, .75, .75,]]]) return torch.tensor(res,requires_grad=False,dtype=torch.float32) -def render(optim, viewport): +def render(optim, viewport, backward_clamp_gradient_mag=None): scene_args = pydiffvg.RenderFunction.serialize_scene(*optim.build_scene()) render = pydiffvg.RenderFunction.apply img = render(viewport[0], # width @@ -43,7 +43,8 @@ def render(optim, viewport): 2, # num_samples_x 2, # num_samples_y 0, # seed - None, + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/apps/texture_synthesis.py b/apps/texture_synthesis.py index 3a7ccce0..c837c0dc 100644 --- a/apps/texture_synthesis.py +++ b/apps/texture_synthesis.py @@ -33,7 +33,7 @@ def texture_syn(img_path): return np.array(target_img) -def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): +def render(canvas_width, canvas_height, shapes, shape_groups, samples=2, backward_clamp_gradient_mag=None): _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) @@ -42,7 +42,8 @@ def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): samples, # num_samples_x samples, # num_samples_y 0, # seed - None, + None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/pydiffvg/optimize_svg.py b/pydiffvg/optimize_svg.py index ce0097f5..0c46d2f1 100644 --- a/pydiffvg/optimize_svg.py +++ b/pydiffvg/optimize_svg.py @@ -1017,7 +1017,7 @@ def zero_grad(self): if issubclass(item.__class__,OptimizableSvg.SvgNode): item.zero_grad() - def render(self,scale=None,seed=0): + def render(self,scale=None,seed=0,backward_clamp_gradient_mag=None): #render at native resolution scene = self.build_scene() scene_args = pydiffvg.RenderFunction.serialize_scene(*scene) @@ -1029,6 +1029,7 @@ def render(self,scale=None,seed=0): 2, # num_samples_y seed, # seed None, # background_image + backward_clamp_gradient_mag, *scene_args) return img diff --git a/pydiffvg/render_pytorch.py b/pydiffvg/render_pytorch.py index b776ce67..63d34b38 100644 --- a/pydiffvg/render_pytorch.py +++ b/pydiffvg/render_pytorch.py @@ -81,7 +81,7 @@ def serialize_scene(canvas_width, else: args.append(torch.zeros(shape.points.shape[0] - 1, dtype = torch.int32)) args.append(shape.points.cpu()) - args.append(None) + args.append(None) args.append(shape.is_closed) args.append(False) # use_distance_approx elif isinstance(shape, pydiffvg.Rect): @@ -179,6 +179,7 @@ def forward(ctx, num_samples_y, seed, background_image, + backward_clamp_gradient_mag, *args): """ Forward rendering pass. @@ -375,7 +376,7 @@ def forward(ctx, assert(eval_positions.shape[0] == 0) rendered_image = torch.zeros(height, width, 4, device = pydiffvg.get_device()) else: - assert(output_type == OutputType.sdf) + assert(output_type == OutputType.sdf) if eval_positions.shape[0] == 0: rendered_image = torch.zeros(height, width, 1, device = pydiffvg.get_device()) else: @@ -427,6 +428,7 @@ def forward(ctx, ctx.output_type = output_type ctx.use_prefiltering = use_prefiltering ctx.eval_positions = eval_positions + ctx.backward_clamp_gradient_mag = backward_clamp_gradient_mag return rendered_image @staticmethod @@ -457,7 +459,7 @@ def render_grad(grad_img, use_prefiltering = args[current_index] current_index += 1 eval_positions = args[current_index] - current_index += 1 + current_index += 1 shapes = [] shape_groups = [] shape_contents = [] # Important to avoid GC deleting the shapes @@ -672,7 +674,6 @@ def backward(ctx, grad_img): if not grad_img.is_contiguous(): grad_img = grad_img.contiguous() - assert(torch.isfinite(grad_img).all()) scene = ctx.scene width = ctx.width @@ -684,6 +685,33 @@ def backward(ctx, use_prefiltering = ctx.use_prefiltering eval_positions = ctx.eval_positions background_image = ctx.background_image + backward_clamp_gradient_mag = ctx.backward_clamp_gradient_mag + + if backward_clamp_gradient_mag is None: + assert torch.isfinite(grad_img).all() + else: + try: + assert torch.isfinite(grad_img).all() + except: + # backward_clamp_gradient_mag can be: + # - A single float or int defining the magnitude of the clamp in both directions + # - A sequence of at least one or two floats or ints defining the magnitude + # of the clamp in the min (element 0) and max (element 1, or element 0 if only one element) direction + # To print a warning to the console when the gradient is not finite, pass a sequence of length 3. + # The third element is treated as a Boolean and if True, a warning is printed. + if type(backward_clamp_gradient_mag) is int or type(backward_clamp_gradient_mag) is float: + min_ = -float(abs(backward_clamp_gradient_mag)) + max_ = +float(abs(backward_clamp_gradient_mag)) + elif len(backward_clamp_gradient_mag) == 1: + min_ = -float(abs(backward_clamp_gradient_mag[0])) + max_ = +float(abs(backward_clamp_gradient_mag[0])) + elif len(backward_clamp_gradient_mag) >= 2: + min_ = -float(abs(backward_clamp_gradient_mag[0])) + max_ = +float(abs(backward_clamp_gradient_mag[1])) + if len(backward_clamp_gradient_mag) >= 3 and backward_clamp_gradient_mag[2]: + print(f'Pydiffvg::backward "isfinite" assertion failed: clamping gradient to: {min_}/{max_}') + backward_clamp_gradient_mag = {"min": min_, "max": max_} + grad_img = torch.clamp(grad_img, **backward_clamp_gradient_mag) if background_image is not None: d_background_image = torch.zeros_like(background_image) @@ -719,6 +747,7 @@ def backward(ctx, d_args.append(None) # num_samples_y d_args.append(None) # seed d_args.append(d_background_image) + d_args.append(None) # backward_clamp_gradient_mag d_args.append(None) # canvas_width d_args.append(None) # canvas_height d_args.append(None) # num_shapes