Skip to content

Commit

Permalink
change parameters of Lotka_Volterra example
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Aug 16, 2023
1 parent 3ae3418 commit 271e15c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions examples/pinn_forward/Lotka_Volterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import matplotlib.pyplot as plt
import numpy as np
from scipy import integrate

# Import tf if using backend tensorflow.compat.v1 or tensorflow
from deepxde.backend import tf

# Import torch if using backend pytorch
# import torch
# Import paddle if using backend paddle
Expand Down Expand Up @@ -51,6 +53,7 @@ def ode_system(x, y):
dp_t - 1 / ub * rb * (0.02 * r * ub * p * ub - 1.06 * p * ub),
]


geom = dde.geometry.TimeDomain(0, 1.0)
data = dde.data.PDE(geom, ode_system, [], 3000, 2, num_test=3000)

Expand All @@ -59,6 +62,7 @@ def ode_system(x, y):
initializer = "Glorot normal"
net = dde.nn.FNN(layer_size, activation, initializer)


# Backend tensorflow.compat.v1 or tensorflow
def input_transform(t):
return tf.concat(
Expand All @@ -73,6 +77,8 @@ def input_transform(t):
),
axis=1,
)


# Backend pytorch
# def input_transform(t):
# return torch.cat(
Expand All @@ -93,20 +99,23 @@ def input_transform(t):
# def input_transform(t):
# if t.ndim == 1:
# t = t[None]
#
#
# return jnp.concatenate(
# [
# jnp.sin(t),
# ],
# axis=1
# )


# hard constraints: x(0) = 100, y(0) = 15
# Backend tensorflow.compat.v1 or tensorflow
def output_transform(t, y):
y1 = y[:, 0:1]
y2 = y[:, 1:2]
return tf.concat([y1 * tf.tanh(t) + 100 / ub, y2 * tf.tanh(t) + 15 / ub], axis=1)


# Backend pytorch
# def output_transform(t, y):
# y1 = y[:, 0:1]
Expand All @@ -131,9 +140,9 @@ def output_transform(t, y):
model = dde.Model(data, net)

model.compile("adam", lr=0.001)
losshistory, train_state = model.train(iterations=1000)
losshistory, train_state = model.train(iterations=50000)
# Most backends except jax can have a second fine tuning of the solution
model.compile("L-BFGS", lr=0.001)
model.compile("L-BFGS")
losshistory, train_state = model.train()
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

Expand Down

0 comments on commit 271e15c

Please sign in to comment.