]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
update
[norepinephrine_wm.git] / model.py
1 from datetime import datetime
2 from os import mkdir
3 import logging
4 import pickle
5
6 import matplotlib.pyplot as plt
7 import matplotlib.ticker as mtick
8 import nengo
9 import numpy as np
10 import pandas as pd
11 from tqdm import tqdm
12
13 exec(open("conf.py").read())
14
15
16 def fmt_num(num, width=18):
17     """
18     Format number to string.
19     """
20
21     return str(num)[:width].ljust(width)
22
23
24 def inputs_function(x):
25     return x * tau_wm
26
27
28 def noise_decision_function(t):
29     return np.random.normal(0.0, noise_decision)
30
31
32 def noise_bias_function(t):
33     return np.random.normal(0.0, noise_wm)
34
35
36 def time_function(t):
37     return time_scale if t > t_cue else 0
38
39
40 def decision_function(x):
41     output = 0.0
42     value = x[0] + x[1]
43     if value > 0.0:
44         output = 1.0
45     elif value < 0.0:
46         output = -1.0
47     return output
48     # return 1.0 if x[0] + x[1] > 0.0 else -1.0
49
50
51 class Alpha():
52     """
53     Base class for alpha receptors. Not to be used directly.
54     """
55
56     def __init__(self):
57         self.x = steps
58         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
59
60         self.gains = []
61         self.biass = []
62
63         for i in range(len(steps)):
64             self.gains.append(self.gaind * self.y[i] + 1)
65             self.biass.append(self.biasd * self.y[i] + 1)
66
67     def plot(self):
68         out = f"./out/{self.__class__.__name__}"
69
70         title = "Norepinepherine Concentration vs Neuron Activity in " + \
71             self.pretty
72         logging.info("Plotting " + title)
73         plt.figure()
74         plt.plot(self.x, self.y)
75
76         plt.xlabel("Norepinephrine concentration (nM)")
77         plt.ylabel("Activity (%)")
78         plt.title(title)
79
80         plt.vlines(self.ki, 0, 1, linestyles="dashed")
81         plt.text(1.1 * self.ki, 0.1, "Affinity")
82
83         plt.hlines(0.5, 0, 1000, linestyles="dashed")
84         plt.text(1, 0.51, "50%")
85
86         plt.xscale("log")
87         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
88
89         plt.draw()
90         plt.savefig(f"{out}-norep-activity.png")
91
92         #######################################################################
93
94         title = "Concentration vs Gain Scalar in" + self.pretty
95         logging.info("Plotting " + title)
96         plt.figure()
97         plt.plot(self.x, self.gains)
98
99         plt.xlabel("Norepinephrine concentration (nM)")
100         plt.ylabel("Gain")
101         plt.title(title)
102
103         plt.xscale("log")
104
105         plt.draw()
106         plt.savefig(f"{out}-concentration-gain.png")
107
108         #######################################################################
109
110         title = "Concentration vs Bias scalar in " + self.pretty
111         logging.info("Plotting " + title)
112         plt.figure()
113         plt.plot(self.x, self.biass)
114
115         plt.xscale("log")
116
117         plt.xlabel("Norepinephrine concentration (nM)")
118         plt.ylabel("Bias")
119         plt.title(title)
120
121         plt.draw()
122         plt.savefig(f"{out}-concentration-bias.png")
123
124
125 class Alpha1(Alpha):
126     """
127     Subclass of Alpha representing an alpha1 receptor.
128     """
129
130     def __init__(self):
131         self.ki = 330
132         self.offset = 5.895
133         self.pretty = "α1 Receptor"
134         self.gaind = -0.02
135         self.biasd = 0.04
136         super().__init__()
137
138
139 class Alpha2(Alpha):
140     """
141     Subclass of Alpha representing an alpha2 receptor.
142     """
143
144     def __init__(self):
145         self.ki = 56
146         self.offset = 1
147         self.pretty = "α2 Receptor"
148         self.gaind = 0.1
149         self.biasd = -0.1
150         super().__init__()
151
152
153 class Simulation():
154     def __init__(self):
155         self.a1 = Alpha1()
156         self.a2 = Alpha2()
157         self.num_spikes = np.ones(len(steps))
158         self.out = np.ones(3)
159         self.trial = 0
160
161         # correctly perceived (not necessarily remembered) cues
162         self.perceived = np.ones(3)
163         rng = np.random.RandomState(seed=seed)
164         # whether the cues is on the left or right
165         self.cues = 2 * rng.randint(2, size=3)-1
166         for n in range(len(self.perceived)):
167             if rng.rand() < misperceive:
168                 self.perceived[n] = 0
169
170     def plot(self):
171         title = "Norepinephrine Concentration vs Spiking Rate"
172         logging.info("Plotting " + title)
173         plt.figure()
174         plt.plot(steps, self.num_spikes)
175
176         #plt.xscale("log")
177
178         plt.xlabel("Norepinephrine concentration (nM)")
179         plt.ylabel("Spiking rate (spikes/time step)")
180         plt.title(title)
181
182         plt.draw()
183         plt.savefig("./out/concentration-spiking.png")
184
185     def cue_function(self, t):
186         if t < t_cue and self.perceived[self.trial] != 0:
187             return cue_scale * self.cues[self.trial]
188         else:
189             return 0
190
191     def run(self):
192         self.a1.plot()
193         self.a2.plot()
194
195         with nengo.Network() as net:
196             # Nodes
197             cue_node = nengo.Node(output=self.cue_function)
198             time_node = nengo.Node(output=time_function)
199             noise_wm_node = nengo.Node(output=noise_bias_function)
200             noise_decision_node = nengo.Node(
201                 output=noise_decision_function)
202
203             # Ensembles
204             wm = nengo.Ensemble(neurons_wm, 2)
205             decision = nengo.Ensemble(neurons_decide, 2)
206             inputs = nengo.Ensemble(neurons_inputs, 2)
207             output = nengo.Ensemble(neurons_decide, 1)
208
209             # Connections
210             nengo.Connection(cue_node, inputs[0], synapse=None)
211             nengo.Connection(time_node, inputs[1], synapse=None)
212             nengo.Connection(inputs, wm, synapse=tau_wm,
213                              function=inputs_function)
214             wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
215             nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
216                              transform=np.ones((neurons_wm, 1)) * tau_wm)
217             wm_to_decision = nengo.Connection(
218                 wm[0], decision[0], synapse=tau)
219             nengo.Connection(noise_decision_node,
220                              decision[1], synapse=None)
221             nengo.Connection(decision, output, function=decision_function)
222
223             # Probes
224             probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
225             probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt)
226             probe_output = nengo.Probe(
227                 output, synapse=None, sample_every=probe_dt)
228
229             # Run simulation
230             with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
231                 for i, _ in tqdm(enumerate(steps), total=len(steps)):
232                     wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
233                     wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
234                     wm_recurrent.solver = MySolver(
235                         sim.model.params[wm_recurrent].weights)
236                     wm_to_decision.solver = MySolver(
237                         sim.model.params[wm_to_decision].weights)
238                     sim = nengo.Simulator(net, dt=dt, progress_bar=False)
239                     for self.trial in range(3):
240                         logging.info(
241                             f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
242                         sim.run(t_cue + t_delay)
243                         self.out[self.trial] = np.count_nonzero(
244                             sim.data[probe_spikes])
245                     self.num_spikes[i] = np.average(self.out)
246
247         with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
248             pickle.dump(self, pout)
249
250         self.plot()
251
252
253 class MySolver(nengo.solvers.Solver):
254     def __init__(self, weights):
255         self.weights = False
256         self.my_weights = weights
257         self._paramdict = dict()
258
259     def __call__(self, A, Y, rng=None, E=None):
260         return self.my_weights.T, dict()
261
262
263 def main():
264     logging.info("Initializing simulation")
265     plt.style.use("ggplot")  # Nice looking and familiar style
266
267     try:
268         data = open("simulation.pkl", "rb")
269     except FileNotFoundError:
270         Simulation().run()
271     else:
272         pickle.load(data).plot()
273
274
275 if __name__ == "__main__":
276     try:
277         mkdir("./out")
278     except FileExistsError:
279         pass
280
281     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
282                         level=logging.INFO)
283
284     main()