Skip to content

Commit e628d1e

Browse files
OneAddermilancurcic
andauthoredMar 5, 2025··
Embedding Layer (#205)
* embedding_layer: initial forward implementation * embedding_layer: implementation of embedding layer * embedding_layer: remove gradient attribute * embedding_layer: guard against zeros * embedding_layer: plumbing * embedding_layer: positional encoding * embedding_layer: update tests * embedding_layer: add more comments * embedding_layer: update cmake * embedding_layer: pr fixes * embedding_layer: add absolute positional encoding * embedding_layer: update constructor and tests * embedding_layer: make integer input generics * embedding_layer: update readme --------- Co-authored-by: milancurcic <caomaco@gmail.com>
1 parent e68e6c2 commit e628d1e

12 files changed

+496
-8
lines changed
 

‎CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ add_library(neural-fortran
4343
src/nf/nf_layer_submodule.f90
4444
src/nf/nf_linear2d_layer.f90
4545
src/nf/nf_linear2d_layer_submodule.f90
46+
src/nf/nf_embedding_layer.f90
47+
src/nf/nf_embedding_layer_submodule.f90
4648
src/nf/nf_loss.f90
4749
src/nf/nf_loss_submodule.f90
4850
src/nf/nf_maxpool2d_layer.f90

‎README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3030
| Layer type | Constructor name | Supported input layers | Rank of output array | Forward pass | Backward pass |
3131
|------------|------------------|------------------------|----------------------|--------------|---------------|
3232
| Input | `input` | n/a | 1, 2, 3 | n/a | n/a |
33+
| Embedding | `embedding` | n/a | 2 |||
3334
| Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 |||
3435
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
3536
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |

‎src/nf.f90

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ module nf
66
conv2d, &
77
dense, &
88
dropout, &
9+
embedding, &
910
flatten, &
1011
input, &
12+
layernorm, &
1113
linear2d, &
1214
maxpool2d, &
1315
reshape, &
14-
self_attention, &
15-
layernorm
16+
self_attention
1617
use nf_loss, only: mse, quadratic
1718
use nf_metrics, only: corr, maxabs
1819
use nf_network, only: network

‎src/nf/nf_embedding_layer.f90

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
module nf_embedding_layer
2+
3+
use nf_activation, only: activation_function
4+
use nf_base_layer, only: base_layer
5+
6+
implicit none
7+
8+
private
9+
public :: embedding_layer
10+
11+
type, extends(base_layer) :: embedding_layer
12+
!! Embedding Layer
13+
!! Stores inputs as a trainable lookup table. Inputs are
14+
!! integer indicies in a dictionary of `vocab_size`.
15+
!! This layer converts them into a table of shape
16+
!! (`sequence_length`, `model_dimension`)
17+
integer :: sequence_length, vocab_size, model_dimension
18+
integer :: positional
19+
20+
real, allocatable :: weights(:, :)
21+
real, allocatable :: output(:, :)
22+
real, allocatable :: dw(:, :) ! weight gradients
23+
24+
contains
25+
26+
procedure :: backward
27+
procedure :: forward
28+
procedure :: positional_trigonometric
29+
procedure :: positional_absolute
30+
procedure :: init
31+
procedure :: get_num_params
32+
procedure :: get_params
33+
procedure :: get_gradients
34+
procedure :: set_params
35+
36+
end type embedding_layer
37+
38+
interface embedding_layer
39+
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
40+
integer, intent(in) :: vocab_size, model_dimension
41+
integer, optional :: positional
42+
type(embedding_layer) :: res
43+
end function embedding_layer_cons
44+
end interface embedding_layer
45+
46+
interface
47+
pure module subroutine forward(self, input)
48+
!! Get vectors by indicis in the dictionary
49+
class(embedding_layer), intent(in out) :: self
50+
integer, intent(in) :: input(:)
51+
end subroutine forward
52+
53+
pure module subroutine backward(self, input, gradient)
54+
!! Update gradient at `input` indices
55+
!! dw_i = W_i + d_output_i
56+
class(embedding_layer), intent(in out) :: self
57+
integer, intent(in) :: input(:)
58+
real, intent(in) :: gradient(:, :)
59+
end subroutine backward
60+
61+
pure module subroutine positional_trigonometric(self, pos)
62+
!! Sum embedding with positional info (trigonometric, not trianable)
63+
class(embedding_layer), intent(in out) :: self
64+
integer, intent(in) :: pos
65+
end subroutine positional_trigonometric
66+
67+
pure module subroutine positional_absolute(self, pos)
68+
!! Sum embedding with absolute position
69+
class(embedding_layer), intent(in out) :: self
70+
integer, intent(in) :: pos
71+
end subroutine positional_absolute
72+
73+
module subroutine init(self, input_shape)
74+
class(embedding_layer), intent(in out) :: self
75+
integer, intent(in) :: input_shape(:)
76+
end subroutine init
77+
78+
pure module function get_num_params(self) result(num_params)
79+
class(embedding_layer), intent(in) :: self
80+
integer :: num_params
81+
end function get_num_params
82+
83+
module function get_params(self) result(params)
84+
class(embedding_layer), intent(in), target :: self
85+
real, allocatable :: params(:)
86+
end function get_params
87+
88+
module function get_gradients(self) result(gradients)
89+
class(embedding_layer), intent(in), target :: self
90+
real, allocatable :: gradients(:)
91+
end function get_gradients
92+
93+
module subroutine set_params(self, params)
94+
class(embedding_layer), intent(in out) :: self
95+
real, intent(in), target :: params(:)
96+
end subroutine set_params
97+
end interface
98+
end module nf_embedding_layer
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#define NONE 0
2+
#define TRIGONOMETRIC 1
3+
#define ABSOLUTE 2
4+
5+
submodule(nf_embedding_layer) nf_embedding_layer_submodule
6+
use nf_base_layer, only: base_layer
7+
implicit none
8+
contains
9+
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
10+
integer, intent(in) :: vocab_size, model_dimension
11+
integer, optional :: positional
12+
type(embedding_layer) :: res
13+
14+
res % vocab_size = vocab_size
15+
res % model_dimension = model_dimension
16+
if (.not. present(positional)) then
17+
res % positional = NONE
18+
else
19+
res % positional = positional
20+
end if
21+
end function embedding_layer_cons
22+
23+
module subroutine init(self, input_shape)
24+
class(embedding_layer), intent(in out) :: self
25+
integer, intent(in) :: input_shape(:)
26+
27+
self % sequence_length = input_shape(1)
28+
29+
allocate(self % output(self % sequence_length, self % model_dimension))
30+
31+
allocate(self % weights(self % vocab_size, self % model_dimension))
32+
self % weights = 0.1
33+
34+
allocate(self % dw(self % vocab_size, self % model_dimension))
35+
self % dw = 0.0
36+
end subroutine init
37+
38+
pure module subroutine forward(self, input)
39+
class(embedding_layer), intent(in out) :: self
40+
integer, intent(in) :: input(:)
41+
integer :: i, index
42+
43+
do concurrent(i = 1: self % sequence_length)
44+
index = input(i)
45+
if (index > size(self % weights, 1)) then
46+
index = 1
47+
elseif (index == 0) then
48+
index = 1
49+
end if
50+
51+
self % output(i, :) = self % weights(index, :)
52+
53+
if (self % positional == TRIGONOMETRIC) then
54+
call self % positional_trigonometric(i)
55+
elseif (self % positional == ABSOLUTE) then
56+
call self % positional_absolute(i)
57+
end if
58+
end do
59+
end subroutine forward
60+
61+
pure module subroutine backward(self, input, gradient)
62+
class(embedding_layer), intent(in out) :: self
63+
integer, intent(in) :: input(:)
64+
real, intent(in) :: gradient(:, :)
65+
integer :: i
66+
67+
do concurrent(i = 1: self % sequence_length)
68+
self % dw(input(i), :) = self % dw(input(i), :) + gradient(i, :)
69+
end do
70+
end subroutine backward
71+
72+
pure module subroutine positional_trigonometric(self, pos)
73+
class(embedding_layer), intent(in out) :: self
74+
integer, intent(in) :: pos
75+
integer :: i
76+
real :: theta
77+
78+
do concurrent(i = 1: floor(real(self % model_dimension) / 2))
79+
theta = (pos - 1) / 10000 ** (real(2 * (i-1)) / self % model_dimension)
80+
self % output(pos, 2 * i - 1) = self % output(pos, 2 * i - 1) + sin(theta)
81+
self % output(pos, 2 * i) = self % output(pos, 2 * i) + cos(theta)
82+
end do
83+
end subroutine positional_trigonometric
84+
85+
pure module subroutine positional_absolute(self, pos)
86+
class(embedding_layer), intent(in out) :: self
87+
integer, intent(in) :: pos
88+
integer :: i
89+
90+
do concurrent(i = 1: self % model_dimension)
91+
self % output(pos, i) = self % output(pos, i) + pos - 1
92+
end do
93+
end subroutine positional_absolute
94+
95+
pure module function get_num_params(self) result(num_params)
96+
class(embedding_layer), intent(in) :: self
97+
integer :: num_params
98+
num_params = self % vocab_size * self % model_dimension
99+
end function get_num_params
100+
101+
module function get_params(self) result(params)
102+
class(embedding_layer), intent(in), target :: self
103+
real, allocatable :: params(:)
104+
real, pointer :: w_(:) => null()
105+
106+
w_(1: product(shape(self % weights))) => self % weights
107+
params = w_
108+
end function get_params
109+
110+
module function get_gradients(self) result(gradients)
111+
class(embedding_layer), intent(in), target :: self
112+
real, allocatable :: gradients(:)
113+
real, pointer :: dw_(:) => null()
114+
115+
dw_(1: product(shape(self % dw))) => self % dw
116+
gradients = dw_
117+
end function get_gradients
118+
119+
module subroutine set_params(self, params)
120+
class(embedding_layer), intent(in out) :: self
121+
real, intent(in), target :: params(:)
122+
123+
real, pointer :: p_(:,:) => null()
124+
125+
! check if the number of parameters is correct
126+
if (size(params) /= self % get_num_params()) then
127+
error stop 'Error: number of parameters does not match'
128+
end if
129+
130+
associate(n => self % vocab_size * self % model_dimension)
131+
! reshape the weights
132+
p_(1:self % vocab_size, 1:self % model_dimension) => params(1 : n)
133+
self % weights = p_
134+
end associate
135+
136+
end subroutine set_params
137+
end submodule nf_embedding_layer_submodule

‎src/nf/nf_layer_constructors.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module nf_layer_constructors
1818
maxpool2d, &
1919
reshape, &
2020
self_attention, &
21+
embedding, &
2122
layernorm
2223

2324
interface input
@@ -233,6 +234,23 @@ module function self_attention(num_heads) result(res)
233234
!! Resulting layer instance
234235
end function self_attention
235236

237+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
238+
!! Embedding layer constructor.
239+
!!
240+
!! This layer is for inputting token indices from the dictionary to the network.
241+
!! Works as a trainable lookup table that converts each index into a vector.
242+
!! Embedding layer must be the first layer in a network.
243+
integer, intent(in) :: sequence_length
244+
!! max len of input sequence
245+
integer, intent(in) :: vocab_size
246+
!! length of token vocabulary
247+
integer, intent(in) :: model_dimension
248+
!! size of target embeddings
249+
integer, optional, intent(in) :: positional
250+
!! positional encoding
251+
type(layer) :: res
252+
end function embedding
253+
236254
module function layernorm() result(res)
237255
!! Layer Normalization
238256
!! ((x − mean(x)) / sqrt(variance(x) + eps) * gamma + beta

‎src/nf/nf_layer_constructors_submodule.f90

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_embedding_layer, only: embedding_layer
1516
use nf_layernorm_layer, only: layernorm_layer
1617
use nf_activation, only: activation_function, relu, sigmoid
1718

@@ -172,6 +173,7 @@ module function linear2d(out_features) result(res)
172173

173174
end function linear2d
174175

176+
175177
module function self_attention(num_heads) result(res)
176178
integer, intent(in) :: num_heads
177179
type(layer) :: res
@@ -180,9 +182,26 @@ module function self_attention(num_heads) result(res)
180182
allocate(res % p, source=self_attention_layer(num_heads))
181183
end function self_attention
182184

183-
module function layernorm() result(res)
185+
186+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
187+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
188+
integer, optional, intent(in) :: positional
184189
type(layer) :: res
190+
type(embedding_layer) :: embedding_layer_instance
191+
192+
embedding_layer_instance = embedding_layer(vocab_size, model_dimension, positional)
193+
call embedding_layer_instance % init([sequence_length])
194+
res % name = 'embedding'
195+
res % layer_shape = [sequence_length, model_dimension]
196+
res % input_layer_shape = [integer ::]
197+
allocate(res % p, source=embedding_layer_instance)
198+
res % initialized = .true.
199+
200+
end function embedding
201+
185202

203+
module function layernorm() result(res)
204+
type(layer) :: res
186205
res % name = 'layernorm'
187206
allocate(res % p, source=layernorm_layer())
188207
end function layernorm

‎src/nf/nf_layer_submodule.f90

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_embedding_layer, only: embedding_layer
1516
use nf_layernorm_layer, only: layernorm_layer
1617
use nf_optimizers, only: optimizer_base_type
1718

@@ -61,6 +62,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6162
call this_layer % backward(prev_layer % output, gradient)
6263
type is(self_attention_layer)
6364
call this_layer % backward(prev_layer % output, gradient)
65+
type is(embedding_layer)
66+
call this_layer % backward(prev_layer % output, gradient)
6467
type is(layernorm_layer)
6568
call this_layer % backward(prev_layer % output, gradient)
6669
end select
@@ -83,6 +86,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8386
select type(prev_layer => previous % p)
8487
type is(input2d_layer)
8588
call this_layer % backward(prev_layer % output, gradient)
89+
type is(embedding_layer)
90+
call this_layer % backward(prev_layer % output, gradient)
8691
type is(linear2d_layer)
8792
call this_layer % backward(prev_layer % output, gradient)
8893
type is(self_attention_layer)
@@ -96,6 +101,8 @@ pure module subroutine backward_2d(self, previous, gradient)
96101
select type(prev_layer => previous % p)
97102
type is(input2d_layer)
98103
call this_layer % backward(prev_layer % output, gradient)
104+
type is(embedding_layer)
105+
call this_layer % backward(prev_layer % output, gradient)
99106
type is(linear2d_layer)
100107
call this_layer % backward(prev_layer % output, gradient)
101108
type is(self_attention_layer)
@@ -271,6 +278,8 @@ module subroutine forward(self, input)
271278
select type(prev_layer => input % p)
272279
type is(input2d_layer)
273280
call this_layer % forward(prev_layer % output)
281+
type is(embedding_layer)
282+
call this_layer % forward(prev_layer % output)
274283
type is(linear2d_layer)
275284
call this_layer % forward(prev_layer % output)
276285
type is(self_attention_layer)
@@ -285,6 +294,8 @@ module subroutine forward(self, input)
285294
select type(prev_layer => input % p)
286295
type is(input2d_layer)
287296
call this_layer % forward(prev_layer % output)
297+
type is(embedding_layer)
298+
call this_layer % forward(prev_layer % output)
288299
type is(linear2d_layer)
289300
call this_layer % forward(prev_layer % output)
290301
type is(self_attention_layer)
@@ -338,6 +349,8 @@ pure module subroutine get_output_2d(self, output)
338349

339350
type is(input2d_layer)
340351
allocate(output, source=this_layer % output)
352+
type is(embedding_layer)
353+
allocate(output, source=this_layer % output)
341354
type is(linear2d_layer)
342355
allocate(output, source=this_layer % output)
343356
type is(self_attention_layer)
@@ -460,6 +473,8 @@ elemental module function get_num_params(self) result(num_params)
460473
num_params = this_layer % get_num_params()
461474
type is (self_attention_layer)
462475
num_params = this_layer % get_num_params()
476+
type is (embedding_layer)
477+
num_params = this_layer % get_num_params()
463478
type is (layernorm_layer)
464479
num_params = this_layer % get_num_params()
465480
class default
@@ -495,6 +510,8 @@ module function get_params(self) result(params)
495510
params = this_layer % get_params()
496511
type is (self_attention_layer)
497512
params = this_layer % get_params()
513+
type is (embedding_layer)
514+
params = this_layer % get_params()
498515
type is (layernorm_layer)
499516
params = this_layer % get_params()
500517
class default
@@ -530,6 +547,8 @@ module function get_gradients(self) result(gradients)
530547
gradients = this_layer % get_gradients()
531548
type is (self_attention_layer)
532549
gradients = this_layer % get_gradients()
550+
type is (embedding_layer)
551+
gradients = this_layer % get_gradients()
533552
type is (layernorm_layer)
534553
gradients = this_layer % get_gradients()
535554
class default
@@ -589,6 +608,8 @@ module subroutine set_params(self, params)
589608

590609
type is (self_attention_layer)
591610
call this_layer % set_params(params)
611+
type is (embedding_layer)
612+
call this_layer % set_params(params)
592613

593614
type is (layernorm_layer)
594615
call this_layer % set_params(params)

‎src/nf/nf_network.f90

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ module nf_network
3232

3333
procedure, private :: evaluate_batch_1d
3434
procedure, private :: forward_1d
35+
procedure, private :: forward_1d_int
3536
procedure, private :: forward_2d
3637
procedure, private :: forward_3d
3738
procedure, private :: predict_1d
39+
procedure, private :: predict_1d_int
3840
procedure, private :: predict_2d
3941
procedure, private :: predict_3d
4042
procedure, private :: predict_batch_1d
4143
procedure, private :: predict_batch_3d
4244

4345
generic :: evaluate => evaluate_batch_1d
44-
generic :: forward => forward_1d, forward_2d, forward_3d
45-
generic :: predict => predict_1d, predict_2d, predict_3d
46+
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
47+
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
4648
generic :: predict_batch => predict_batch_1d, predict_batch_3d
4749

4850
end type network
@@ -95,6 +97,12 @@ module subroutine forward_1d(self, input)
9597
!! 1-d input data
9698
end subroutine forward_1d
9799

100+
module subroutine forward_1d_int(self, input)
101+
!! Same as `forward_1d` except `integer`
102+
class(network), intent(in out) :: self
103+
integer, intent(in) :: input(:)
104+
end subroutine forward_1d_int
105+
98106
module subroutine forward_2d(self, input)
99107
!! Apply a forward pass through the network.
100108
!!
@@ -137,6 +145,13 @@ module function predict_1d(self, input) result(res)
137145
!! Output of the network
138146
end function predict_1d
139147

148+
module function predict_1d_int(self, input) result(res)
149+
!! Same as `predict_1d` except `integer`
150+
class(network), intent(in out) :: self
151+
integer, intent(in) :: input(:)
152+
real, allocatable :: res(:)
153+
end function predict_1d_int
154+
140155
module function predict_2d(self, input) result(res)
141156
!! Return the output of the network given the input 1-d array.
142157
class(network), intent(in out) :: self

‎src/nf/nf_network_submodule.f90

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
1313
use nf_self_attention_layer, only: self_attention_layer
14+
use nf_embedding_layer, only: embedding_layer
1415
use nf_layernorm_layer, only: layernorm_layer
1516
use nf_layer, only: layer
1617
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
@@ -47,7 +48,7 @@ module function network_from_layers(layers) result(res)
4748
error stop 'Error: A network must have at least 2 layers.'
4849

4950
! The first layer must be an input layer
50-
if (.not. layers(1) % name == 'input') &
51+
if (.not. layers(1) % name == 'input' .and. .not. layers(1) % name == 'embedding') &
5152
error stop 'Error: First layer in the network must be an input layer.'
5253

5354
!TODO Ensure that the layers are in allowed sequence:
@@ -210,8 +211,9 @@ module subroutine forward_1d(self, input)
210211
integer :: n
211212

212213
! Set the input array into the input layer
213-
select type(input_layer => self % layers(1) % p); type is(input1d_layer)
214-
call input_layer % set(input)
214+
select type(input_layer => self % layers(1) % p)
215+
type is(input1d_layer)
216+
call input_layer % set(input)
215217
end select
216218

217219
do n = 2, size(self % layers)
@@ -220,6 +222,21 @@ module subroutine forward_1d(self, input)
220222

221223
end subroutine forward_1d
222224

225+
module subroutine forward_1d_int(self, input)
226+
class(network), intent(in out) :: self
227+
integer, intent(in) :: input(:)
228+
integer :: n
229+
230+
select type(input_layer => self % layers(1) % p)
231+
type is(embedding_layer)
232+
call input_layer % forward(input)
233+
end select
234+
235+
do n = 2, size(self % layers)
236+
call self % layers(n) % forward(self % layers(n - 1))
237+
end do
238+
239+
end subroutine forward_1d_int
223240

224241
module subroutine forward_2d(self, input)
225242
class(network), intent(in out) :: self
@@ -284,6 +301,31 @@ module function predict_1d(self, input) result(res)
284301

285302
end function predict_1d
286303

304+
module function predict_1d_int(self, input) result(res)
305+
class(network), intent(in out) :: self
306+
integer, intent(in) :: input(:)
307+
real, allocatable :: res(:)
308+
integer :: n, num_layers
309+
310+
num_layers = size(self % layers)
311+
312+
call self % set_training_mode(.false.)
313+
call self % forward(input)
314+
call self % set_training_mode(.true.)
315+
316+
select type(output_layer => self % layers(num_layers) % p)
317+
type is(dense_layer)
318+
res = output_layer % output
319+
type is(dropout_layer)
320+
res = output_layer % output
321+
type is(flatten_layer)
322+
res = output_layer % output
323+
class default
324+
error stop 'network % output not implemented for ' // &
325+
trim(self % layers(num_layers) % name) // ' layer'
326+
end select
327+
328+
end function predict_1d_int
287329

288330
module function predict_2d(self, input) result(res)
289331
class(network), intent(in out) :: self

‎test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ foreach(execid
1212
insert_flatten
1313
reshape_layer
1414
multihead_attention_layer
15+
embedding_layer
1516
layernorm
1617
dense_network
1718
get_set_network_params

‎test/test_embedding_layer.f90

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
program test_embedding_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_embedding_layer, only: embedding_layer
4+
use nf_layer, only: layer
5+
use nf_layer_constructors, only: embedding_constructor => embedding
6+
implicit none
7+
8+
logical :: ok = .true.
9+
integer :: sample_input(3) = [2, 1, 3]
10+
11+
call test_simple(ok, sample_input)
12+
call test_positional_trigonometric(ok, sample_input)
13+
call test_positional_absolute(ok, sample_input)
14+
15+
if (ok) then
16+
print '(a)', 'test_embedding_layer: All tests passed.'
17+
else
18+
write(stderr, '(a)') 'test_embedding_layer: One or more tests failed.'
19+
error stop 1
20+
end if
21+
22+
contains
23+
subroutine test_simple(ok, sample_input)
24+
logical, intent(in out) :: ok
25+
integer, intent(in) :: sample_input(:)
26+
27+
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
28+
real :: output_flat(6)
29+
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
30+
real :: dw_flat(8)
31+
real :: expected_dw_flat(8) = reshape([0.2, 0.1, 0.3, 0., 0.6, 0.4, 0.6, 0.], [8])
32+
type(embedding_layer) :: embedding
33+
34+
embedding = embedding_layer(vocab_size=4, model_dimension=2)
35+
call embedding % init([3])
36+
embedding % weights = reshape([0.1, 0.3, 0.5, 0.7, 0.2, 0.4, 0.6, 0.8], [4, 2])
37+
38+
call embedding % forward(sample_input)
39+
40+
output_flat = reshape(embedding % output, [6])
41+
if (.not. all(output_flat.eq.expected_output_flat)) then
42+
ok = .false.
43+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
44+
end if
45+
46+
call embedding % backward(sample_input, sample_gradient)
47+
dw_flat = reshape(embedding % dw, shape(dw_flat))
48+
if (.not. all(dw_flat.eq.expected_dw_flat)) then
49+
ok = .false.
50+
write(stderr, '(a)') 'backward returned incorrect dw values.. failed'
51+
end if
52+
end subroutine test_simple
53+
54+
subroutine test_positional_trigonometric(ok, sample_input)
55+
logical, intent(in out) :: ok
56+
integer, intent(in) :: sample_input(:)
57+
58+
real :: output_flat(12)
59+
real :: expected_output_flat(12) = reshape([&
60+
0.3, 0.941471, 1.4092975,&
61+
1.3, 0.64030236, 0.08385316,&
62+
0.3, 0.10999984, 0.51999867,&
63+
1.3, 1.09995, 1.4998&
64+
], [12])
65+
type(embedding_layer) :: embedding
66+
67+
real :: theta
68+
integer :: i, pos
69+
70+
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=1)
71+
call embedding % init([3])
72+
embedding % weights = reshape([&
73+
0.1, 0.3, 0.5, 0.7, 0.2,&
74+
0.1, 0.3, 0.5, 0.7, 0.2,&
75+
0.1, 0.3, 0.5, 0.7, 0.2,&
76+
0.1, 0.3, 0.5, 0.7, 0.2&
77+
], [5, 4])
78+
79+
call embedding % forward(sample_input)
80+
81+
output_flat = reshape(embedding % output, [12])
82+
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
83+
ok = .false.
84+
write(stderr, '(a)') 'trigonometric positional encoding returned incorrect values.. failed'
85+
end if
86+
end subroutine test_positional_trigonometric
87+
88+
subroutine test_positional_absolute(ok, sample_input)
89+
logical, intent(in out) :: ok
90+
integer, intent(in) :: sample_input(:)
91+
92+
real :: output_flat(12)
93+
real :: expected_output_flat(12) = reshape([&
94+
0.3, 1.1, 2.5,&
95+
0.3, 1.1, 2.5,&
96+
0.3, 1.1, 2.5,&
97+
0.3, 1.1, 2.5&
98+
], [12])
99+
type(embedding_layer) :: embedding
100+
101+
real :: theta
102+
integer :: i, pos
103+
104+
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=2)
105+
call embedding % init([3])
106+
embedding % weights = reshape([&
107+
0.1, 0.3, 0.5, 0.7, 0.2,&
108+
0.1, 0.3, 0.5, 0.7, 0.2,&
109+
0.1, 0.3, 0.5, 0.7, 0.2,&
110+
0.1, 0.3, 0.5, 0.7, 0.2&
111+
], [5, 4])
112+
113+
call embedding % forward(sample_input)
114+
115+
output_flat = reshape(embedding % output, [12])
116+
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
117+
ok = .false.
118+
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
119+
end if
120+
end subroutine test_positional_absolute
121+
122+
subroutine test_embedding_constructor(ok, sample_input)
123+
logical, intent(in out) :: ok
124+
integer, intent(in) :: sample_input(:)
125+
126+
type(layer) :: embedding_constructed
127+
128+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4)
129+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=0)
130+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=1)
131+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=2)
132+
end subroutine test_embedding_constructor
133+
end program test_embedding_layer

0 commit comments

Comments
 (0)
Please sign in to comment.