Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training restarts #141

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions cloud-microphysics/app/train-cloud-microphysics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,32 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
real(rkind), parameter :: keep = 0.01
real(rkind), allocatable :: cost(:)
real(rkind), allocatable :: harvest(:)
integer i, batch, lon, lat, level, time, network_unit, io_status, final_step, epoch
integer i, batch, lon, lat, level, time, network_unit, io_status, epoch
integer(int64) start_training, finish_training

open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')

if (.not. allocated(end_step)) end_step = t_end

print *,"Defining tensors from time step", start_step, "through", end_step, "with strides of", stride

! The following temporary copies are required by gfortran bug 100650 and possibly 49324
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
inputs = [( [( [( [( &
tensor_t( &
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]

outputs = [( [( [( [( &
tensor_t( &
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
dqs_dt(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]

read_or_initialize_engine: &
if (io_status==0) then
print *,"Reading network from file " // network_file
trainable_engine = trainable_engine_t(inference_engine_t(file_t(string_t(network_file))))
Expand All @@ -240,68 +261,48 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
initialize_network: &
block
character(len=len('YYYYMMDD')) date
type(tensor_range_t) input_range, output_range

call date_and_time(date)

print *,"Calculating input tensor component ranges."
input_range = tensor_range_t( &
associate(input_range => tensor_range_t( &
layer = "inputs", &
minima = [minval(pressure_in), minval(potential_temperature_in), minval(temperature_in), &
minval(qv_in), minval(qc_in), minval(qr_in), minval(qs_in)], &
maxima = [maxval(pressure_in), maxval(potential_temperature_in), maxval(temperature_in), &
maxval(qv_in), maxval(qc_in), maxval(qr_in), maxval(qs_in)] &
)
print *,"Calculating output tensor component ranges."
output_range = tensor_range_t( &
layer = "outputs", &
minima = [minval(dpt_dt), minval(dqv_dt), minval(dqc_dt), minval(dqr_dt), minval(dqs_dt)], &
maxima = [maxval(dpt_dt), maxval(dqv_dt), maxval(dqc_dt), maxval(dqr_dt), maxval(dqs_dt)] &
)
print *,"Initializing a new network"

associate(activation => training_configuration%differentiable_activation_strategy())
associate( &
model_name => string_t("Simple microphysics"), &
author => string_t("Inference Engine"), &
date_string => string_t(date), &
activation_name => activation%function_name(), &
residual_network => string_t(trim(merge("true ", "false", training_configuration%skip_connections()))) &
)
trainable_engine = trainable_engine_t( &
training_configuration, perturbation_magnitude=0.05, &
metadata = [model_name, author, date_string, activation_name, residual_network], &
input_range = input_range, output_range = output_range &
)
))
print *,"Calculating output tensor component ranges."
associate(output_range => tensor_range_t( &
layer = "outputs", &
minima = [minval(dpt_dt), minval(dqv_dt), minval(dqc_dt), minval(dqr_dt), minval(dqs_dt)], &
maxima = [maxval(dpt_dt), maxval(dqv_dt), maxval(dqc_dt), maxval(dqr_dt), maxval(dqs_dt)] &
))
associate(activation => training_configuration%differentiable_activation_strategy())
associate( &
model_name => string_t("Simple microphysics"), &
author => string_t("Inference Engine"), &
date_string => string_t(date), &
activation_name => activation%function_name(), &
residual_network => string_t(trim(merge("true ", "false", training_configuration%skip_connections()))) &
)
trainable_engine = trainable_engine_t( &
training_configuration, perturbation_magnitude=0.05, &
metadata = [model_name, author, date_string, activation_name, residual_network], &
input_range = input_range, output_range = output_range &
)
end associate
end associate
end associate
end associate
end block initialize_network
end if

if (.not. allocated(end_step)) end_step = t_end

print *,"Defining tensors from time step", start_step, "through", end_step, "with strides of", stride

! The following temporary copies are required by gfortran bug 100650 and possibly 49324
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
inputs = [( [( [( [( &
tensor_t( &
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]

outputs = [( [( [( [( &
tensor_t( &
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
dqs_dt(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]
end if read_or_initialize_engine

print *,"Normalizing input tensors"
inputs = input_range%map_to_training_range(inputs)
inputs = trainable_engine%map_to_input_training_range(inputs)

print *,"Normalizing output tensors"
outputs = output_range%map_to_training_range(outputs)
outputs = trainable_engine%map_to_output_training_range(outputs)

print *, "Eliminating",int(100*(1.-keep)),"% of the grid points that have all-zero time derivatives"

Expand Down
1 change: 1 addition & 0 deletions src/inference_engine/inference_engine_m_.f90
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ module inference_engine_m_
end type

type exchange_t
type(tensor_range_t) input_range_, output_range_
type(string_t) metadata_(size(key))
real(rkind), allocatable :: weights_(:,:,:), biases_(:,:)
integer, allocatable :: nodes_(:)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
use sourcery_formats_m, only : separated_values
implicit none

#ifndef NO_EXTRAPOLATION
#define NO_EXTRAPOLATION .false.
#endif

interface assert_consistency
procedure inference_engine_consistency
procedure difference_consistency
Expand All @@ -23,6 +19,9 @@
contains

module procedure to_exchange
exchange%input_range_ = self%input_range_
exchange%output_range_ = self%output_range_
exchange%metadata_ = self%metadata_
exchange%metadata_ = self%metadata_
exchange%weights_ = self%weights_
exchange%biases_ = self%biases_
Expand All @@ -38,8 +37,6 @@

call assert_consistency(self)

if (NO_EXTRAPOLATION) call assert(self%input_range_%in_range(inputs), "inference_engine_s(infer): inputs in range")

associate(w => self%weights_, b => self%biases_, n => self%nodes_, output_layer => ubound(self%nodes_,1))

allocate(a(maxval(n), input_layer:output_layer))
Expand All @@ -61,8 +58,6 @@

end associate

if (NO_EXTRAPOLATION) call assert(self%output_range_%in_range(outputs), "inference_engine_s(infer): outputs in range")

end procedure

pure subroutine inference_engine_consistency(self)
Expand Down Expand Up @@ -228,7 +223,7 @@ pure subroutine set_activation_strategy(inference_engine)
end associate
end block

inference_engine = hidden_layers%inference_engine(metadata, output_layer)
inference_engine = hidden_layers%inference_engine(metadata, output_layer, input_range, output_range)

call set_activation_strategy(inference_engine)
call assert_consistency(inference_engine)
Expand Down
4 changes: 3 additions & 1 deletion src/inference_engine/layer_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module layer_m
use sourcery_string_m, only : string_t
use kind_parameters_m, only : rkind
use inference_engine_m_, only : inference_engine_t
use tensor_range_m, only : tensor_range_t
implicit none

private
Expand Down Expand Up @@ -39,11 +40,12 @@ recursive module function construct_layer(layer_lines, start) result(layer)

interface

module function inference_engine(hidden_layers, metadata, output_layer) result(inference_engine_)
module function inference_engine(hidden_layers, metadata, output_layer, input_range, output_range) result(inference_engine_)
implicit none
class(layer_t), intent(in), target :: hidden_layers
type(layer_t), intent(in), target :: output_layer
type(string_t), intent(in) :: metadata(:)
type(tensor_range_t), intent(in) :: input_range, output_range
type(inference_engine_t) inference_engine_
end function

Expand Down
2 changes: 1 addition & 1 deletion src/inference_engine/layer_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@

end do loop_over_output_neurons

inference_engine_ = inference_engine_t(metadata, weights, biases, nodes)
inference_engine_ = inference_engine_t(metadata, weights, biases, nodes, input_range, output_range)
end block
end associate
end associate
Expand Down
24 changes: 20 additions & 4 deletions src/inference_engine/trainable_engine_m.F90
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ module trainable_engine_m
procedure :: num_inputs
procedure :: num_outputs
procedure :: to_inference_engine
procedure :: map_to_training_range
procedure :: map_from_training_range
procedure :: map_to_input_training_range
procedure :: map_from_input_training_range
procedure :: map_to_output_training_range
procedure :: map_from_output_training_range
end type

integer, parameter :: input_layer = 0
Expand Down Expand Up @@ -123,14 +125,28 @@ module function to_inference_engine(self) result(inference_engine)
type(inference_engine_t) :: inference_engine
end function

elemental module function map_to_training_range(self, tensor) result(normalized_tensor)
elemental module function map_to_input_training_range(self, tensor) result(normalized_tensor)
implicit none
class(trainable_engine_t), intent(in) :: self
type(tensor_t), intent(in) :: tensor
type(tensor_t) normalized_tensor
end function

elemental module function map_from_training_range(self, tensor) result(unnormalized_tensor)
elemental module function map_from_input_training_range(self, tensor) result(unnormalized_tensor)
implicit none
class(trainable_engine_t), intent(in) :: self
type(tensor_t), intent(in) :: tensor
type(tensor_t) unnormalized_tensor
end function

elemental module function map_to_output_training_range(self, tensor) result(normalized_tensor)
implicit none
class(trainable_engine_t), intent(in) :: self
type(tensor_t), intent(in) :: tensor
type(tensor_t) normalized_tensor
end function

elemental module function map_from_output_training_range(self, tensor) result(unnormalized_tensor)
implicit none
class(trainable_engine_t), intent(in) :: self
type(tensor_t), intent(in) :: tensor
Expand Down
14 changes: 12 additions & 2 deletions src/inference_engine/trainable_engine_s.F90
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
module procedure construct_from_inference_engine

associate(exchange => inference_engine%to_exchange())
trainable_engine%input_range_ = exchange%input_range_
trainable_engine%output_range_ = exchange%output_range_
trainable_engine%metadata_ = exchange%metadata_
trainable_engine%w = exchange%weights_
trainable_engine%b = exchange%biases_
Expand Down Expand Up @@ -304,11 +306,19 @@ pure function e(j,n) result(unit_vector)

end procedure

module procedure map_to_training_range
module procedure map_to_input_training_range
normalized_tensor = self%input_range_%map_to_training_range(tensor)
end procedure

module procedure map_from_training_range
module procedure map_from_input_training_range
unnormalized_tensor = self%input_range_%map_from_training_range(tensor)
end procedure

module procedure map_to_output_training_range
normalized_tensor = self%output_range_%map_to_training_range(tensor)
end procedure

module procedure map_from_output_training_range
unnormalized_tensor = self%output_range_%map_from_training_range(tensor)
end procedure

Expand Down
Loading