1 from datetime import datetime
6 import matplotlib.pyplot as plt
7 import matplotlib.ticker as mtick
13 exec(open("conf.py").read())
16 def fmt_num(num, width=18):
18 Format number to string.
21 return str(num)[:width].ljust(width)
24 def inputs_function(x):
28 def noise_decision_function(t):
29 return np.random.normal(0.0, noise_decision)
32 def noise_bias_function(t):
33 return np.random.normal(0.0, noise_wm)
37 return time_scale if t > t_cue else 0
40 def decision_function(x):
48 # return 1.0 if x[0] + x[1] > 0.0 else -1.0
53 Base class for alpha receptors. Not to be used directly.
58 self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
63 for i in range(len(steps)):
64 self.gains.append(self.gaind * self.y[i] + 1)
65 self.biass.append(self.biasd * self.y[i] + 1)
68 out = f"./out/{self.__class__.__name__}"
70 title = "Norepinepherine Concentration vs Neuron Activity in " + \
72 logging.info("Plotting " + title)
74 plt.plot(self.x, self.y)
76 plt.xlabel("Norepinephrine concentration (nM)")
77 plt.ylabel("Activity (%)")
80 plt.vlines(self.ki, 0, 1, linestyles="dashed")
81 plt.text(1.1 * self.ki, 0.1, "Affinity")
83 plt.hlines(0.5, 0, 1000, linestyles="dashed")
84 plt.text(1, 0.51, "50%")
87 plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
90 plt.savefig(f"{out}-norep-activity.png")
92 #######################################################################
94 title = "Concentration vs Gain/Bias scalar in " + self.pretty
95 logging.info("Plotting " + title)
97 plt.plot(self.x, self.biass, label="Bias scalar")
98 plt.plot(self.x, self.gains, label="Gain scalar")
102 plt.xlabel("Norepinephrine concentration (nM)")
108 plt.savefig(f"{out}-concentration-bias-gains.png")
113 Subclass of Alpha representing an alpha1 receptor.
129 Subclass of Alpha representing an alpha2 receptor.
148 self.num_spikes = np.zeros(len(steps))
149 self.num_correct = np.zeros(len(steps))
150 self.out = np.zeros(n_trials)
153 # correctly perceived (not necessarily remembered) cues
154 self.perceived = np.ones(n_trials)
155 rng = np.random.RandomState(seed=seed)
156 # whether the cues is on the left or right
157 self.cues = 2 * rng.randint(2, size=n_trials)-1
158 for n in range(len(self.perceived)):
159 if rng.rand() < misperceive:
160 self.perceived[n] = 0
163 title = "Norepinephrine Concentration vs Spiking Rate"
164 logging.info("Plotting " + title)
166 plt.plot(steps, self.num_spikes)
168 plt.xlabel("Norepinephrine concentration (nM)")
169 plt.ylabel("Spiking rate (spikes/time step)")
173 plt.savefig("./out/concentration-spiking.png")
175 ########################################################################
177 title = "Norepinephrine Concentration vs Accuracy"
178 logging.info("Plotting " + title)
180 correct_df = pd.DataFrame(np.clip(self.num_correct, 0.5, 1.0)).rolling(20).mean()
181 plt.plot(steps, correct_df)
183 plt.xlabel("Norepinephrine concentration (nM)")
184 plt.ylabel("Accuracy")
188 plt.savefig("./out/concentration-correct.png")
190 def cue_function(self, t):
191 if t < t_cue and self.perceived[self.trial] != 0:
192 return cue_scale * self.cues[self.trial]
197 with nengo.Network() as net:
199 cue_node = nengo.Node(output=self.cue_function)
200 time_node = nengo.Node(output=time_function)
201 noise_wm_node = nengo.Node(output=noise_bias_function)
202 noise_decision_node = nengo.Node(
203 output=noise_decision_function)
206 wm = nengo.Ensemble(neurons_wm, 2)
207 decision = nengo.Ensemble(neurons_decide, 2)
208 inputs = nengo.Ensemble(neurons_inputs, 2)
209 output = nengo.Ensemble(neurons_decide, 1)
212 nengo.Connection(cue_node, inputs[0], synapse=None)
213 nengo.Connection(time_node, inputs[1], synapse=None)
214 nengo.Connection(inputs, wm, synapse=tau_wm,
215 function=inputs_function)
216 wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
217 nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
218 transform=np.ones((neurons_wm, 1)) * tau_wm)
219 wm_to_decision = nengo.Connection(
220 wm[0], decision[0], synapse=tau)
221 nengo.Connection(noise_decision_node,
222 decision[1], synapse=None)
223 nengo.Connection(decision, output, function=decision_function)
226 wm_probe = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
227 spikes_probe = nengo.Probe(wm.neurons, sample_every=probe_dt)
228 output_probe = nengo.Probe(
229 output, synapse=None, sample_every=probe_dt)
232 for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"):
233 sim = nengo.Simulator(net, dt=dt, progress_bar=False)
234 wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
235 wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
236 wm_recurrent.solver = MySolver(
237 sim.model.params[wm_recurrent].weights)
238 wm_to_decision.solver = MySolver(
239 sim.model.params[wm_to_decision].weights)
240 sim = nengo.Simulator(net, dt=dt, progress_bar=False)
241 for self.trial in range(n_trials):
243 f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
244 sim.run(t_cue + t_delay)
247 self.out[self.trial] = np.count_nonzero(
248 sim.data[spikes_probe])
250 cue = self.cues[self.trial]
252 out = sim.data[output_probe][int(t_cue + t_delay)][0]
253 if (out * cue) > 0: # check if same sign
254 self.num_correct[i] += np.abs(1 / (out - cue))
256 self.num_spikes[i] = np.average(self.out)
258 with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
259 pickle.dump(self, pout)
263 def get_correct(cue, output_value):
264 return 1 if (cue > 0.0 and output_value > 0.0) or (cue < 0.0 and output_value < 0.0) else 0
267 class MySolver(nengo.solvers.Solver):
268 def __init__(self, weights):
270 self.my_weights = weights
273 def __call__(self, A, Y, rng=None, E=None):
274 return self.my_weights.T, dict()
278 logging.info("Initializing simulation")
279 plt.style.use("ggplot") # Nice looking and familiar style
282 data = open("simulation.pkl", "rb")
283 except FileNotFoundError:
286 pickle.load(data).plot()
289 if __name__ == "__main__":
292 except FileExistsError:
295 logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",