Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert domain_descriptors into a struct #2

Open
wants to merge 1 commit into
base: cleanup
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 33 additions & 119 deletions Tutorial.ipynb

Large diffs are not rendered by default.

123 changes: 20 additions & 103 deletions src/NN_things.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ function skew_symm_NN(
diagonals;
channel,
)

#domain_range,(I,J),(X,x),(Omega,omega),(W,R),(IP,ip),(INTEG,integ) = domain_descriptors

(physics_width, stencil_width, conv_pad_size) = pad_sizes

if constraints
Expand Down Expand Up @@ -364,19 +361,12 @@ function loss_function(
model;
subgrid_loss = true,
)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors

#us,dus,ss,dss,as,bs = stop_gradient() do
# [hcat([k[:,i,j] for i in 1:size(k)[2] for j in 1:size(k)[3]]...) for k in (us,dus,ss,dss,as,bs)]
#end

X = domain_descriptors.grids.coarse

pred_dus, pred_dss = model(us, ss, X, as, bs)

l1 = Flux.Losses.mse(dus, pred_dus)
Expand Down Expand Up @@ -411,15 +401,6 @@ function trajectory_loss_function(
T;
subgrid_loss = true,
)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors

traj_us, traj_ss = stop_gradient() do
[hcat([k[:, i, j] for i = 1:size(k)[2] for j = 1:size(k)[3]]...) for k in (us, ss)]
end
Expand All @@ -429,6 +410,8 @@ function trajectory_loss_function(

current_f = model_wrapper(model, as = BC_as, bs = BC_bs, eval_BCs = true)

X = domain_descriptors.grids.coarse

init_cond = [us[:, 1, :]; ss[:, 1, :]]
NN_var, NN_dus, NN_ts = simulation(
init_cond,
Expand Down Expand Up @@ -567,14 +550,8 @@ function padding(vec, pad_size, a = 0, b = 0; anti_symm_outflow = false)
end

function subgrid_gradients(u_prime, du_prime, domain_descriptors, subgrid_filter)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
I = domain_descriptors.I
J = domain_descriptors.J
function extract_jacobian(u_prime, subgrid_filter = subgrid_filter, I = I, J = J)
jac = jacobian(subgrid_filter, u_prime)[1]
flat_jac = zeros(I * J)
Expand All @@ -587,23 +564,17 @@ function subgrid_gradients(u_prime, du_prime, domain_descriptors, subgrid_filter
end
full_jac = zeros(I, size(u_prime)[2])

R = domain_descriptors.R
for i = 1:size(u_prime)[2]
full_jac[:, i] .+= R' * (extract_jacobian(u_prime[:, i]) .* du_prime[:, i])
end
return full_jac
end

function gen_NN_subgrid_filter(layers, domain_descriptors, outflow = false)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
NN = NeuralNetwork(layers, bias = false, activation_function = tanh)
select_mat = gen_subgrid_filter_select_mat(domain_descriptors)
J = domain_descriptors.J
Px = stop_gradient() do
reverse(Matrix{Float64}(LinearAlgebra.I, J, J), dims = 2)
end
Expand All @@ -615,15 +586,6 @@ function gen_NN_subgrid_filter(layers, domain_descriptors, outflow = false)
Px = Px,
outflow = outflow,
)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors

f1 = vcat([NN(u_primes[select_mat[:, i], :]) for i = 1:size(select_mat)[2]]...)
f2 = vcat([NN(Px * u_primes[select_mat[:, i], :]) for i = 1:size(select_mat)[2]]...)
filtered = f1 .- f2
Expand Down Expand Up @@ -680,27 +642,14 @@ function conv_NN(sizes, channels, strides = 0, bias = true)
end

function gen_subgrid_filter_select_mat(domain_descriptors)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
x = domain_descriptors.grids.fine
J = domain_descriptors.J
I = domain_descriptors.I
selects = reshape(collect(1:size(x)[1]), (J, I))
return selects
end

function gen_t_stencil(params, domain_descriptors, outflow)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
if outflow
lambda = 0.0
t_tilde = params
Expand All @@ -721,14 +670,7 @@ function gen_t_stencil(params, domain_descriptors, outflow)
end

function gen_subgrid_filter(domain_descriptors, outflow = false)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
J = domain_descriptors.J
params = zeros(J) .+ rand(Uniform(-10^(-20), 10^(-20)), J)
select_mat = gen_subgrid_filter_select_mat(domain_descriptors)
function subgrid_filter(
Expand All @@ -748,14 +690,7 @@ end
#bar = gen_S_and_K(subgrid_filter_stencil,"just_avg")

function subgrid_filter_loss(u_primes, subgrid_filter, domain_descriptors)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
W = domain_descriptors.W

filtered = subgrid_filter(u_primes)
filtered_squared = 1 / 2 * (filtered) .^ 2
Expand All @@ -767,14 +702,8 @@ function subgrid_filter_loss(u_primes, subgrid_filter, domain_descriptors)
end

function gen_T(subgrid_filter_stencil, domain_descriptors)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
I = domain_descriptors.I
J = domain_descriptors.J
dimensions = (I, I * J)
mat = spzeros(dimensions)
for i = 1:I
Expand All @@ -784,14 +713,6 @@ function gen_T(subgrid_filter_stencil, domain_descriptors)
end

function gen_train_test_set(d, f, domain_descriptors, fraction, train_fraction)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
us, dus, ts, as, bs, Fs = d
ending_index = Int(floor(train_fraction * size(us)[2]))

Expand Down Expand Up @@ -851,21 +772,14 @@ function gen_trajectory_data(
traj_steps,
T_mat,
)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors

traj_fraction = 1 / traj_steps

us, dus, ts, as, bs, Fs = d

traj_indexes = cut_indexes(indexes, traj_fraction, randomize = false, uniform = false)

I = domain_descriptors.I

traj_data = Dict()
traj_data["u_bar"] = Array{Float64}(undef, I, traj_steps + 1, 0)
traj_data["s"] = Array{Float64}(undef, I, traj_steps + 1, 0)
Expand Down Expand Up @@ -905,6 +819,9 @@ function gen_trajectory_data(
trajectory_data["b"] = reshape(bs[:, trajs], (1, traj_steps + 1, data_size))

#### Process F ########
W = domain_descriptors.W
R = domain_descriptors.R
interpolation_matrix = domain_descriptors.interpolation_matrix
F_bar = W * interpolation_matrix * Fs[:, trajs]
F_prime = T_mat * (interpolation_matrix * Fs[:, trajs] .- R * F_bar)

Expand Down
90 changes: 60 additions & 30 deletions src/discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@ using SparseArrays
using Random
using Distributions

"""
General type for triples of types, one for the fine grid, one for the coarse grid and one for the reference grid.
"""
struct Triplet{T}
coarse::T
fine::T
reference::T
end

Grid = Vector{Float64}
Volume = Matrix{Float64}
InnerProduct = Function
Integrator = Function

struct DomainDescriptors
b::Float64
interpolation_matrix::SparseMatrixCSC{Float64,Int}
N::Int
I::Int
J::Int
W::SparseMatrixCSC{Float64,Int}
R::SparseMatrixCSC{Float64,Int}
grids::Triplet{Grid}
volumes::Triplet{Volume}
inner_products::Triplet{InnerProduct}
integrators::Triplet{Integrator}
end

function gen_stencil(N, coeffs, positions)
mat = spzeros((N, N))
stencil_width = size(coeffs)[1]
Expand Down Expand Up @@ -56,6 +84,9 @@ function generate_domain_and_filters(b, I, N; spectral = false)
ref_x = ref_x .+ 1 / 2 * ref_dx
ref_omega = Diagonal(ref_dx * ones(size(ref_x)[1]))

grids = Triplet{Grid}(X, x, ref_x)
volumes = Triplet{Volume}(Omega, omega, ref_omega)

interpolation_matrix = construct_weighted_interpolation_matrix(ref_x, x)

mapper = dx * ones((J, size(X)[1]))
Expand All @@ -65,24 +96,32 @@ function generate_domain_and_filters(b, I, N; spectral = false)
IP(a, b, omega = Omega) = inner_product(a, b, omega)
ip(a, b, omega = omega) = inner_product(a, b, omega)
ref_ip(a, b, omega = ref_omega) = inner_product(a, b, omega)
inner_products = Triplet{InnerProduct}(IP, ip, ref_ip)

INTEG(a) = IP(a, ones(size(a)))
integ(a) = ip(a, ones(size(a)))
ref_integ(a) = ref_ip(a, ones(size(a)))
integrators = Triplet{Integrator}(INTEG, integ, ref_integ)

if spectral
W, R = spectral_filter(X, x, domain_range)
else
W, R = gen_W(mapper), gen_R(mapper)
end
domain_descriptors = b,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega, ref_omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ)
return domain_descriptors

DomainDescriptors(
b,
interpolation_matrix,
N,
I,
J,
W,
R,
grids,
volumes,
inner_products,
integrators,
)
end

function fourier_basis(x, domain_range)
Expand Down Expand Up @@ -207,31 +246,28 @@ function gen_fourier(
end

function process_HR_solution(us, dus, ts, domain_descriptors, f, return_primes = false)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors
Es = (1 / 2) * ref_ip(us, us)
dEs = ref_ip(us, dus)
ips = domain_descriptors.inner_products
Es = (1 / 2) * ips.reference(us, us)
dEs = ips.reference(us, dus)

interpolation_matrix = domain_descriptors.interpolation_matrix
us = interpolation_matrix * us
dus = interpolation_matrix * dus

W = domain_descriptors.W
X = domain_descriptors.grids.coarse
us_bar = W * us
phys_dus = f(us_bar, X, ts)
dus_bar = W * dus

Es_bar = (1 / 2) * IP(us_bar, us_bar)
Es_bar = (1 / 2) * ips.coarse(us_bar, us_bar)
Es_prime = Es .- Es_bar

dEs_bar = IP(us_bar, dus_bar)
dEs_bar = ips.coarse(us_bar, dus_bar)
dEs_prime = dEs .- dEs_bar

if return_primes
R = domain_descriptors.R
HR_us_bar = R * us_bar
HR_us_prime = us .- HR_us_bar

Expand Down Expand Up @@ -449,17 +485,11 @@ function gen_rand_condition(
scaling = 1,
in_outflow = false,
)
domain_range,
interpolation_matrix,
(N, I, J),
(X, x, ref_x),
(Omega, omega),
(W, R),
(IP, ip, ref_ip),
(INTEG, integ, ref_integ) = domain_descriptors

b(t) = NaN * t

domain_range = domain_descriptors.b
ref_x = domain_descriptors.grids.reference

if in_outflow
a = gen_fourier(2 * pi, min_mode, max_mode, offset = offset, scaling = scaling)

Expand Down
Loading