]> git.armaanb.net Git - norepinephrine_wm.git/blobdiff - model.py
Write outer loop
[norepinephrine_wm.git] / model.py
index db9946a45aacd09ca174a2a65437b3bae8a35982..f3079e211decc92113a2d04546e4f0c5df178a34 100644 (file)
--- a/model.py
+++ b/model.py
@@ -6,20 +6,28 @@ import numpy as np
 
 exec(open("conf.py").read())
 
+def fmt_num(num):
+    """
+    Format number to string.
+    """
+
+    return str(num)[:18].zfill(18)
+
+
 def wm_recurrent_function(x):
-       return x
+    return x
 
 
 def inputs_function(x):
-       return x * tau_wm
+    return x * tau_wm
 
 
 def noise_decision_function(t):
-       return np.random.normal(0.0, noise_decision)
+    return np.random.normal(0.0, noise_decision)
 
 
 def noise_bias_function(t):
-       return np.random.normal(0.0, noise_wm)
+    return np.random.normal(0.0, noise_wm)
 
 
 def time_function(t):
@@ -54,6 +62,7 @@ class Alpha(object):
             pass
 
         out = f"./out/{self.__class__.__name__}"
+        plt.figure()
         plt.plot(self.x, self.y)
 
         plt.xlabel("Norepinephrine concentration (nM)")
@@ -75,6 +84,7 @@ class Alpha(object):
 
         #######################################################################
 
+        plt.figure()
         plt.plot(self.x, self.gain)
 
         plt.xlabel("Norepinephrine concentration (nM)")
@@ -86,6 +96,7 @@ class Alpha(object):
 
         #######################################################################
 
+        plt.figure()
         plt.plot(self.x, self.bias)
 
         plt.xlabel("Norepinephrine concentration (nM)")
@@ -96,6 +107,8 @@ class Alpha(object):
         plt.savefig(f"{out}-concentration-bias.png", dpi=1000)
 
     def simulate(self):
+        for i in range(steps):
+            print(f"{self.__class__.__name__}, gain: {fmt_num(self.gain[i])}, bias: {fmt_num(self.bias[i])}")
             with nengo.Network() as net:
                 # Nodes
                 time_node = nengo.Node(output=time_function)
@@ -105,6 +118,8 @@ class Alpha(object):
 
                 # Ensembles
                 wm = nengo.Ensemble(neurons_wm, 2)
+                wm.gain = np.full(wm.n_neurons, self.gain[i])
+                wm.bias = np.full(wm.n_neurons, self.bias[i])
                 decision = nengo.Ensemble(neurons_decide, 2)
                 inputs = nengo.Ensemble(neurons_inputs, 2)
                 output = nengo.Ensemble(neurons_decide, 1)
@@ -124,16 +139,14 @@ class Alpha(object):
                 nengo.Connection(decision, output, function=decision_function)
 
                 # Probes
-                #probes_wm = nengo.Probe(
-                #    wm[0], synapse=0.01, sample_every=dt_sample)
-                #probes_spikes = nengo.Probe(wm.neurons, 'spikes',
-                #                            sample_every=dt_sample)
-                #probe_output = nengo.Probe(output, synapse=None,
-                #                           same_every=dt_sample)
+                # probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=dt_sample)
+                # probes_spikes = nengo.Probe(wm.neurons, 'spikes',
+                #                           sample_every=dt_sample)
+                # probe_output = nengo.Probe(output, synapse=None, same_every=dt_sample)
 
                 # Run simulation
-                with nengo.Simulator(net, dt=dt) as sim:
-                    sim.run(t_cue + t_delay)
+            with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
+                sim.run(t_cue + t_delay)
 
 
 class Alpha1(Alpha):
@@ -145,8 +158,8 @@ class Alpha1(Alpha):
         self.ki = 330
         self.offset = 5.895
         self.pretty = "α1 Receptor"
-        self.gaind = 0.1
-        self.biasd = 0.1
+        self.gaind = -0.04
+        self.biasd = -0.02
         super().__init__()
 
 
@@ -159,20 +172,20 @@ class Alpha2(Alpha):
         self.ki = 56
         self.offset = 1
         self.pretty = "α2 Receptor"
-        self.gaind = -0.04
-        self.biasd = -0.02
+        self.gaind = -0.1
+        self.biasd = 0.1
         super().__init__()
 
 
 def main():
     plt.style.use("ggplot")  # Nice looking and familiar style
-
     a1 = Alpha1()
     # a1.plot()
     a1.simulate()
 
-    #a2 = Alpha2()
+    a2 = Alpha2()
     # a2.plot()
+    a2.simulate()
 
 
 if __name__ == "__main__":