]> git.armaanb.net Git - norepinephrine_wm.git/commitdiff
final update draft main
authorArmaan Bhojwani <me@armaanb.net>
Thu, 12 Aug 2021 19:03:57 +0000 (15:03 -0400)
committerArmaan Bhojwani <me@armaanb.net>
Thu, 12 Aug 2021 19:03:57 +0000 (15:03 -0400)
conf.py
model.py

diff --git a/conf.py b/conf.py
index ed17397bb5712a545f0a6739d20dfa34eb97c30e..557d729772661b2c61cbf04a15578543e7350ecf 100644 (file)
--- a/conf.py
+++ b/conf.py
@@ -7,7 +7,7 @@ t_delay = 8.0          # Duration of delay period between cue and decision
 cue_scale = 1.0        # How strong the cuelus is from the visual system
 misperceive = 0.1      # ???
 time_scale = 0.4       # ???
-steps = np.logspace(0, 3, 500) # Steps to use
+steps = np.linspace(0, 750, 250) # Steps to use
 noise_wm = 0.005       # Standard deviation of white noise added to WM
 noise_decision = 0.005 # Standard deviation of white noise added to decision
 neurons_decide = 100   # Number of neurons for decision
@@ -15,4 +15,5 @@ neurons_inputs = 100   # Number of neurons for inputs ensemble
 neurons_wm = 100       # Number of neurons for working memory ensemble
 tau_wm = 0.1           # Synapse on recurrent connection in wm
 tau = 0.01             # Synaptic time constant between ensembles
-seed = 3               # Seed for RNG
+seed = 1               # Seed for RNG
+n_trials = 1
index 7feeee2be94f4f04b1bdade82760b5d2bc71710e..ec89294142909c93a4ad2da27ed6ad8961a42661 100644 (file)
--- a/model.py
+++ b/model.py
@@ -91,35 +91,21 @@ class Alpha():
 
         #######################################################################
 
-        title = "Concentration vs Gain Scalar in" + self.pretty
+        title = "Concentration vs Gain/Bias scalar in " + self.pretty
         logging.info("Plotting " + title)
         plt.figure()
-        plt.plot(self.x, self.gains)
-
-        plt.xlabel("Norepinephrine concentration (nM)")
-        plt.ylabel("Gain")
-        plt.title(title)
-
-        plt.xscale("log")
-
-        plt.draw()
-        plt.savefig(f"{out}-concentration-gain.png")
-
-        #######################################################################
-
-        title = "Concentration vs Bias scalar in " + self.pretty
-        logging.info("Plotting " + title)
-        plt.figure()
-        plt.plot(self.x, self.biass)
+        plt.plot(self.x, self.biass, label="Bias scalar")
+        plt.plot(self.x, self.gains, label="Gain scalar")
 
         plt.xscale("log")
 
         plt.xlabel("Norepinephrine concentration (nM)")
-        plt.ylabel("Bias")
+        plt.ylabel("Level")
         plt.title(title)
+        plt.legend()
 
         plt.draw()
-        plt.savefig(f"{out}-concentration-bias.png")
+        plt.savefig(f"{out}-concentration-bias-gains.png")
 
 
 class Alpha1(Alpha):
@@ -130,9 +116,11 @@ class Alpha1(Alpha):
     def __init__(self):
         self.ki = 330
         self.offset = 5.895
-        self.pretty = "α1 Receptor"
-        self.gaind = -0.02
-        self.biasd = 0.04
+        self.pretty = "α1"
+        #self.gaind = -0.02
+        self.gaind = -0.1
+        #self.biasd = 0.04
+        self.biasd = 0.1
         super().__init__()
 
 
@@ -144,7 +132,7 @@ class Alpha2(Alpha):
     def __init__(self):
         self.ki = 56
         self.offset = 1
-        self.pretty = "α2 Receptor"
+        self.pretty = "α2"
         self.gaind = 0.1
         self.biasd = -0.1
         super().__init__()
@@ -153,16 +141,20 @@ class Alpha2(Alpha):
 class Simulation():
     def __init__(self):
         self.a1 = Alpha1()
+        self.a1.plot()
         self.a2 = Alpha2()
-        self.num_spikes = np.ones(len(steps))
-        self.out = np.ones(3)
+        self.a2.plot()
+
+        self.num_spikes = np.zeros(len(steps))
+        self.num_correct = np.zeros(len(steps))
+        self.out = np.zeros(n_trials)
         self.trial = 0
 
         # correctly perceived (not necessarily remembered) cues
-        self.perceived = np.ones(3)
+        self.perceived = np.ones(n_trials)
         rng = np.random.RandomState(seed=seed)
         # whether the cues is on the left or right
-        self.cues = 2 * rng.randint(2, size=3)-1
+        self.cues = 2 * rng.randint(2, size=n_trials)-1
         for n in range(len(self.perceived)):
             if rng.rand() < misperceive:
                 self.perceived[n] = 0
@@ -173,8 +165,6 @@ class Simulation():
         plt.figure()
         plt.plot(steps, self.num_spikes)
 
-        #plt.xscale("log")
-
         plt.xlabel("Norepinephrine concentration (nM)")
         plt.ylabel("Spiking rate (spikes/time step)")
         plt.title(title)
@@ -182,6 +172,21 @@ class Simulation():
         plt.draw()
         plt.savefig("./out/concentration-spiking.png")
 
+        ########################################################################
+
+        title = "Norepinephrine Concentration vs Accuracy"
+        logging.info("Plotting " + title)
+        plt.figure()
+        correct_df = pd.DataFrame(np.clip(self.num_correct, 0.5, 1.0)).rolling(20).mean()
+        plt.plot(steps, correct_df)
+
+        plt.xlabel("Norepinephrine concentration (nM)")
+        plt.ylabel("Accuracy")
+        plt.title(title)
+
+        plt.draw()
+        plt.savefig("./out/concentration-correct.png")
+
     def cue_function(self, t):
         if t < t_cue and self.perceived[self.trial] != 0:
             return cue_scale * self.cues[self.trial]
@@ -189,9 +194,6 @@ class Simulation():
             return 0
 
     def run(self):
-        self.a1.plot()
-        self.a2.plot()
-
         with nengo.Network() as net:
             # Nodes
             cue_node = nengo.Node(output=self.cue_function)
@@ -221,40 +223,52 @@ class Simulation():
             nengo.Connection(decision, output, function=decision_function)
 
             # Probes
-            probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
-            probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt)
-            probe_output = nengo.Probe(
+            wm_probe = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
+            spikes_probe = nengo.Probe(wm.neurons, sample_every=probe_dt)
+            output_probe = nengo.Probe(
                 output, synapse=None, sample_every=probe_dt)
 
             # Run simulation
-            with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
-                for i, _ in tqdm(enumerate(steps), total=len(steps)):
-                    wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
-                    wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
-                    wm_recurrent.solver = MySolver(
-                        sim.model.params[wm_recurrent].weights)
-                    wm_to_decision.solver = MySolver(
-                        sim.model.params[wm_to_decision].weights)
-                    sim = nengo.Simulator(net, dt=dt, progress_bar=False)
-                    for self.trial in range(3):
-                        logging.info(
-                            f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
-                        sim.run(t_cue + t_delay)
-                        self.out[self.trial] = np.count_nonzero(
-                            sim.data[probe_spikes])
-                    self.num_spikes[i] = np.average(self.out)
+            for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"):
+                sim = nengo.Simulator(net, dt=dt, progress_bar=False)
+                wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
+                wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
+                wm_recurrent.solver = MySolver(
+                    sim.model.params[wm_recurrent].weights)
+                wm_to_decision.solver = MySolver(
+                    sim.model.params[wm_to_decision].weights)
+                sim = nengo.Simulator(net, dt=dt, progress_bar=False)
+                for self.trial in range(n_trials):
+                    logging.info(
+                        f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
+                    sim.run(t_cue + t_delay)
+
+                    # Firing rate
+                    self.out[self.trial] = np.count_nonzero(
+                        sim.data[spikes_probe])
+
+                    cue = self.cues[self.trial]
+                    # Correctness
+                    out = sim.data[output_probe][int(t_cue + t_delay)][0]
+                    if (out * cue) > 0:  # check if same sign
+                        self.num_correct[i] += np.abs(1 / (out - cue))
+
+                self.num_spikes[i] = np.average(self.out)
 
         with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
             pickle.dump(self, pout)
 
         self.plot()
 
+def get_correct(cue, output_value):
+    return 1 if (cue > 0.0 and output_value > 0.0) or (cue < 0.0 and output_value < 0.0) else 0
+
 
 class MySolver(nengo.solvers.Solver):
     def __init__(self, weights):
         self.weights = False
         self.my_weights = weights
-        self._paramdict = dict()
+        self._paramdict = {}
 
     def __call__(self, A, Y, rng=None, E=None):
         return self.my_weights.T, dict()