X-Git-Url: https://git.armaanb.net/?p=norepinephrine_wm.git;a=blobdiff_plain;f=model.py;h=ec89294142909c93a4ad2da27ed6ad8961a42661;hp=cd6002207c20f0c1f5b6246163964ce490c9e937;hb=HEAD;hpb=c1c79514649a54ce6af82f30a2bd3e1a5401f11d diff --git a/model.py b/model.py index cd60022..ec89294 100644 --- a/model.py +++ b/model.py @@ -1,11 +1,14 @@ from datetime import datetime from os import mkdir import logging +import pickle import matplotlib.pyplot as plt import matplotlib.ticker as mtick import nengo import numpy as np +import pandas as pd +from tqdm import tqdm exec(open("conf.py").read()) @@ -35,30 +38,37 @@ def time_function(t): def decision_function(x): - return 1.0 if x[0] + x[1] > 0.0 else -1.0 + output = 0.0 + value = x[0] + x[1] + if value > 0.0: + output = 1.0 + elif value < 0.0: + output = -1.0 + return output + # return 1.0 if x[0] + x[1] > 0.0 else -1.0 -class Alpha(object): +class Alpha(): """ Base class for alpha receptors. Not to be used directly. """ def __init__(self): - self.x = np.logspace(0, 3, steps) + self.x = steps self.y = 1 / (1 + (999 * np.exp(-0.1233 * (self.x / self.offset)))) self.gains = [] self.biass = [] - for i in range(steps): - y = self.y[i] - self.gains.append(1 + self.gaind * y) - self.biass.append(1 + self.biasd * y) + for i in range(len(steps)): + self.gains.append(self.gaind * self.y[i] + 1) + self.biass.append(self.biasd * self.y[i] + 1) def plot(self): out = f"./out/{self.__class__.__name__}" - title = "Norepinepherine Concentration vs Neuron Activity in " + self.pretty + title = "Norepinepherine Concentration vs Neuron Activity in " + \ + self.pretty logging.info("Plotting " + title) plt.figure() plt.plot(self.x, self.y) @@ -81,35 +91,21 @@ class Alpha(object): ####################################################################### - title = "Concentration vs Gain Scalar in" + self.pretty + title = "Concentration vs Gain/Bias scalar in " + self.pretty logging.info("Plotting " + title) plt.figure() - plt.plot(self.x, self.gains) - - plt.xlabel("Norepinephrine concentration (nM)") - plt.ylabel("Gain") - plt.title(title) - - plt.xscale("log") - - plt.draw() - plt.savefig(f"{out}-concentration-gain.png") - - ####################################################################### - - title = "Concentration vs Bias scalar in " + self.pretty - logging.info("Plotting " + title) - plt.figure() - plt.plot(self.x, self.biass) + plt.plot(self.x, self.biass, label="Bias scalar") + plt.plot(self.x, self.gains, label="Gain scalar") plt.xscale("log") plt.xlabel("Norepinephrine concentration (nM)") - plt.ylabel("Bias") + plt.ylabel("Level") plt.title(title) + plt.legend() plt.draw() - plt.savefig(f"{out}-concentration-bias.png") + plt.savefig(f"{out}-concentration-bias-gains.png") class Alpha1(Alpha): @@ -120,9 +116,11 @@ class Alpha1(Alpha): def __init__(self): self.ki = 330 self.offset = 5.895 - self.pretty = "α1 Receptor" - self.gaind = -0.02 - self.biasd = 0.04 + self.pretty = "α1" + #self.gaind = -0.02 + self.gaind = -0.1 + #self.biasd = 0.04 + self.biasd = 0.1 super().__init__() @@ -134,19 +132,71 @@ class Alpha2(Alpha): def __init__(self): self.ki = 56 self.offset = 1 - self.pretty = "α2 Receptor" + self.pretty = "α2" self.gaind = 0.1 self.biasd = -0.1 super().__init__() -def simulate(a1, a2): - for i in range(steps): - gain = a1.gains[i] + a2.gains[i] - 1 - bias = a1.biass[i] + a2.biass[i] - 1 - logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}") +class Simulation(): + def __init__(self): + self.a1 = Alpha1() + self.a1.plot() + self.a2 = Alpha2() + self.a2.plot() + + self.num_spikes = np.zeros(len(steps)) + self.num_correct = np.zeros(len(steps)) + self.out = np.zeros(n_trials) + self.trial = 0 + + # correctly perceived (not necessarily remembered) cues + self.perceived = np.ones(n_trials) + rng = np.random.RandomState(seed=seed) + # whether the cues is on the left or right + self.cues = 2 * rng.randint(2, size=n_trials)-1 + for n in range(len(self.perceived)): + if rng.rand() < misperceive: + self.perceived[n] = 0 + + def plot(self): + title = "Norepinephrine Concentration vs Spiking Rate" + logging.info("Plotting " + title) + plt.figure() + plt.plot(steps, self.num_spikes) + + plt.xlabel("Norepinephrine concentration (nM)") + plt.ylabel("Spiking rate (spikes/time step)") + plt.title(title) + + plt.draw() + plt.savefig("./out/concentration-spiking.png") + + ######################################################################## + + title = "Norepinephrine Concentration vs Accuracy" + logging.info("Plotting " + title) + plt.figure() + correct_df = pd.DataFrame(np.clip(self.num_correct, 0.5, 1.0)).rolling(20).mean() + plt.plot(steps, correct_df) + + plt.xlabel("Norepinephrine concentration (nM)") + plt.ylabel("Accuracy") + plt.title(title) + + plt.draw() + plt.savefig("./out/concentration-correct.png") + + def cue_function(self, t): + if t < t_cue and self.perceived[self.trial] != 0: + return cue_scale * self.cues[self.trial] + else: + return 0 + + def run(self): with nengo.Network() as net: # Nodes + cue_node = nengo.Node(output=self.cue_function) time_node = nengo.Node(output=time_function) noise_wm_node = nengo.Node(output=noise_bias_function) noise_decision_node = nengo.Node( @@ -154,13 +204,12 @@ def simulate(a1, a2): # Ensembles wm = nengo.Ensemble(neurons_wm, 2) - wm.gain = np.full(wm.n_neurons, gain) - wm.bias = np.full(wm.n_neurons, bias) decision = nengo.Ensemble(neurons_decide, 2) inputs = nengo.Ensemble(neurons_inputs, 2) output = nengo.Ensemble(neurons_decide, 1) # Connections + nengo.Connection(cue_node, inputs[0], synapse=None) nengo.Connection(time_node, inputs[1], synapse=None) nengo.Connection(inputs, wm, synapse=tau_wm, function=inputs_function) @@ -174,25 +223,67 @@ def simulate(a1, a2): nengo.Connection(decision, output, function=decision_function) # Probes - probes_wm = nengo.Probe(wm[0], synapse=0.01) - probe_output = nengo.Probe(output, synapse=None) + wm_probe = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt) + spikes_probe = nengo.Probe(wm.neurons, sample_every=probe_dt) + output_probe = nengo.Probe( + output, synapse=None, sample_every=probe_dt) + + # Run simulation + for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"): + sim = nengo.Simulator(net, dt=dt, progress_bar=False) + wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain + wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias + wm_recurrent.solver = MySolver( + sim.model.params[wm_recurrent].weights) + wm_to_decision.solver = MySolver( + sim.model.params[wm_to_decision].weights) + sim = nengo.Simulator(net, dt=dt, progress_bar=False) + for self.trial in range(n_trials): + logging.info( + f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}") + sim.run(t_cue + t_delay) - # Run simulation - with nengo.Simulator(net, dt=dt, progress_bar=False) as sim: - sim.run(t_cue + t_delay) + # Firing rate + self.out[self.trial] = np.count_nonzero( + sim.data[spikes_probe]) + + cue = self.cues[self.trial] + # Correctness + out = sim.data[output_probe][int(t_cue + t_delay)][0] + if (out * cue) > 0: # check if same sign + self.num_correct[i] += np.abs(1 / (out - cue)) + + self.num_spikes[i] = np.average(self.out) + + with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout: + pickle.dump(self, pout) + + self.plot() + +def get_correct(cue, output_value): + return 1 if (cue > 0.0 and output_value > 0.0) or (cue < 0.0 and output_value < 0.0) else 0 + + +class MySolver(nengo.solvers.Solver): + def __init__(self, weights): + self.weights = False + self.my_weights = weights + self._paramdict = {} + + def __call__(self, A, Y, rng=None, E=None): + return self.my_weights.T, dict() def main(): logging.info("Initializing simulation") plt.style.use("ggplot") # Nice looking and familiar style - a1 = Alpha1() - a1.plot() - - a2 = Alpha2() - a2.plot() - - simulate(a1, a2) + try: + data = open("simulation.pkl", "rb") + except FileNotFoundError: + Simulation().run() + else: + pickle.load(data).plot() if __name__ == "__main__": @@ -202,9 +293,6 @@ if __name__ == "__main__": pass logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log", - level=logging.DEBUG) - console = logging.StreamHandler() - console.setLevel(logging.INFO) - logging.getLogger("").addHandler(console) + level=logging.INFO) main()