]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
final update
[norepinephrine_wm.git] / model.py
1 from datetime import datetime
2 from os import mkdir
3 import logging
4 import pickle
5
6 import matplotlib.pyplot as plt
7 import matplotlib.ticker as mtick
8 import nengo
9 import numpy as np
10 import pandas as pd
11 from tqdm import tqdm
12
13 exec(open("conf.py").read())
14
15
16 def fmt_num(num, width=18):
17     """
18     Format number to string.
19     """
20
21     return str(num)[:width].ljust(width)
22
23
24 def inputs_function(x):
25     return x * tau_wm
26
27
28 def noise_decision_function(t):
29     return np.random.normal(0.0, noise_decision)
30
31
32 def noise_bias_function(t):
33     return np.random.normal(0.0, noise_wm)
34
35
36 def time_function(t):
37     return time_scale if t > t_cue else 0
38
39
40 def decision_function(x):
41     output = 0.0
42     value = x[0] + x[1]
43     if value > 0.0:
44         output = 1.0
45     elif value < 0.0:
46         output = -1.0
47     return output
48     # return 1.0 if x[0] + x[1] > 0.0 else -1.0
49
50
51 class Alpha():
52     """
53     Base class for alpha receptors. Not to be used directly.
54     """
55
56     def __init__(self):
57         self.x = steps
58         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
59
60         self.gains = []
61         self.biass = []
62
63         for i in range(len(steps)):
64             self.gains.append(self.gaind * self.y[i] + 1)
65             self.biass.append(self.biasd * self.y[i] + 1)
66
67     def plot(self):
68         out = f"./out/{self.__class__.__name__}"
69
70         title = "Norepinepherine Concentration vs Neuron Activity in " + \
71             self.pretty
72         logging.info("Plotting " + title)
73         plt.figure()
74         plt.plot(self.x, self.y)
75
76         plt.xlabel("Norepinephrine concentration (nM)")
77         plt.ylabel("Activity (%)")
78         plt.title(title)
79
80         plt.vlines(self.ki, 0, 1, linestyles="dashed")
81         plt.text(1.1 * self.ki, 0.1, "Affinity")
82
83         plt.hlines(0.5, 0, 1000, linestyles="dashed")
84         plt.text(1, 0.51, "50%")
85
86         plt.xscale("log")
87         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
88
89         plt.draw()
90         plt.savefig(f"{out}-norep-activity.png")
91
92         #######################################################################
93
94         title = "Concentration vs Gain/Bias scalar in " + self.pretty
95         logging.info("Plotting " + title)
96         plt.figure()
97         plt.plot(self.x, self.biass, label="Bias scalar")
98         plt.plot(self.x, self.gains, label="Gain scalar")
99
100         plt.xscale("log")
101
102         plt.xlabel("Norepinephrine concentration (nM)")
103         plt.ylabel("Level")
104         plt.title(title)
105         plt.legend()
106
107         plt.draw()
108         plt.savefig(f"{out}-concentration-bias-gains.png")
109
110
111 class Alpha1(Alpha):
112     """
113     Subclass of Alpha representing an alpha1 receptor.
114     """
115
116     def __init__(self):
117         self.ki = 330
118         self.offset = 5.895
119         self.pretty = "α1"
120         #self.gaind = -0.02
121         self.gaind = -0.1
122         #self.biasd = 0.04
123         self.biasd = 0.1
124         super().__init__()
125
126
127 class Alpha2(Alpha):
128     """
129     Subclass of Alpha representing an alpha2 receptor.
130     """
131
132     def __init__(self):
133         self.ki = 56
134         self.offset = 1
135         self.pretty = "α2"
136         self.gaind = 0.1
137         self.biasd = -0.1
138         super().__init__()
139
140
141 class Simulation():
142     def __init__(self):
143         self.a1 = Alpha1()
144         self.a1.plot()
145         self.a2 = Alpha2()
146         self.a2.plot()
147
148         self.num_spikes = np.zeros(len(steps))
149         self.num_correct = np.zeros(len(steps))
150         self.out = np.zeros(n_trials)
151         self.trial = 0
152
153         # correctly perceived (not necessarily remembered) cues
154         self.perceived = np.ones(n_trials)
155         rng = np.random.RandomState(seed=seed)
156         # whether the cues is on the left or right
157         self.cues = 2 * rng.randint(2, size=n_trials)-1
158         for n in range(len(self.perceived)):
159             if rng.rand() < misperceive:
160                 self.perceived[n] = 0
161
162     def plot(self):
163         title = "Norepinephrine Concentration vs Spiking Rate"
164         logging.info("Plotting " + title)
165         plt.figure()
166         plt.plot(steps, self.num_spikes)
167
168         plt.xlabel("Norepinephrine concentration (nM)")
169         plt.ylabel("Spiking rate (spikes/time step)")
170         plt.title(title)
171
172         plt.draw()
173         plt.savefig("./out/concentration-spiking.png")
174
175         ########################################################################
176
177         title = "Norepinephrine Concentration vs Accuracy"
178         logging.info("Plotting " + title)
179         plt.figure()
180         correct_df = pd.DataFrame(np.clip(self.num_correct, 0.5, 1.0)).rolling(20).mean()
181         plt.plot(steps, correct_df)
182
183         plt.xlabel("Norepinephrine concentration (nM)")
184         plt.ylabel("Accuracy")
185         plt.title(title)
186
187         plt.draw()
188         plt.savefig("./out/concentration-correct.png")
189
190     def cue_function(self, t):
191         if t < t_cue and self.perceived[self.trial] != 0:
192             return cue_scale * self.cues[self.trial]
193         else:
194             return 0
195
196     def run(self):
197         with nengo.Network() as net:
198             # Nodes
199             cue_node = nengo.Node(output=self.cue_function)
200             time_node = nengo.Node(output=time_function)
201             noise_wm_node = nengo.Node(output=noise_bias_function)
202             noise_decision_node = nengo.Node(
203                 output=noise_decision_function)
204
205             # Ensembles
206             wm = nengo.Ensemble(neurons_wm, 2)
207             decision = nengo.Ensemble(neurons_decide, 2)
208             inputs = nengo.Ensemble(neurons_inputs, 2)
209             output = nengo.Ensemble(neurons_decide, 1)
210
211             # Connections
212             nengo.Connection(cue_node, inputs[0], synapse=None)
213             nengo.Connection(time_node, inputs[1], synapse=None)
214             nengo.Connection(inputs, wm, synapse=tau_wm,
215                              function=inputs_function)
216             wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
217             nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
218                              transform=np.ones((neurons_wm, 1)) * tau_wm)
219             wm_to_decision = nengo.Connection(
220                 wm[0], decision[0], synapse=tau)
221             nengo.Connection(noise_decision_node,
222                              decision[1], synapse=None)
223             nengo.Connection(decision, output, function=decision_function)
224
225             # Probes
226             wm_probe = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
227             spikes_probe = nengo.Probe(wm.neurons, sample_every=probe_dt)
228             output_probe = nengo.Probe(
229                 output, synapse=None, sample_every=probe_dt)
230
231             # Run simulation
232             for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"):
233                 sim = nengo.Simulator(net, dt=dt, progress_bar=False)
234                 wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
235                 wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
236                 wm_recurrent.solver = MySolver(
237                     sim.model.params[wm_recurrent].weights)
238                 wm_to_decision.solver = MySolver(
239                     sim.model.params[wm_to_decision].weights)
240                 sim = nengo.Simulator(net, dt=dt, progress_bar=False)
241                 for self.trial in range(n_trials):
242                     logging.info(
243                         f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
244                     sim.run(t_cue + t_delay)
245
246                     # Firing rate
247                     self.out[self.trial] = np.count_nonzero(
248                         sim.data[spikes_probe])
249
250                     cue = self.cues[self.trial]
251                     # Correctness
252                     out = sim.data[output_probe][int(t_cue + t_delay)][0]
253                     if (out * cue) > 0:  # check if same sign
254                         self.num_correct[i] += np.abs(1 / (out - cue))
255
256                 self.num_spikes[i] = np.average(self.out)
257
258         with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
259             pickle.dump(self, pout)
260
261         self.plot()
262
263 def get_correct(cue, output_value):
264     return 1 if (cue > 0.0 and output_value > 0.0) or (cue < 0.0 and output_value < 0.0) else 0
265
266
267 class MySolver(nengo.solvers.Solver):
268     def __init__(self, weights):
269         self.weights = False
270         self.my_weights = weights
271         self._paramdict = {}
272
273     def __call__(self, A, Y, rng=None, E=None):
274         return self.my_weights.T, dict()
275
276
277 def main():
278     logging.info("Initializing simulation")
279     plt.style.use("ggplot")  # Nice looking and familiar style
280
281     try:
282         data = open("simulation.pkl", "rb")
283     except FileNotFoundError:
284         Simulation().run()
285     else:
286         pickle.load(data).plot()
287
288
289 if __name__ == "__main__":
290     try:
291         mkdir("./out")
292     except FileExistsError:
293         pass
294
295     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
296                         level=logging.INFO)
297
298     main()