]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
Log things
[norepinephrine_wm.git] / model.py
1 from datetime import datetime
2 from os import mkdir
3 import logging
4
5 import matplotlib.pyplot as plt
6 import matplotlib.ticker as mtick
7 import nengo
8 import numpy as np
9
10 exec(open("conf.py").read())
11
12
13 def fmt_num(num, width=18):
14     """
15     Format number to string.
16     """
17
18     return str(num)[:width].ljust(width)
19
20
21 def inputs_function(x):
22     return x * tau_wm
23
24
25 def noise_decision_function(t):
26     return np.random.normal(0.0, noise_decision)
27
28
29 def noise_bias_function(t):
30     return np.random.normal(0.0, noise_wm)
31
32
33 def time_function(t):
34     return time_scale if t > t_cue else 0
35
36
37 def decision_function(x):
38     return 1.0 if x[0] + x[1] > 0.0 else -1.0
39
40
41 class Alpha(object):
42     """
43     Base class for alpha receptors. Not to be used directly.
44     """
45
46     def __init__(self):
47         self.x = np.logspace(0, 3, steps)
48         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
49
50         self.gains = []
51         self.biass = []
52
53         for i in range(steps):
54             y = self.y[i]
55             self.gains.append(1 + self.gaind * y)
56             self.biass.append(1 + self.biasd * y)
57
58     def plot(self):
59         out = f"./out/{self.__class__.__name__}"
60
61         title = "Norepinepherine Concentration vs Neuron Activity in " + self.pretty
62         logging.info("Plotting " + title)
63         plt.figure()
64         plt.plot(self.x, self.y)
65
66         plt.xlabel("Norepinephrine concentration (nM)")
67         plt.ylabel("Activity (%)")
68         plt.title(title)
69
70         plt.vlines(self.ki, 0, 1, linestyles="dashed")
71         plt.text(1.1 * self.ki, 0.1, "Affinity")
72
73         plt.hlines(0.5, 0, 1000, linestyles="dashed")
74         plt.text(1, 0.51, "50%")
75
76         plt.xscale("log")
77         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
78
79         plt.draw()
80         plt.savefig(f"{out}-norep-activity.png")
81
82         #######################################################################
83
84         title = "Concentration vs Gain Scalar in" + self.pretty
85         logging.info("Plotting " + title)
86         plt.figure()
87         plt.plot(self.x, self.gains)
88
89         plt.xlabel("Norepinephrine concentration (nM)")
90         plt.ylabel("Gain")
91         plt.title(title)
92
93         plt.xscale("log")
94
95         plt.draw()
96         plt.savefig(f"{out}-concentration-gain.png")
97
98         #######################################################################
99
100         title = "Concentration vs Bias scalar in " + self.pretty
101         logging.info("Plotting " + title)
102         plt.figure()
103         plt.plot(self.x, self.biass)
104
105         plt.xscale("log")
106
107         plt.xlabel("Norepinephrine concentration (nM)")
108         plt.ylabel("Bias")
109         plt.title(title)
110
111         plt.draw()
112         plt.savefig(f"{out}-concentration-bias.png")
113
114
115 class Alpha1(Alpha):
116     """
117     Subclass of Alpha representing an alpha1 receptor.
118     """
119
120     def __init__(self):
121         self.ki = 330
122         self.offset = 5.895
123         self.pretty = "α1 Receptor"
124         self.gaind = -0.02
125         self.biasd = 0.04
126         super().__init__()
127
128
129 class Alpha2(Alpha):
130     """
131     Subclass of Alpha representing an alpha2 receptor.
132     """
133
134     def __init__(self):
135         self.ki = 56
136         self.offset = 1
137         self.pretty = "α2 Receptor"
138         self.gaind = 0.1
139         self.biasd = -0.1
140         super().__init__()
141
142
143 def simulate(a1, a2):
144     for i in range(steps):
145         gain = a1.gains[i] + a2.gains[i] - 1
146         bias = a1.biass[i] + a2.biass[i] - 1
147         logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}")
148         with nengo.Network() as net:
149             # Nodes
150             time_node = nengo.Node(output=time_function)
151             noise_wm_node = nengo.Node(output=noise_bias_function)
152             noise_decision_node = nengo.Node(
153                 output=noise_decision_function)
154
155             # Ensembles
156             wm = nengo.Ensemble(neurons_wm, 2)
157             wm.gain = np.full(wm.n_neurons, gain)
158             wm.bias = np.full(wm.n_neurons, bias)
159             decision = nengo.Ensemble(neurons_decide, 2)
160             inputs = nengo.Ensemble(neurons_inputs, 2)
161             output = nengo.Ensemble(neurons_decide, 1)
162
163             # Connections
164             nengo.Connection(time_node, inputs[1], synapse=None)
165             nengo.Connection(inputs, wm, synapse=tau_wm,
166                              function=inputs_function)
167             wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
168             nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
169                              transform=np.ones((neurons_wm, 1)) * tau_wm)
170             wm_to_decision = nengo.Connection(
171                 wm[0], decision[0], synapse=tau)
172             nengo.Connection(noise_decision_node,
173                              decision[1], synapse=None)
174             nengo.Connection(decision, output, function=decision_function)
175
176             # Probes
177             probes_wm = nengo.Probe(wm[0], synapse=0.01)
178             probe_output = nengo.Probe(output, synapse=None)
179
180         # Run simulation
181         with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
182             sim.run(t_cue + t_delay)
183
184
185 def main():
186     logging.info("Initializing simulation")
187     plt.style.use("ggplot")  # Nice looking and familiar style
188
189     a1 = Alpha1()
190     a1.plot()
191
192     a2 = Alpha2()
193     a2.plot()
194
195     simulate(a1, a2)
196
197
198 if __name__ == "__main__":
199     try:
200         mkdir("./out")
201     except FileExistsError:
202         pass
203
204     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
205                         level=logging.DEBUG)
206     console = logging.StreamHandler()
207     console.setLevel(logging.INFO)
208     logging.getLogger("").addHandler(console)
209
210     main()