Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed broken references #3

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/feature_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
FeatureGraph(nf, ef, senders, receivers)

Data structure that is used as an input for the [GraphNetCore.GraphNetwork](@ref).
Data structure that is used as an input for the [`GraphNetwork`](@ref).

# Arguments
- `nf`: Node features of the graph.
Expand Down
30 changes: 15 additions & 15 deletions src/graph_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include("graph_net_blocks.jl")
The central data structure that contains the neural network and the normalisers corresponding to the components of the GNN (edge features, node features and output).

# Arguments
- `model`: The Enocde-Process-Decode model as a [Lux.Chain](@ref).
- `model`: The Enocde-Process-Decode model as a [Lux](https://github.com/LuxDL/Lux.jl) Chain.
- `ps`: Parameters of the model.
- `st`: State of the model.
- `e_norm`: Normaliser for the edge features of the GNN.
Expand Down Expand Up @@ -57,7 +57,7 @@ end
"""
build_model(quantities_size::Integer, dims, output_size::Integer, mps::Integer, layer_size::Integer, hidden_layers::Integer, device::Function)

Constructs the Encode-Process-Decode model as a [Lux.Chain](@ref) with the given arguments.
Constructs the Encode-Process-Decode model as a [Lux](https://github.com/LuxDL/Lux.jl) Chain with the given arguments.

# Arguments
- `quantities_size`: Sum of dimensions of each node feature.
Expand All @@ -66,10 +66,10 @@ Constructs the Encode-Process-Decode model as a [Lux.Chain](@ref) with the given
- `mps`: Number of message passing steps.
- `layer_size`: Size of hidden layers.
- `hidden_layers`: Number of hidden layers.
- `device`: Device where the model should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
- `device`: Device where the model should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).

# Returns
- `model`: The Encode-Process-Decode model as a [Lux.Chain](@ref).
- `model`: The Encode-Process-Decode model as a [Lux](https://github.com/LuxDL/Lux.jl) Chain.
"""
function build_model(quantities_size::Integer, dims, output_size::Integer, mps::Integer, layer_size::Integer, hidden_layers::Integer, device::Function)
encoder = Encoder(build_mlp(quantities_size, layer_size, layer_size, hidden_layers, dev=device), build_mlp(dims + 1, layer_size, layer_size, hidden_layers, dev=device))
Expand Down Expand Up @@ -103,8 +103,8 @@ end


# Arguments
- `gn`: The used [GraphNetCore.GraphNetwork](@ref).
- `graph`: Input data stored in a [GraphNetCore.FeatureGraph](@ref).
- `gn`: The used [`GraphNetwork`](@ref).
- `graph`: Input data stored in a [`FeatureGraph`](@ref).
- `target_quantities_change`: Derivatives of quantities of interest (e.g. via finite differences from data).
- `mask`: Mask for excluding node types that should not be updated.
- `loss_function`: Loss function that is used to calculate the error.
Expand All @@ -124,13 +124,13 @@ end
"""
save!(gn, opt_state, df_train::DataFrame, df_valid::DataFrame, step::Integer, train_loss::Float32, path::String; is_training = true)

Creates a checkpoint of the [GraphNetCore.GraphNetwork](@ref) at the given training step.
Creates a checkpoint of the [`GraphNetwork`](@ref) at the given training step.

# Arguments
- `gn`: The [GraphNetCore.GraphNetwork](@ref) that a checkpoint is created of.
- `gn`: The [`GraphNetwork`](@ref) that a checkpoint is created of.
- `opt_state`: State of the optimiser.
- `df_train`: [DataFrames.DataFram](@ref) that stores the train losses at the checkpoints.
- `df_valid`: [DataFrames.DataFram](@ref) that stores the validation losses at the checkpoints (only improvements are saved).
- `df_train`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame that stores the train losses at the checkpoints.
- `df_valid`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame that stores the validation losses at the checkpoints (only improvements are saved).
- `step`: Current training step where the checkpoint is created.
- `train_loss`: Current training loss.
- `path`: Path to the folder where checkpoints are saved.
Expand Down Expand Up @@ -178,7 +178,7 @@ end
"""
load(quantities, dims, norms, output, message_steps, ls, hl, opt, device::Function, path::String)

Loads the [GraphNetCore.GraphNetwork](@ref) from the latest checkpoint at the given path.
Loads the [`GraphNetwork`](@ref) from the latest checkpoint at the given path.

# Arguments
- `quantities`: Sum of dimensions of each node feature.
Expand All @@ -189,14 +189,14 @@ Loads the [GraphNetCore.GraphNetwork](@ref) from the latest checkpoint at the gi
- `ls`: Size of hidden layers.
- `hl`: Number of hidden layers.
- `opt`: Optimiser that is used for training. Set this to `nothing` if you want to use the optimiser from the checkpoint.
- `device`: Device where the model should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
- `device`: Device where the model should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
- `path`: Path to the folder where the checkpoint is.

# Returns
- `gn`: The loaded [GraphNetCore.GraphNetwork](@ref) from the checkpoint.
- `gn`: The loaded [`GraphNetwork`](@ref) from the checkpoint.
- `opt_state`: The loaded optimiser state. Is nothing if no checkpoint was found or an optimiser was passed as an argument.
- `df_train`: [DataFrames.DataFram](@ref) containing the train losses at the checkpoints.
- `df_valid`: [DataFrames.DataFram](@ref) containing the validation losses at the checkpoints (only improvements are saved).
- `df_train`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the train losses at the checkpoints.
- `df_valid`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the validation losses at the checkpoints (only improvements are saved).
"""
function load(quantities, dims, norms, output, message_steps, ls, hl, opt, device::Function, path::String)
if isfile(joinpath(path, "checkpoints"))
Expand Down
10 changes: 5 additions & 5 deletions src/normaliser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
Inverses the normalised data.

# Arguments
- `n`: The used [GraphNetCore.NormaliserOffline](@ref).
- `n`: The used [`NormaliserOffline`](@ref).
- `data`: Data to be converted back.

# Returns
Expand Down Expand Up @@ -75,7 +75,7 @@ It is recommended to use offline normalization since the minimum and maximum do

# Arguments
- `dims`: Dimension of the quantity to normalize.
- `device`: Device where the Normaliser should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
- `device`: Device where the Normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).

# Keyword Arguments
- `max_acc = 10f6`: Maximum number of accumulation steps.
Expand All @@ -92,8 +92,8 @@ Online normalization if the minimum and maximum of the quantity is not known.
It is recommended to use offline normalization since the minimum and maximum do not need to be inferred from data.

# Arguments
- `d`: Dictionary containing the fields of the struct [GraphNetCore.NormaliserOnline](@ref).
- `device`: Device where the Normaliser should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
- `d`: Dictionary containing the fields of the struct [`NormaliserOnline`](@ref).
- `device`: Device where the Normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
"""
function NormaliserOnline(d::Dict{String, Any}, device::Function)
NormaliserOnline(d["max_accumulations"], d["std_epsilon"], d["acc_count"], d["num_accumulations"], device(d["acc_sum"]), device(d["acc_sum_squared"]))
Expand All @@ -114,7 +114,7 @@ end
Inverses the normalised data.

# Arguments
- `n`: The used [GraphNetCore.NormaliserOnline](@ref).
- `n`: The used [`NormaliserOnline`](@ref).
- `data`: Data to be converted back.

# Returns
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Converts the given faces of a mesh to edges.
- `faces`: Two-dimensional array with the node indices in the first dimension.

# Returns
- A tuple containing the edge pairs. (See [parse_edges](@ref))
- A tuple containing the edge pairs. (See [`parse_edges`](@ref))
"""
function triangles_to_edges(faces::AbstractArray{T, 2} where T <: Integer)
edges = hcat(faces[1:2, :], faces[2:3, :], permutedims(hcat(faces[3, :], faces[1, :])))
Expand Down Expand Up @@ -94,7 +94,7 @@ end
"""
mse_reduce(target, output)

Calculates the mean squared error of the given arguments with [Tullio](@ref) for GPU compatibility.
Calculates the mean squared error of the given arguments with [Tullio](https://github.com/mcabbott/Tullio.jl) for GPU compatibility.

# Arguments
- `target`: Ground truth from the data.
Expand All @@ -111,7 +111,7 @@ end
"""
tullio_reducesum(a, dims)

Implementation of the function [reducesum](@ref) with [Tullio](@ref) for GPU compatibility.
Implementation of the function [`reducesum`](@ref) with [Tullio](https://github.com/mcabbott/Tullio.jl) for GPU compatibility.

# Arguments
- `a`: Array as input for reducesum.
Expand Down