]> git.armaanb.net Git - norepinephrine_wm.git/commitdiff
update
authorArmaan Bhojwani <me@armaanb.net>
Wed, 11 Aug 2021 02:18:08 +0000 (22:18 -0400)
committerArmaan Bhojwani <me@armaanb.net>
Wed, 11 Aug 2021 02:18:08 +0000 (22:18 -0400)
README
conf.py
model.py

diff --git a/README b/README
index 312e3c3411de67902e784c830c55a16a5e2b2960..d916536f679bc4c1ea5143f60a1d825137926f23 100644 (file)
--- a/README
+++ b/README
@@ -8,8 +8,7 @@ Marianne Bezaire.
 
 The work in this repository was in-part derived from "Effects of Guanfacine and
 Phenylephrine on a Spiking Neuron Model of Working Memory" by Peter Duggins,
-Terry Stewart, Xuan Choo, Chris Eliasmith. Some of those original files are
-preserved in old/.
+Terry Stewart, Xuan Choo, Chris Eliasmith.
 <https://github.com/psipeter/drugs_and_working_memory>
 
 Usage
@@ -32,4 +31,7 @@ Run simulation:
 
        python3 model.py
 
-Results are stored in ./out/.
+An overall progress bar is shown during the simulation, and more detailed
+results are stored in a log file. Logs and charts are stored in ./out/. The
+simulation will be skipped if a file named simulation.pkl is present. It will be
+loaded and used instead of the simulation being run.
diff --git a/conf.py b/conf.py
index b9664989ef3d3b51eecf46e6e93b318393967e41..ed17397bb5712a545f0a6739d20dfa34eb97c30e 100644 (file)
--- a/conf.py
+++ b/conf.py
@@ -1,12 +1,13 @@
 import numpy as np
 
-dt = 0.01              # Time step
+dt = 0.05              # Time step
+probe_dt = 0.1         # Time step to sample from probe
 t_cue = 1.0            # Duration of cue presentation
 t_delay = 8.0          # Duration of delay period between cue and decision
 cue_scale = 1.0        # How strong the cuelus is from the visual system
 misperceive = 0.1      # ???
 time_scale = 0.4       # ???
-steps = np.arange(750) # Steps to use
+steps = np.logspace(0, 3, 500) # Steps to use
 noise_wm = 0.005       # Standard deviation of white noise added to WM
 noise_decision = 0.005 # Standard deviation of white noise added to decision
 neurons_decide = 100   # Number of neurons for decision
@@ -14,3 +15,4 @@ neurons_inputs = 100   # Number of neurons for inputs ensemble
 neurons_wm = 100       # Number of neurons for working memory ensemble
 tau_wm = 0.1           # Synapse on recurrent connection in wm
 tau = 0.01             # Synaptic time constant between ensembles
+seed = 3               # Seed for RNG
index db3c69156877bbc37b53828b891ed987b764cfdc..7feeee2be94f4f04b1bdade82760b5d2bc71710e 100644 (file)
--- a/model.py
+++ b/model.py
@@ -12,6 +12,7 @@ from tqdm import tqdm
 
 exec(open("conf.py").read())
 
+
 def fmt_num(num, width=18):
     """
     Format number to string.
@@ -44,7 +45,7 @@ def decision_function(x):
     elif value < 0.0:
         output = -1.0
     return output
-    #return 1.0 if x[0] + x[1] > 0.0 else -1.0
+    # return 1.0 if x[0] + x[1] > 0.0 else -1.0
 
 
 class Alpha():
@@ -60,8 +61,8 @@ class Alpha():
         self.biass = []
 
         for i in range(len(steps)):
-            self.gains.append(1 + self.gaind * self.y[i])
-            self.biass.append(1 + self.biasd * self.y[i])
+            self.gains.append(self.gaind * self.y[i] + 1)
+            self.biass.append(self.biasd * self.y[i] + 1)
 
     def plot(self):
         out = f"./out/{self.__class__.__name__}"
@@ -154,14 +155,14 @@ class Simulation():
         self.a1 = Alpha1()
         self.a2 = Alpha2()
         self.num_spikes = np.ones(len(steps))
-        self.biass = np.ones(len(steps))
-        self.gains = np.ones(len(steps))
         self.out = np.ones(3)
         self.trial = 0
 
-        self.perceived = np.ones(3)  # correctly perceived (not necessarily remembered) cues
-        rng=np.random.RandomState(seed=111)
-        self.cues=2 * rng.randint(2, size=3)-1  # whether the cues is on the left or right
+        # correctly perceived (not necessarily remembered) cues
+        self.perceived = np.ones(3)
+        rng = np.random.RandomState(seed=seed)
+        # whether the cues is on the left or right
+        self.cues = 2 * rng.randint(2, size=3)-1
         for n in range(len(self.perceived)):
             if rng.rand() < misperceive:
                 self.perceived[n] = 0
@@ -172,7 +173,7 @@ class Simulation():
         plt.figure()
         plt.plot(steps, self.num_spikes)
 
-        plt.xscale("log")
+        #plt.xscale("log")
 
         plt.xlabel("Norepinephrine concentration (nM)")
         plt.ylabel("Spiking rate (spikes/time step)")
@@ -188,6 +189,9 @@ class Simulation():
             return 0
 
     def run(self):
+        self.a1.plot()
+        self.a2.plot()
+
         with nengo.Network() as net:
             # Nodes
             cue_node = nengo.Node(output=self.cue_function)
@@ -217,30 +221,55 @@ class Simulation():
             nengo.Connection(decision, output, function=decision_function)
 
             # Probes
-            probes_wm = nengo.Probe(wm[0], synapse=0.01)
-            probe_spikes = nengo.Probe(wm.neurons)
-            probe_output = nengo.Probe(output, synapse=None)
+            probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
+            probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt)
+            probe_output = nengo.Probe(
+                output, synapse=None, sample_every=probe_dt)
 
             # Run simulation
             with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
                 for i, _ in tqdm(enumerate(steps), total=len(steps)):
                     wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
-                    wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].gain
+                    wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
+                    wm_recurrent.solver = MySolver(
+                        sim.model.params[wm_recurrent].weights)
+                    wm_to_decision.solver = MySolver(
+                        sim.model.params[wm_to_decision].weights)
+                    sim = nengo.Simulator(net, dt=dt, progress_bar=False)
                     for self.trial in range(3):
-                        logging.debug(f"Simulating: trial: {self.trial}, gain: {fmt_num(self.gains[i])}, bias: {fmt_num(self.biass[i])}")
+                        logging.info(
+                            f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
                         sim.run(t_cue + t_delay)
-                        self.out[self.trial] = np.count_nonzero(sim.data[probe_spikes])
+                        self.out[self.trial] = np.count_nonzero(
+                            sim.data[probe_spikes])
                     self.num_spikes[i] = np.average(self.out)
 
         with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
-            pickle.dump(self.num_spikes, pout)
+            pickle.dump(self, pout)
 
         self.plot()
 
+
+class MySolver(nengo.solvers.Solver):
+    def __init__(self, weights):
+        self.weights = False
+        self.my_weights = weights
+        self._paramdict = dict()
+
+    def __call__(self, A, Y, rng=None, E=None):
+        return self.my_weights.T, dict()
+
+
 def main():
     logging.info("Initializing simulation")
     plt.style.use("ggplot")  # Nice looking and familiar style
-    Simulation().run()
+
+    try:
+        data = open("simulation.pkl", "rb")
+    except FileNotFoundError:
+        Simulation().run()
+    else:
+        pickle.load(data).plot()
 
 
 if __name__ == "__main__":
@@ -250,6 +279,6 @@ if __name__ == "__main__":
         pass
 
     logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
-                        level=logging.DEBUG)
+                        level=logging.INFO)
 
     main()