-
Notifications
You must be signed in to change notification settings - Fork 0
/
Manager.cs
235 lines (190 loc) · 8.71 KB
/
Manager.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public struct Neuron
{
public float p_prev;
public int in_start; // where do the input neuron ids start in the in_neurons_buffer (-1 for no inputs)
public int spiked;
public float last_spike_time;
public Neuron(float p_prev, int in_start, int spiked, float last_spike_time)
{
this.p_prev = p_prev;
this.in_start = in_start;
this.spiked = spiked;
this.last_spike_time = last_spike_time;
}
};
public struct Synapse
{
public int in_neuron_buffer_index; // this is also the index used by the weights buffer
public float time;
public Synapse(int in_neuron_buffer_index, float time)
{
this.in_neuron_buffer_index = in_neuron_buffer_index;
this.time = time;
}
};
public class Manager : MonoBehaviour
{
private const int group_size = 64;
private int num_neurons;
private int[][] in_neurons;
private Neuron[] neurons;
private int sim_time;
// neuron properties
public float p_min;
public float p_thresh;
public float p_rest;
public float p_refract;
public float leak;
// learning properties
public float no_stdp_window;
public float a_minus;
public float a_plus;
public float tau_minus;
public float tau_plus;
public float learning_rate;
public float weight_min;
public float weight_max;
// defines how a neuron indexes into synapses buffer
public int synapses_to_keep;
public Neuron[] neuron_buffer;
public int[] in_neuron_buffer;
public float[] weight_buffer;
public float[] weight_delta_buffer;
private ComputeBuffer _neuron_buffer;
private ComputeBuffer _synapse_write_index_buffer;
private ComputeBuffer _synapse_buffer;
private ComputeBuffer _in_neuron_buffer;
private ComputeBuffer _weight_buffer;
private ComputeBuffer _weight_delta_buffer;
public ComputeShader _compute_shader;
private int kernel;
// Start is called before the first frame update
void Start()
{
// initialize buffer data
List<List<int>> in_neurons_for_neurons = CSVReader.Read("simple_network");
int num_neurons = in_neurons_for_neurons.Count;
Debug.Log(num_neurons);
neuron_buffer = new Neuron[num_neurons];
int[] synapse_write_index_buffer = new int[num_neurons];
Synapse[] synapse_buffer = new Synapse[synapses_to_keep * num_neurons];
int neuron_count = 0;
int in_neuron_count = 0;
List<int> in_neuron_buffer_list = new List<int>();
List<float> weight_buffer_list = new List<float>();
foreach(List<int> neuron_in_neurons in in_neurons_for_neurons)
{
neuron_buffer[neuron_count] = new Neuron(p_rest, in_neuron_count, 0, -20); // initial last_spike_time before the start of simulation and outside learning window
synapse_write_index_buffer[neuron_count] = 0;
for (int i = 0; i < synapses_to_keep; i++)
{
synapse_buffer[neuron_count * i] = new Synapse(0, 0); // initialize all synapses to neuron 0, at time 0, these will be ignored
}
foreach(int in_neuron_index in neuron_in_neurons)
{
in_neuron_buffer_list.Add(in_neuron_index);
weight_buffer_list.Add(1f);
in_neuron_count += 1;
}
neuron_count += 1;
}
in_neuron_buffer = in_neuron_buffer_list.ToArray();
weight_buffer = weight_buffer_list.ToArray();
weight_delta_buffer = weight_buffer_list.ToArray();
// put data in shader
_compute_shader = Resources.Load<ComputeShader>("ComputeShader");
_compute_shader.SetFloat("p_in", 0f);
_compute_shader.SetFloat("p_min", p_min);
_compute_shader.SetFloat("p_thresh", p_thresh);
_compute_shader.SetFloat("p_rest", p_rest);
_compute_shader.SetFloat("p_refract", p_refract);
_compute_shader.SetFloat("leak", leak);
_compute_shader.SetFloat("no_stdp_window", no_stdp_window);
_compute_shader.SetFloat("a_minus", a_minus);
_compute_shader.SetFloat("a_plus", a_plus);
_compute_shader.SetFloat("tau_minus", tau_minus);
_compute_shader.SetFloat("tau_plus", tau_plus);
_compute_shader.SetFloat("learning_rate", learning_rate);
_compute_shader.SetFloat("weight_min", weight_min);
_compute_shader.SetFloat("weight_max", weight_max);
_compute_shader.SetInt("synapses_to_keep", synapses_to_keep);
_neuron_buffer = new ComputeBuffer(neuron_buffer.Length, sizeof(float) + sizeof(int) + sizeof(int) + sizeof(float));
_synapse_write_index_buffer = new ComputeBuffer(synapse_write_index_buffer.Length, sizeof(int));
_synapse_buffer = new ComputeBuffer(synapse_buffer.Length, sizeof(int) + sizeof(float));
_in_neuron_buffer = new ComputeBuffer(in_neuron_buffer.Length, sizeof(int));
_weight_buffer = new ComputeBuffer(weight_buffer.Length, sizeof(float));
_weight_delta_buffer = new ComputeBuffer(weight_delta_buffer.Length, sizeof(float));
_neuron_buffer.SetData(neuron_buffer);
_synapse_write_index_buffer.SetData(synapse_write_index_buffer);
_synapse_buffer.SetData(synapse_buffer);
_in_neuron_buffer.SetData(in_neuron_buffer);
_weight_buffer.SetData(weight_buffer);
_weight_delta_buffer.SetData(weight_delta_buffer);
kernel = _compute_shader.FindKernel("calc");
_compute_shader.SetBuffer(kernel, "neuron_buffer", _neuron_buffer);
_compute_shader.SetBuffer(kernel, "synapse_write_index_buffer", _synapse_write_index_buffer);
_compute_shader.SetBuffer(kernel, "synapse_buffer", _synapse_buffer);
_compute_shader.SetBuffer(kernel, "in_neuron_buffer", _in_neuron_buffer);
_compute_shader.SetBuffer(kernel, "weight_buffer", _weight_buffer);
_compute_shader.SetBuffer(kernel, "weight_delta_buffer", _weight_delta_buffer);
sim_time = 0;
}
// Update is called once per frame
void Update()
{
if (Time.frameCount % 60 == 0)
{
_compute_shader.SetFloat("sim_time", sim_time);
_neuron_buffer.SetData(neuron_buffer);
_compute_shader.Dispatch(kernel, 64, 1, 1);
_neuron_buffer.GetData(neuron_buffer);
_weight_buffer.GetData(weight_buffer);
_weight_delta_buffer.GetData(weight_delta_buffer);
for(int i = 0; i < neuron_buffer.Length; i++)
{
Debug.Log("looking for " + i.ToString());
GameObject node_object = GameObject.Find(i.ToString());
Node node = node_object.GetComponent<Node>();
node.spiked = neuron_buffer[i].spiked;
node.potential = neuron_buffer[i].p_prev;
int in_start = neuron_buffer[i].in_start;
if(in_neuron_buffer[in_start] < 0)
{
continue;
}
if(i + 1 > neuron_buffer.Length - 1)
{
continue;
}
int in_end = neuron_buffer[i + 1].in_start;
for(int j = in_start; j < in_end; j++)
{
if(i == 0 && j == 0)
{
continue;
}
int in_neuron_index = in_neuron_buffer[j];
Debug.Log("i " + i.ToString());
Debug.Log("j " + j.ToString());
Debug.Log("in_neuron_index " + in_neuron_index.ToString());
Debug.Log("weight buffer length " + weight_buffer.Length.ToString());
Debug.Log("neuron buffer length " + neuron_buffer.Length.ToString());
float weight = weight_buffer[j];
float weight_delta = weight_delta_buffer[j];
Debug.Log("looking for " + i.ToString() + in_neuron_index.ToString());
GameObject edge_object = GameObject.Find(in_neuron_index.ToString() + i.ToString());
if (!edge_object)
{
continue;
}
edge_object.transform.GetChild(0).GetComponent<TextMesh>().text = weight.ToString() + " " + weight_delta.ToString();
int weights_start = in_neuron_buffer[neuron_buffer[i].in_start];
}
}
sim_time += 1;
}
}
}