]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
Factor simulation into its own class
[norepinephrine_wm.git] / model.py
1 from datetime import datetime
2 from os import mkdir
3 import logging
4
5 import matplotlib.pyplot as plt
6 import matplotlib.ticker as mtick
7 import nengo
8 import numpy as np
9
10 exec(open("conf.py").read())
11
12
13 def fmt_num(num, width=18):
14     """
15     Format number to string.
16     """
17
18     return str(num)[:width].ljust(width)
19
20
21 def inputs_function(x):
22     return x * tau_wm
23
24
25 def noise_decision_function(t):
26     return np.random.normal(0.0, noise_decision)
27
28
29 def noise_bias_function(t):
30     return np.random.normal(0.0, noise_wm)
31
32
33 def time_function(t):
34     return time_scale if t > t_cue else 0
35
36
37 def decision_function(x):
38     return 1.0 if x[0] + x[1] > 0.0 else -1.0
39
40
41 class Alpha():
42     """
43     Base class for alpha receptors. Not to be used directly.
44     """
45
46     def __init__(self):
47         self.x = np.logspace(0, 3, steps)
48         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
49
50         self.gains = []
51         self.biass = []
52
53         for i in range(steps):
54             y = self.y[i]
55             self.gains.append(1 + self.gaind * y)
56             self.biass.append(1 + self.biasd * y)
57
58     def plot(self):
59         out = f"./out/{self.__class__.__name__}"
60
61         title = "Norepinepherine Concentration vs Neuron Activity in " + \
62             self.pretty
63         logging.info("Plotting " + title)
64         plt.figure()
65         plt.plot(self.x, self.y)
66
67         plt.xlabel("Norepinephrine concentration (nM)")
68         plt.ylabel("Activity (%)")
69         plt.title(title)
70
71         plt.vlines(self.ki, 0, 1, linestyles="dashed")
72         plt.text(1.1 * self.ki, 0.1, "Affinity")
73
74         plt.hlines(0.5, 0, 1000, linestyles="dashed")
75         plt.text(1, 0.51, "50%")
76
77         plt.xscale("log")
78         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
79
80         plt.draw()
81         plt.savefig(f"{out}-norep-activity.png")
82
83         #######################################################################
84
85         title = "Concentration vs Gain Scalar in" + self.pretty
86         logging.info("Plotting " + title)
87         plt.figure()
88         plt.plot(self.x, self.gains)
89
90         plt.xlabel("Norepinephrine concentration (nM)")
91         plt.ylabel("Gain")
92         plt.title(title)
93
94         plt.xscale("log")
95
96         plt.draw()
97         plt.savefig(f"{out}-concentration-gain.png")
98
99         #######################################################################
100
101         title = "Concentration vs Bias scalar in " + self.pretty
102         logging.info("Plotting " + title)
103         plt.figure()
104         plt.plot(self.x, self.biass)
105
106         plt.xscale("log")
107
108         plt.xlabel("Norepinephrine concentration (nM)")
109         plt.ylabel("Bias")
110         plt.title(title)
111
112         plt.draw()
113         plt.savefig(f"{out}-concentration-bias.png")
114
115
116 class Alpha1(Alpha):
117     """
118     Subclass of Alpha representing an alpha1 receptor.
119     """
120
121     def __init__(self):
122         self.ki = 330
123         self.offset = 5.895
124         self.pretty = "α1 Receptor"
125         self.gaind = -0.02
126         self.biasd = 0.04
127         super().__init__()
128
129
130 class Alpha2(Alpha):
131     """
132     Subclass of Alpha representing an alpha2 receptor.
133     """
134
135     def __init__(self):
136         self.ki = 56
137         self.offset = 1
138         self.pretty = "α2 Receptor"
139         self.gaind = 0.1
140         self.biasd = -0.1
141         super().__init__()
142
143
144 class Simulation():
145     def __init__(self):
146         self.a1 = Alpha1()
147         self.a2 = Alpha2()
148
149     def run(self):
150         for i in range(steps):
151             gain = self.a1.gains[i] + self.a2.gains[i] - 1
152             bias = self.a1.biass[i] + self.a2.biass[i] - 1
153             logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}")
154
155             with nengo.Network() as net:
156                 # Nodes
157                 time_node = nengo.Node(output=time_function)
158                 noise_wm_node = nengo.Node(output=noise_bias_function)
159                 noise_decision_node = nengo.Node(
160                     output=noise_decision_function)
161
162                 # Ensembles
163                 wm = nengo.Ensemble(neurons_wm, 2)
164                 wm.gain = np.full(wm.n_neurons, gain)
165                 wm.bias = np.full(wm.n_neurons, bias)
166                 decision = nengo.Ensemble(neurons_decide, 2)
167                 inputs = nengo.Ensemble(neurons_inputs, 2)
168                 output = nengo.Ensemble(neurons_decide, 1)
169
170                 # Connections
171                 nengo.Connection(time_node, inputs[1], synapse=None)
172                 nengo.Connection(inputs, wm, synapse=tau_wm,
173                                  function=inputs_function)
174                 wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
175                 nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
176                                  transform=np.ones((neurons_wm, 1)) * tau_wm)
177                 wm_to_decision = nengo.Connection(
178                     wm[0], decision[0], synapse=tau)
179                 nengo.Connection(noise_decision_node,
180                                  decision[1], synapse=None)
181                 nengo.Connection(decision, output, function=decision_function)
182
183                 # Probes
184                 probes_wm = nengo.Probe(wm[0], synapse=0.01)
185                 probe_output = nengo.Probe(output, synapse=None)
186
187             # Run simulation
188             with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
189                 sim.run(t_cue + t_delay)
190
191
192 def main():
193     logging.info("Initializing simulation")
194     plt.style.use("ggplot")  # Nice looking and familiar style
195     Simulation().run()
196
197
198 if __name__ == "__main__":
199     try:
200         mkdir("./out")
201     except FileExistsError:
202         pass
203
204     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
205                         level=logging.DEBUG)
206     console = logging.StreamHandler()
207     console.setLevel(logging.INFO)
208     logging.getLogger("").addHandler(console)
209
210     main()