]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
More work
[norepinephrine_wm.git] / model.py
1 from datetime import datetime
2 from os import mkdir
3 import logging
4 import pickle
5
6 import matplotlib.pyplot as plt
7 import matplotlib.ticker as mtick
8 import nengo
9 import numpy as np
10 import pandas as pd
11 from tqdm import tqdm
12
13 exec(open("conf.py").read())
14
15 def fmt_num(num, width=18):
16     """
17     Format number to string.
18     """
19
20     return str(num)[:width].ljust(width)
21
22
23 def inputs_function(x):
24     return x * tau_wm
25
26
27 def noise_decision_function(t):
28     return np.random.normal(0.0, noise_decision)
29
30
31 def noise_bias_function(t):
32     return np.random.normal(0.0, noise_wm)
33
34
35 def time_function(t):
36     return time_scale if t > t_cue else 0
37
38
39 def decision_function(x):
40     output = 0.0
41     value = x[0] + x[1]
42     if value > 0.0:
43         output = 1.0
44     elif value < 0.0:
45         output = -1.0
46     return output
47     #return 1.0 if x[0] + x[1] > 0.0 else -1.0
48
49
50 class Alpha():
51     """
52     Base class for alpha receptors. Not to be used directly.
53     """
54
55     def __init__(self):
56         self.x = steps
57         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
58
59         self.gains = []
60         self.biass = []
61
62         for i in range(len(steps)):
63             self.gains.append(1 + self.gaind * self.y[i])
64             self.biass.append(1 + self.biasd * self.y[i])
65
66     def plot(self):
67         out = f"./out/{self.__class__.__name__}"
68
69         title = "Norepinepherine Concentration vs Neuron Activity in " + \
70             self.pretty
71         logging.info("Plotting " + title)
72         plt.figure()
73         plt.plot(self.x, self.y)
74
75         plt.xlabel("Norepinephrine concentration (nM)")
76         plt.ylabel("Activity (%)")
77         plt.title(title)
78
79         plt.vlines(self.ki, 0, 1, linestyles="dashed")
80         plt.text(1.1 * self.ki, 0.1, "Affinity")
81
82         plt.hlines(0.5, 0, 1000, linestyles="dashed")
83         plt.text(1, 0.51, "50%")
84
85         plt.xscale("log")
86         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
87
88         plt.draw()
89         plt.savefig(f"{out}-norep-activity.png")
90
91         #######################################################################
92
93         title = "Concentration vs Gain Scalar in" + self.pretty
94         logging.info("Plotting " + title)
95         plt.figure()
96         plt.plot(self.x, self.gains)
97
98         plt.xlabel("Norepinephrine concentration (nM)")
99         plt.ylabel("Gain")
100         plt.title(title)
101
102         plt.xscale("log")
103
104         plt.draw()
105         plt.savefig(f"{out}-concentration-gain.png")
106
107         #######################################################################
108
109         title = "Concentration vs Bias scalar in " + self.pretty
110         logging.info("Plotting " + title)
111         plt.figure()
112         plt.plot(self.x, self.biass)
113
114         plt.xscale("log")
115
116         plt.xlabel("Norepinephrine concentration (nM)")
117         plt.ylabel("Bias")
118         plt.title(title)
119
120         plt.draw()
121         plt.savefig(f"{out}-concentration-bias.png")
122
123
124 class Alpha1(Alpha):
125     """
126     Subclass of Alpha representing an alpha1 receptor.
127     """
128
129     def __init__(self):
130         self.ki = 330
131         self.offset = 5.895
132         self.pretty = "α1 Receptor"
133         self.gaind = -0.02
134         self.biasd = 0.04
135         super().__init__()
136
137
138 class Alpha2(Alpha):
139     """
140     Subclass of Alpha representing an alpha2 receptor.
141     """
142
143     def __init__(self):
144         self.ki = 56
145         self.offset = 1
146         self.pretty = "α2 Receptor"
147         self.gaind = 0.1
148         self.biasd = -0.1
149         super().__init__()
150
151
152 class Simulation():
153     def __init__(self):
154         self.a1 = Alpha1()
155         self.a2 = Alpha2()
156         self.num_spikes = np.ones(len(steps))
157         self.biass = np.ones(len(steps))
158         self.gains = np.ones(len(steps))
159         self.out = np.ones(3)
160         self.trial = 0
161
162         self.perceived = np.ones(3)  # correctly perceived (not necessarily remembered) cues
163         rng=np.random.RandomState(seed=111)
164         self.cues=2 * rng.randint(2, size=3)-1  # whether the cues is on the left or right
165         for n in range(len(self.perceived)):
166             if rng.rand() < misperceive:
167                 self.perceived[n] = 0
168
169     def plot(self):
170         title = "Norepinephrine Concentration vs Spiking Rate"
171         logging.info("Plotting " + title)
172         plt.figure()
173         plt.plot(steps, self.num_spikes)
174
175         plt.xscale("log")
176
177         plt.xlabel("Norepinephrine concentration (nM)")
178         plt.ylabel("Spiking rate (spikes/time step)")
179         plt.title(title)
180
181         plt.draw()
182         plt.savefig("./out/concentration-spiking.png")
183
184     def cue_function(self, t):
185         if t < t_cue and self.perceived[self.trial] != 0:
186             return cue_scale * self.cues[self.trial]
187         else:
188             return 0
189
190     def run(self):
191         with nengo.Network() as net:
192             # Nodes
193             cue_node = nengo.Node(output=self.cue_function)
194             time_node = nengo.Node(output=time_function)
195             noise_wm_node = nengo.Node(output=noise_bias_function)
196             noise_decision_node = nengo.Node(
197                 output=noise_decision_function)
198
199             # Ensembles
200             wm = nengo.Ensemble(neurons_wm, 2)
201             decision = nengo.Ensemble(neurons_decide, 2)
202             inputs = nengo.Ensemble(neurons_inputs, 2)
203             output = nengo.Ensemble(neurons_decide, 1)
204
205             # Connections
206             nengo.Connection(cue_node, inputs[0], synapse=None)
207             nengo.Connection(time_node, inputs[1], synapse=None)
208             nengo.Connection(inputs, wm, synapse=tau_wm,
209                              function=inputs_function)
210             wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
211             nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
212                              transform=np.ones((neurons_wm, 1)) * tau_wm)
213             wm_to_decision = nengo.Connection(
214                 wm[0], decision[0], synapse=tau)
215             nengo.Connection(noise_decision_node,
216                              decision[1], synapse=None)
217             nengo.Connection(decision, output, function=decision_function)
218
219             # Probes
220             probes_wm = nengo.Probe(wm[0], synapse=0.01)
221             probe_spikes = nengo.Probe(wm.neurons)
222             probe_output = nengo.Probe(output, synapse=None)
223
224             # Run simulation
225             with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
226                 for i, _ in tqdm(enumerate(steps), total=len(steps)):
227                     wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
228                     wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].gain
229                     for self.trial in range(3):
230                         logging.debug(f"Simulating: trial: {self.trial}, gain: {fmt_num(self.gains[i])}, bias: {fmt_num(self.biass[i])}")
231                         sim.run(t_cue + t_delay)
232                         self.out[self.trial] = np.count_nonzero(sim.data[probe_spikes])
233                     self.num_spikes[i] = np.average(self.out)
234
235         with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
236             pickle.dump(self.num_spikes, pout)
237
238         self.plot()
239
240 def main():
241     logging.info("Initializing simulation")
242     plt.style.use("ggplot")  # Nice looking and familiar style
243     Simulation().run()
244
245
246 if __name__ == "__main__":
247     try:
248         mkdir("./out")
249     except FileExistsError:
250         pass
251
252     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
253                         level=logging.DEBUG)
254
255     main()