From 5aa6fb93f367d0ed8ae6d9947aa3ff0ed7e209cf Mon Sep 17 00:00:00 2001 From: Armaan Bhojwani Date: Tue, 10 Aug 2021 22:18:08 -0400 Subject: [PATCH] update --- README | 8 ++++--- conf.py | 6 ++++-- model.py | 65 ++++++++++++++++++++++++++++++++++++++++---------------- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/README b/README index 312e3c3..d916536 100644 --- a/README +++ b/README @@ -8,8 +8,7 @@ Marianne Bezaire. The work in this repository was in-part derived from "Effects of Guanfacine and Phenylephrine on a Spiking Neuron Model of Working Memory" by Peter Duggins, -Terry Stewart, Xuan Choo, Chris Eliasmith. Some of those original files are -preserved in old/. +Terry Stewart, Xuan Choo, Chris Eliasmith. Usage @@ -32,4 +31,7 @@ Run simulation: python3 model.py -Results are stored in ./out/. +An overall progress bar is shown during the simulation, and more detailed +results are stored in a log file. Logs and charts are stored in ./out/. The +simulation will be skipped if a file named simulation.pkl is present. It will be +loaded and used instead of the simulation being run. diff --git a/conf.py b/conf.py index b966498..ed17397 100644 --- a/conf.py +++ b/conf.py @@ -1,12 +1,13 @@ import numpy as np -dt = 0.01 # Time step +dt = 0.05 # Time step +probe_dt = 0.1 # Time step to sample from probe t_cue = 1.0 # Duration of cue presentation t_delay = 8.0 # Duration of delay period between cue and decision cue_scale = 1.0 # How strong the cuelus is from the visual system misperceive = 0.1 # ??? time_scale = 0.4 # ??? -steps = np.arange(750) # Steps to use +steps = np.logspace(0, 3, 500) # Steps to use noise_wm = 0.005 # Standard deviation of white noise added to WM noise_decision = 0.005 # Standard deviation of white noise added to decision neurons_decide = 100 # Number of neurons for decision @@ -14,3 +15,4 @@ neurons_inputs = 100 # Number of neurons for inputs ensemble neurons_wm = 100 # Number of neurons for working memory ensemble tau_wm = 0.1 # Synapse on recurrent connection in wm tau = 0.01 # Synaptic time constant between ensembles +seed = 3 # Seed for RNG diff --git a/model.py b/model.py index db3c691..7feeee2 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,7 @@ from tqdm import tqdm exec(open("conf.py").read()) + def fmt_num(num, width=18): """ Format number to string. @@ -44,7 +45,7 @@ def decision_function(x): elif value < 0.0: output = -1.0 return output - #return 1.0 if x[0] + x[1] > 0.0 else -1.0 + # return 1.0 if x[0] + x[1] > 0.0 else -1.0 class Alpha(): @@ -60,8 +61,8 @@ class Alpha(): self.biass = [] for i in range(len(steps)): - self.gains.append(1 + self.gaind * self.y[i]) - self.biass.append(1 + self.biasd * self.y[i]) + self.gains.append(self.gaind * self.y[i] + 1) + self.biass.append(self.biasd * self.y[i] + 1) def plot(self): out = f"./out/{self.__class__.__name__}" @@ -154,14 +155,14 @@ class Simulation(): self.a1 = Alpha1() self.a2 = Alpha2() self.num_spikes = np.ones(len(steps)) - self.biass = np.ones(len(steps)) - self.gains = np.ones(len(steps)) self.out = np.ones(3) self.trial = 0 - self.perceived = np.ones(3) # correctly perceived (not necessarily remembered) cues - rng=np.random.RandomState(seed=111) - self.cues=2 * rng.randint(2, size=3)-1 # whether the cues is on the left or right + # correctly perceived (not necessarily remembered) cues + self.perceived = np.ones(3) + rng = np.random.RandomState(seed=seed) + # whether the cues is on the left or right + self.cues = 2 * rng.randint(2, size=3)-1 for n in range(len(self.perceived)): if rng.rand() < misperceive: self.perceived[n] = 0 @@ -172,7 +173,7 @@ class Simulation(): plt.figure() plt.plot(steps, self.num_spikes) - plt.xscale("log") + #plt.xscale("log") plt.xlabel("Norepinephrine concentration (nM)") plt.ylabel("Spiking rate (spikes/time step)") @@ -188,6 +189,9 @@ class Simulation(): return 0 def run(self): + self.a1.plot() + self.a2.plot() + with nengo.Network() as net: # Nodes cue_node = nengo.Node(output=self.cue_function) @@ -217,30 +221,55 @@ class Simulation(): nengo.Connection(decision, output, function=decision_function) # Probes - probes_wm = nengo.Probe(wm[0], synapse=0.01) - probe_spikes = nengo.Probe(wm.neurons) - probe_output = nengo.Probe(output, synapse=None) + probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt) + probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt) + probe_output = nengo.Probe( + output, synapse=None, sample_every=probe_dt) # Run simulation with nengo.Simulator(net, dt=dt, progress_bar=False) as sim: for i, _ in tqdm(enumerate(steps), total=len(steps)): wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain - wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].gain + wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias + wm_recurrent.solver = MySolver( + sim.model.params[wm_recurrent].weights) + wm_to_decision.solver = MySolver( + sim.model.params[wm_to_decision].weights) + sim = nengo.Simulator(net, dt=dt, progress_bar=False) for self.trial in range(3): - logging.debug(f"Simulating: trial: {self.trial}, gain: {fmt_num(self.gains[i])}, bias: {fmt_num(self.biass[i])}") + logging.info( + f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}") sim.run(t_cue + t_delay) - self.out[self.trial] = np.count_nonzero(sim.data[probe_spikes]) + self.out[self.trial] = np.count_nonzero( + sim.data[probe_spikes]) self.num_spikes[i] = np.average(self.out) with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout: - pickle.dump(self.num_spikes, pout) + pickle.dump(self, pout) self.plot() + +class MySolver(nengo.solvers.Solver): + def __init__(self, weights): + self.weights = False + self.my_weights = weights + self._paramdict = dict() + + def __call__(self, A, Y, rng=None, E=None): + return self.my_weights.T, dict() + + def main(): logging.info("Initializing simulation") plt.style.use("ggplot") # Nice looking and familiar style - Simulation().run() + + try: + data = open("simulation.pkl", "rb") + except FileNotFoundError: + Simulation().run() + else: + pickle.load(data).plot() if __name__ == "__main__": @@ -250,6 +279,6 @@ if __name__ == "__main__": pass logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log", - level=logging.DEBUG) + level=logging.INFO) main() -- 2.39.2