]> git.armaanb.net Git - norepinephrine_wm.git/blob - model.py
Write outer loop
[norepinephrine_wm.git] / model.py
1 from os import mkdir
2 import matplotlib.pyplot as plt
3 import matplotlib.ticker as mtick
4 import nengo
5 import numpy as np
6
7 exec(open("conf.py").read())
8
9 def fmt_num(num):
10     """
11     Format number to string.
12     """
13
14     return str(num)[:18].zfill(18)
15
16
17 def wm_recurrent_function(x):
18     return x
19
20
21 def inputs_function(x):
22     return x * tau_wm
23
24
25 def noise_decision_function(t):
26     return np.random.normal(0.0, noise_decision)
27
28
29 def noise_bias_function(t):
30     return np.random.normal(0.0, noise_wm)
31
32
33 def time_function(t):
34     return time_scale if t > t_cue else 0
35
36
37 def decision_function(x):
38     return 1.0 if x[0] + x[1] > 0.0 else -1.0
39
40
41 class Alpha(object):
42     """
43     Base class for alpha receptors. Not to be used directly.
44     """
45
46     def __init__(self):
47         self.x = np.logspace(0, 3, steps)
48         self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset))))
49
50         self.gain = []
51         self.bias = []
52
53         for i in range(steps):
54             y = self.y[i]
55             self.gain.append(1 + self.gaind * y)
56             self.bias.append(1 + self.biasd * y)
57
58     def plot(self):
59         try:
60             mkdir("./out")
61         except FileExistsError:
62             pass
63
64         out = f"./out/{self.__class__.__name__}"
65         plt.figure()
66         plt.plot(self.x, self.y)
67
68         plt.xlabel("Norepinephrine concentration (nM)")
69         plt.ylabel("Activity (%)")
70         plt.title("Norepinepherine Concentration vs Neuron Activity in " +
71                   self.pretty)
72
73         plt.vlines(self.ki, 0, 1, linestyles="dashed")
74         plt.text(1.1 * self.ki, 0.1, "Affinity")
75
76         plt.hlines(0.5, 0, 1000, linestyles="dashed")
77         plt.text(1, 0.51, "50%")
78
79         plt.xscale("log")
80         plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
81
82         plt.draw()
83         plt.savefig(f"{out}-norep-activity.png", dpi=1000)
84
85         #######################################################################
86
87         plt.figure()
88         plt.plot(self.x, self.gain)
89
90         plt.xlabel("Norepinephrine concentration (nM)")
91         plt.ylabel("Gain")
92         plt.title(f"Concentration vs Gain in {self.pretty}")
93
94         plt.draw()
95         plt.savefig(f"{out}-concentration-gain.png", dpi=1000)
96
97         #######################################################################
98
99         plt.figure()
100         plt.plot(self.x, self.bias)
101
102         plt.xlabel("Norepinephrine concentration (nM)")
103         plt.ylabel("Bias")
104         plt.title("Concentration vs Bias in " + self.pretty)
105
106         plt.draw()
107         plt.savefig(f"{out}-concentration-bias.png", dpi=1000)
108
109     def simulate(self):
110         for i in range(steps):
111             print(f"{self.__class__.__name__}, gain: {fmt_num(self.gain[i])}, bias: {fmt_num(self.bias[i])}")
112             with nengo.Network() as net:
113                 # Nodes
114                 time_node = nengo.Node(output=time_function)
115                 noise_wm_node = nengo.Node(output=noise_bias_function)
116                 noise_decision_node = nengo.Node(
117                     output=noise_decision_function)
118
119                 # Ensembles
120                 wm = nengo.Ensemble(neurons_wm, 2)
121                 wm.gain = np.full(wm.n_neurons, self.gain[i])
122                 wm.bias = np.full(wm.n_neurons, self.bias[i])
123                 decision = nengo.Ensemble(neurons_decide, 2)
124                 inputs = nengo.Ensemble(neurons_inputs, 2)
125                 output = nengo.Ensemble(neurons_decide, 1)
126
127                 # Connections
128                 nengo.Connection(time_node, inputs[1], synapse=None)
129                 nengo.Connection(inputs, wm, synapse=tau_wm,
130                                  function=inputs_function)
131                 wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm,
132                                                 function=wm_recurrent_function)
133                 nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
134                                  transform=np.ones((neurons_wm, 1)) * tau_wm)
135                 wm_to_decision = nengo.Connection(
136                     wm[0], decision[0], synapse=tau)
137                 nengo.Connection(noise_decision_node,
138                                  decision[1], synapse=None)
139                 nengo.Connection(decision, output, function=decision_function)
140
141                 # Probes
142                 # probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=dt_sample)
143                 # probes_spikes = nengo.Probe(wm.neurons, 'spikes',
144                 #                           sample_every=dt_sample)
145                 # probe_output = nengo.Probe(output, synapse=None, same_every=dt_sample)
146
147                 # Run simulation
148             with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
149                 sim.run(t_cue + t_delay)
150
151
152 class Alpha1(Alpha):
153     """
154     Subclass of Alpha representing an alpha1 receptor.
155     """
156
157     def __init__(self):
158         self.ki = 330
159         self.offset = 5.895
160         self.pretty = "α1 Receptor"
161         self.gaind = -0.04
162         self.biasd = -0.02
163         super().__init__()
164
165
166 class Alpha2(Alpha):
167     """
168     Subclass of Alpha representing an alpha2 receptor.
169     """
170
171     def __init__(self):
172         self.ki = 56
173         self.offset = 1
174         self.pretty = "α2 Receptor"
175         self.gaind = -0.1
176         self.biasd = 0.1
177         super().__init__()
178
179
180 def main():
181     plt.style.use("ggplot")  # Nice looking and familiar style
182     a1 = Alpha1()
183     # a1.plot()
184     a1.simulate()
185
186     a2 = Alpha2()
187     # a2.plot()
188     a2.simulate()
189
190
191 if __name__ == "__main__":
192     main()