diff --git a/deepxde/callbacks.py b/deepxde/callbacks.py index 3039390df..07b066e13 100644 --- a/deepxde/callbacks.py +++ b/deepxde/callbacks.py @@ -551,7 +551,6 @@ def __init__(self, period=100, pde_points=True, bc_points=False): self.period = period self.pde_points = pde_points self.bc_points = bc_points - self.num_bcs_initial = None self.epochs_since_last_resample = 0 @@ -571,3 +570,50 @@ def on_epoch_end(self): raise ValueError( "`num_bcs` changed! Please update the loss function by `model.compile`." ) + + +class PrintLossWeight(Callback): + """Print the loss weights every period epochs. + + Args: + period: Interval (number of epochs) between printing loss weights. + """ + + def __init__(self, period): + super().__init__() + self.period = period + self.initial_loss_weights = None + self.current_loss_weights = None + + def on_epoch_begin(self): + if self.model.train_state.epoch == 0: + self.initial_loss_weights = self.model.loss_weights.numpy().tolist() + else: + self.current_loss_weights = self.model.loss_weights.numpy().tolist() + if self.model.train_state.epoch % self.period == 0: + print("Initial loss weights:", self.initial_loss_weights) + print("Current loss weights:", self.current_loss_weights) + + +class ManualDynamicLossWeight(Callback): + """Change the loss weights at a specific epoch. + + Args: + epoch2change: The epoch at which to change the loss weight + value: The value to change the loss weight to + idx: The index of the loss weight to change + """ + + def __init__(self, epoch2change, value, loss_idx): + super().__init__() + self.epoch2change = epoch2change + self.value = value + self.loss_idx = loss_idx + + def on_epoch_begin(self): + if self.model.train_state.epoch == self.epoch2change: + current_loss_weights = self.model.loss_weights.numpy() + current_loss_weights[self.loss_idx] = self.value + self.model.loss_weights = tf.convert_to_tensor( + current_loss_weights, dtype=config.default_float() + ) diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..13d742af3 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -119,7 +119,9 @@ def compile( print("Compiling model...") self.opt_name = optimizer loss_fn = losses_module.get(loss) - self.loss_weights = loss_weights + self.loss_weights = tf.convert_to_tensor( + loss_weights, dtype=config.default_float() + ) if external_trainable_variables is None: self.external_trainable_variables = [] else: @@ -202,7 +204,9 @@ def _compile_tensorflow(self, lr, loss_fn, decay): def outputs(training, inputs): return self.net(inputs, training=training) - def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn): + def outputs_losses( + training, inputs, targets, auxiliary_vars, losses_fn, loss_weights + ): self.net.auxiliary_vars = auxiliary_vars # Don't call outputs() decorated by @tf.function above, otherwise the # gradient of outputs wrt inputs will be lost here. @@ -218,29 +222,41 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn): losses += [tf.math.reduce_sum(self.net.losses)] losses = tf.convert_to_tensor(losses) # Weighted losses - if self.loss_weights is not None: - losses *= self.loss_weights + if loss_weights is not None: + losses *= loss_weights return outputs_, losses @tf.function(jit_compile=config.xla_jit) - def outputs_losses_train(inputs, targets, auxiliary_vars): + def outputs_losses_train(inputs, targets, auxiliary_vars, loss_weights): return outputs_losses( - True, inputs, targets, auxiliary_vars, self.data.losses_train + True, + inputs, + targets, + auxiliary_vars, + self.data.losses_train, + loss_weights, ) @tf.function(jit_compile=config.xla_jit) - def outputs_losses_test(inputs, targets, auxiliary_vars): + def outputs_losses_test(inputs, targets, auxiliary_vars, loss_weights): return outputs_losses( - False, inputs, targets, auxiliary_vars, self.data.losses_test + False, + inputs, + targets, + auxiliary_vars, + self.data.losses_test, + loss_weights, ) opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay) @tf.function(jit_compile=config.xla_jit) - def train_step(inputs, targets, auxiliary_vars): + def train_step(inputs, targets, auxiliary_vars, loss_weights): # inputs and targets are np.ndarray and automatically converted to Tensor. with tf.GradientTape() as tape: - losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + losses = outputs_losses_train( + inputs, targets, auxiliary_vars, loss_weights + )[1] total_loss = tf.math.reduce_sum(losses) trainable_variables = ( self.net.trainable_variables + self.external_trainable_variables @@ -531,7 +547,7 @@ def _outputs(self, training, inputs): outs = self.outputs(self.net.params, training, inputs) return utils.to_numpy(outs) - def _outputs_losses(self, training, inputs, targets, auxiliary_vars): + def _outputs_losses(self, training, inputs, targets, auxiliary_vars, loss_weights): if training: outputs_losses = self.outputs_losses_train else: @@ -540,7 +556,7 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars): feed_dict = self.net.feed_dict(training, inputs, targets, auxiliary_vars) return self.sess.run(outputs_losses, feed_dict=feed_dict) if backend_name == "tensorflow": - outs = outputs_losses(inputs, targets, auxiliary_vars) + outs = outputs_losses(inputs, targets, auxiliary_vars, loss_weights) elif backend_name == "pytorch": self.net.requires_grad_(requires_grad=False) outs = outputs_losses(inputs, targets, auxiliary_vars) @@ -552,12 +568,12 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars): outs = outputs_losses(inputs, targets, auxiliary_vars) return utils.to_numpy(outs[0]), utils.to_numpy(outs[1]) - def _train_step(self, inputs, targets, auxiliary_vars): + def _train_step(self, inputs, targets, auxiliary_vars, loss_weights): if backend_name == "tensorflow.compat.v1": feed_dict = self.net.feed_dict(True, inputs, targets, auxiliary_vars) self.sess.run(self.train_step, feed_dict=feed_dict) elif backend_name in ["tensorflow", "paddle"]: - self.train_step(inputs, targets, auxiliary_vars) + self.train_step(inputs, targets, auxiliary_vars, loss_weights) elif backend_name == "pytorch": self.train_step(inputs, targets, auxiliary_vars) elif backend_name == "jax": @@ -669,6 +685,7 @@ def _train_sgd(self, iterations, display_every): self.train_state.X_train, self.train_state.y_train, self.train_state.train_aux_vars, + self.loss_weights, ) self.train_state.epoch += 1 @@ -827,12 +844,14 @@ def _test(self): self.train_state.X_train, self.train_state.y_train, self.train_state.train_aux_vars, + self.loss_weights, ) self.train_state.y_pred_test, self.train_state.loss_test = self._outputs_losses( False, self.train_state.X_test, self.train_state.y_test, self.train_state.test_aux_vars, + self.loss_weights, ) if isinstance(self.train_state.y_test, (list, tuple)): diff --git a/examples/pinn_inverse/elliptic_inverse_field_manual_dynamic_loss_weights.py b/examples/pinn_inverse/elliptic_inverse_field_manual_dynamic_loss_weights.py new file mode 100644 index 000000000..fb3f95d7e --- /dev/null +++ b/examples/pinn_inverse/elliptic_inverse_field_manual_dynamic_loss_weights.py @@ -0,0 +1,81 @@ +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" + +# import sys +import deepxde as dde +import matplotlib.pyplot as plt +import numpy as np +from deepxde.callbacks import PrintLossWeight, ManualDynamicLossWeight + +dde.config.disable_xla_jit() +from deepxde.backend import set_default_backend + +set_default_backend("tensorflow") + + +def gen_traindata(num): + # generate num equally-spaced points from -1 to 1 + xvals = np.linspace(-1, 1, num).reshape(num, 1) + uvals = np.sin(np.pi * xvals) + return xvals, uvals + + +def pde(x, y): + u, q = y[:, 0:1], y[:, 1:2] + du_xx = dde.grad.hessian(y, x, component=0, i=0, j=0) + return -du_xx + q + + +def sol(x): + # solution is u(x) = sin(pi*x), q(x) = -pi^2 * sin(pi*x) + return np.sin(np.pi * x) + + +geom = dde.geometry.Interval(-1, 1) +bc = dde.icbc.DirichletBC(geom, sol, lambda _, on_boundary: on_boundary, component=0) +ob_x, ob_u = gen_traindata(100) +observe_u = dde.icbc.PointSetBC(ob_x, ob_u, component=0) + +data = dde.data.PDE( + geom, + pde, + [bc, observe_u], + num_domain=200, + num_boundary=2, + anchors=ob_x, + num_test=1000, +) + +net = dde.nn.FNN([1, 40, 40, 40, 2], "tanh", "Glorot uniform") +PrintLossWeight_cb = PrintLossWeight(period=1) +ManualDynamicLossWeight_cb = ManualDynamicLossWeight( + epoch2change=5000, value=1, loss_idx=0 +) +model = dde.Model(data, net) +model.compile("adam", lr=0.0001, loss_weights=[0, 100, 1000]) +losshistory, train_state = model.train( + iterations=20000, + display_every=1, + callbacks=[PrintLossWeight_cb, ManualDynamicLossWeight_cb], +) +# dde.saveplot(losshistory, train_state, issave=True, isplot=True) + +# view results +x = geom.uniform_points(500) +yhat = model.predict(x) +uhat, qhat = yhat[:, 0:1], yhat[:, 1:2] + +utrue = np.sin(np.pi * x) +print("l2 relative error for u: " + str(dde.metrics.l2_relative_error(utrue, uhat))) +plt.figure() +plt.plot(x, utrue, "-", label="u_true") +plt.plot(x, uhat, "--", label="u_NN") +plt.legend() + +qtrue = -np.pi**2 * np.sin(np.pi * x) +print("l2 relative error for q: " + str(dde.metrics.l2_relative_error(qtrue, qhat))) +plt.figure() +plt.plot(x, qtrue, "-", label="q_true") +plt.plot(x, qhat, "--", label="q_NN") +plt.legend() + +plt.show()