From 5a3641109d4da2de801586db3c5265f0b8cf08d7 Mon Sep 17 00:00:00 2001 From: Armaan Bhojwani Date: Thu, 12 Aug 2021 15:03:57 -0400 Subject: [PATCH] final update --- conf.py | 5 ++- model.py | 120 +++++++++++++++++++++++++++++++------------------------ 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/conf.py b/conf.py index ed17397..557d729 100644 --- a/conf.py +++ b/conf.py @@ -7,7 +7,7 @@ t_delay = 8.0 # Duration of delay period between cue and decision cue_scale = 1.0 # How strong the cuelus is from the visual system misperceive = 0.1 # ??? time_scale = 0.4 # ??? -steps = np.logspace(0, 3, 500) # Steps to use +steps = np.linspace(0, 750, 250) # Steps to use noise_wm = 0.005 # Standard deviation of white noise added to WM noise_decision = 0.005 # Standard deviation of white noise added to decision neurons_decide = 100 # Number of neurons for decision @@ -15,4 +15,5 @@ neurons_inputs = 100 # Number of neurons for inputs ensemble neurons_wm = 100 # Number of neurons for working memory ensemble tau_wm = 0.1 # Synapse on recurrent connection in wm tau = 0.01 # Synaptic time constant between ensembles -seed = 3 # Seed for RNG +seed = 1 # Seed for RNG +n_trials = 1 diff --git a/model.py b/model.py index 7feeee2..ec89294 100644 --- a/model.py +++ b/model.py @@ -91,35 +91,21 @@ class Alpha(): ####################################################################### - title = "Concentration vs Gain Scalar in" + self.pretty + title = "Concentration vs Gain/Bias scalar in " + self.pretty logging.info("Plotting " + title) plt.figure() - plt.plot(self.x, self.gains) - - plt.xlabel("Norepinephrine concentration (nM)") - plt.ylabel("Gain") - plt.title(title) - - plt.xscale("log") - - plt.draw() - plt.savefig(f"{out}-concentration-gain.png") - - ####################################################################### - - title = "Concentration vs Bias scalar in " + self.pretty - logging.info("Plotting " + title) - plt.figure() - plt.plot(self.x, self.biass) + plt.plot(self.x, self.biass, label="Bias scalar") + plt.plot(self.x, self.gains, label="Gain scalar") plt.xscale("log") plt.xlabel("Norepinephrine concentration (nM)") - plt.ylabel("Bias") + plt.ylabel("Level") plt.title(title) + plt.legend() plt.draw() - plt.savefig(f"{out}-concentration-bias.png") + plt.savefig(f"{out}-concentration-bias-gains.png") class Alpha1(Alpha): @@ -130,9 +116,11 @@ class Alpha1(Alpha): def __init__(self): self.ki = 330 self.offset = 5.895 - self.pretty = "α1 Receptor" - self.gaind = -0.02 - self.biasd = 0.04 + self.pretty = "α1" + #self.gaind = -0.02 + self.gaind = -0.1 + #self.biasd = 0.04 + self.biasd = 0.1 super().__init__() @@ -144,7 +132,7 @@ class Alpha2(Alpha): def __init__(self): self.ki = 56 self.offset = 1 - self.pretty = "α2 Receptor" + self.pretty = "α2" self.gaind = 0.1 self.biasd = -0.1 super().__init__() @@ -153,16 +141,20 @@ class Alpha2(Alpha): class Simulation(): def __init__(self): self.a1 = Alpha1() + self.a1.plot() self.a2 = Alpha2() - self.num_spikes = np.ones(len(steps)) - self.out = np.ones(3) + self.a2.plot() + + self.num_spikes = np.zeros(len(steps)) + self.num_correct = np.zeros(len(steps)) + self.out = np.zeros(n_trials) self.trial = 0 # correctly perceived (not necessarily remembered) cues - self.perceived = np.ones(3) + self.perceived = np.ones(n_trials) rng = np.random.RandomState(seed=seed) # whether the cues is on the left or right - self.cues = 2 * rng.randint(2, size=3)-1 + self.cues = 2 * rng.randint(2, size=n_trials)-1 for n in range(len(self.perceived)): if rng.rand() < misperceive: self.perceived[n] = 0 @@ -173,8 +165,6 @@ class Simulation(): plt.figure() plt.plot(steps, self.num_spikes) - #plt.xscale("log") - plt.xlabel("Norepinephrine concentration (nM)") plt.ylabel("Spiking rate (spikes/time step)") plt.title(title) @@ -182,6 +172,21 @@ class Simulation(): plt.draw() plt.savefig("./out/concentration-spiking.png") + ######################################################################## + + title = "Norepinephrine Concentration vs Accuracy" + logging.info("Plotting " + title) + plt.figure() + correct_df = pd.DataFrame(np.clip(self.num_correct, 0.5, 1.0)).rolling(20).mean() + plt.plot(steps, correct_df) + + plt.xlabel("Norepinephrine concentration (nM)") + plt.ylabel("Accuracy") + plt.title(title) + + plt.draw() + plt.savefig("./out/concentration-correct.png") + def cue_function(self, t): if t < t_cue and self.perceived[self.trial] != 0: return cue_scale * self.cues[self.trial] @@ -189,9 +194,6 @@ class Simulation(): return 0 def run(self): - self.a1.plot() - self.a2.plot() - with nengo.Network() as net: # Nodes cue_node = nengo.Node(output=self.cue_function) @@ -221,40 +223,52 @@ class Simulation(): nengo.Connection(decision, output, function=decision_function) # Probes - probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt) - probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt) - probe_output = nengo.Probe( + wm_probe = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt) + spikes_probe = nengo.Probe(wm.neurons, sample_every=probe_dt) + output_probe = nengo.Probe( output, synapse=None, sample_every=probe_dt) # Run simulation - with nengo.Simulator(net, dt=dt, progress_bar=False) as sim: - for i, _ in tqdm(enumerate(steps), total=len(steps)): - wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain - wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias - wm_recurrent.solver = MySolver( - sim.model.params[wm_recurrent].weights) - wm_to_decision.solver = MySolver( - sim.model.params[wm_to_decision].weights) - sim = nengo.Simulator(net, dt=dt, progress_bar=False) - for self.trial in range(3): - logging.info( - f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}") - sim.run(t_cue + t_delay) - self.out[self.trial] = np.count_nonzero( - sim.data[probe_spikes]) - self.num_spikes[i] = np.average(self.out) + for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"): + sim = nengo.Simulator(net, dt=dt, progress_bar=False) + wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain + wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias + wm_recurrent.solver = MySolver( + sim.model.params[wm_recurrent].weights) + wm_to_decision.solver = MySolver( + sim.model.params[wm_to_decision].weights) + sim = nengo.Simulator(net, dt=dt, progress_bar=False) + for self.trial in range(n_trials): + logging.info( + f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}") + sim.run(t_cue + t_delay) + + # Firing rate + self.out[self.trial] = np.count_nonzero( + sim.data[spikes_probe]) + + cue = self.cues[self.trial] + # Correctness + out = sim.data[output_probe][int(t_cue + t_delay)][0] + if (out * cue) > 0: # check if same sign + self.num_correct[i] += np.abs(1 / (out - cue)) + + self.num_spikes[i] = np.average(self.out) with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout: pickle.dump(self, pout) self.plot() +def get_correct(cue, output_value): + return 1 if (cue > 0.0 and output_value > 0.0) or (cue < 0.0 and output_value < 0.0) else 0 + class MySolver(nengo.solvers.Solver): def __init__(self, weights): self.weights = False self.my_weights = weights - self._paramdict = dict() + self._paramdict = {} def __call__(self, A, Y, rng=None, E=None): return self.my_weights.T, dict() -- 2.39.2