-
Notifications
You must be signed in to change notification settings - Fork 0
/
ComputeShader.compute
159 lines (133 loc) · 5.27 KB
/
ComputeShader.compute
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma enable_d3d11_debug_symbols
// Each #kernel tells which function to compile; you can have many kernels
struct Neuron
{
float p_prev;
int in_start;
int spiked;
float last_spike_time;
};
struct Synapse
{
int in_neuron_buffer_index; // this is also the index used by the weights buffer
float time;
};
const float sim_time; //incremented in Update
// neuron properties
float p_in;
const float p_min;
const float p_thresh;
const float p_rest;
const float p_refract;
const float leak;
// learning properties
const float no_stdp_window;
const float a_minus;
const float a_plus;
const float tau_minus;
const float tau_plus;
const float learning_rate;
const float weight_min;
const float weight_max;
// defines how a neuron indexes into synapses buffer
const int synapses_to_keep;
// synapses_to_keep incoming synapses for each neuron, in a circular buffer
RWStructuredBuffer<Synapse> synapse_buffer;
// keeps track of which index to overwrite in the synapse_buffer to implement a circular buffer that persists across simulation steps
RWStructuredBuffer<int> synapse_write_index_buffer;
// neurons
RWStructuredBuffer<Neuron> neuron_buffer;
// incoming neuron ids (e.g. neuron_buffer[in_neuron_buffer[this_neuron.in_start]] is the first incoming neuron
StructuredBuffer<int> in_neuron_buffer;
// incoming neuron weights, with same indexing as in_neurons_buffer
RWStructuredBuffer<float> weight_buffer;
RWStructuredBuffer<float> weight_delta_buffer;
#pragma kernel calc
[numthreads(64, 1, 1)]
void calc(uint3 id : SV_DispatchThreadID){
const uint neuron_index = id.x;
const Neuron this_neuron = neuron_buffer[neuron_index];
const int in_start = this_neuron.in_start;
// there is nothing to do if the neuron has no inputs
if (in_neuron_buffer[in_start] < 0)
{
return;
}
uint next_neuron_index = neuron_index + 1;
const int in_end = neuron_buffer[next_neuron_index].in_start;
const uint synapses_start = synapses_to_keep * neuron_index;
const uint synapses_end = synapses_to_keep * (neuron_index + 1);
const float dt = sim_time - this_neuron.last_spike_time;
float p_in = 0;
uint in_neuron_index;
for (int in_neuron_buffer_index = in_start; in_neuron_buffer_index < in_end; in_neuron_buffer_index++)
{
in_neuron_index = in_neuron_buffer[in_neuron_buffer_index];
if (neuron_buffer[in_neuron_index].spiked > 0)
{
// if the in neuron spiked, add its weight contribution
p_in += weight_buffer[in_neuron_buffer_index];
int synapse_write_index = synapse_write_index_buffer[neuron_index];
synapse_buffer[synapse_write_index].in_neuron_buffer_index = in_neuron_buffer_index;
synapse_buffer[synapse_write_index].time = sim_time;
// rotate back to overwrite the beginning of the circular buffer if this is the end
if (synapse_write_index == synapses_to_keep - 1)
{
synapse_write_index_buffer[neuron_index] = 0;
}
else
{
synapse_write_index_buffer[neuron_index] += 1;
}
// STDP with the last spike, dt is always positive, this will always be depression
if (dt >= no_stdp_window)
{
float weight_delta = a_plus * exp(-dt/tau_plus);
float weight_old = weight_buffer[in_neuron_buffer_index];
weight_buffer[in_neuron_buffer_index] += learning_rate * weight_delta * (weight_old - weight_min);
}
}
}
// calculate new potential
const float p_prev = this_neuron.p_prev;
float p;
if (p_prev < p_min)
{
p = p_rest;
}
if (p_prev > p_thresh)
{
neuron_buffer[neuron_index].spiked = 0;
p = p_refract;
}
if (p_prev > p_min && p_prev < p_thresh)
{
p = p_prev + p_in - leak;
}
neuron_buffer[neuron_index].p_prev = p;
if (p > p_thresh)
{
// spike if above threshold
neuron_buffer[neuron_index].spiked = 1;
// STDP with the presynaptic spikes leading up to this spike, dt is always negative, this will always be potentiation
Synapse synapse;
float synapse_dt;
float weight_old;
for (uint synapse_buffer_index = synapses_start; synapse_buffer_index < synapses_end; synapse_buffer_index++)
{
synapse = synapse_buffer[synapse_buffer_index];
if (synapse.time > this_neuron.last_spike_time) // ignore synapses older than the previous spike
{
synapse_dt = synapse.time - sim_time;
weight_old = weight_buffer[synapse.in_neuron_buffer_index];
if (synapse_dt <= -no_stdp_window)
{
float weight_delta = -a_minus * exp(dt / tau_minus);
//weight_delta_buffer[in_neuron_buffer_index] = weight_delta;
weight_buffer[synapse.in_neuron_buffer_index] += learning_rate * weight_delta * (weight_max - weight_old);
}
}
}
neuron_buffer[neuron_index].last_spike_time = sim_time;
}
}