exec(open("conf.py").read())
+
def fmt_num(num, width=18):
"""
Format number to string.
elif value < 0.0:
output = -1.0
return output
- #return 1.0 if x[0] + x[1] > 0.0 else -1.0
+ # return 1.0 if x[0] + x[1] > 0.0 else -1.0
class Alpha():
self.biass = []
for i in range(len(steps)):
- self.gains.append(1 + self.gaind * self.y[i])
- self.biass.append(1 + self.biasd * self.y[i])
+ self.gains.append(self.gaind * self.y[i] + 1)
+ self.biass.append(self.biasd * self.y[i] + 1)
def plot(self):
out = f"./out/{self.__class__.__name__}"
self.a1 = Alpha1()
self.a2 = Alpha2()
self.num_spikes = np.ones(len(steps))
- self.biass = np.ones(len(steps))
- self.gains = np.ones(len(steps))
self.out = np.ones(3)
self.trial = 0
- self.perceived = np.ones(3) # correctly perceived (not necessarily remembered) cues
- rng=np.random.RandomState(seed=111)
- self.cues=2 * rng.randint(2, size=3)-1 # whether the cues is on the left or right
+ # correctly perceived (not necessarily remembered) cues
+ self.perceived = np.ones(3)
+ rng = np.random.RandomState(seed=seed)
+ # whether the cues is on the left or right
+ self.cues = 2 * rng.randint(2, size=3)-1
for n in range(len(self.perceived)):
if rng.rand() < misperceive:
self.perceived[n] = 0
plt.figure()
plt.plot(steps, self.num_spikes)
- plt.xscale("log")
+ #plt.xscale("log")
plt.xlabel("Norepinephrine concentration (nM)")
plt.ylabel("Spiking rate (spikes/time step)")
return 0
def run(self):
+ self.a1.plot()
+ self.a2.plot()
+
with nengo.Network() as net:
# Nodes
cue_node = nengo.Node(output=self.cue_function)
nengo.Connection(decision, output, function=decision_function)
# Probes
- probes_wm = nengo.Probe(wm[0], synapse=0.01)
- probe_spikes = nengo.Probe(wm.neurons)
- probe_output = nengo.Probe(output, synapse=None)
+ probes_wm = nengo.Probe(wm[0], synapse=0.01, sample_every=probe_dt)
+ probe_spikes = nengo.Probe(wm.neurons, sample_every=probe_dt)
+ probe_output = nengo.Probe(
+ output, synapse=None, sample_every=probe_dt)
# Run simulation
with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
for i, _ in tqdm(enumerate(steps), total=len(steps)):
wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
- wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].gain
+ wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
+ wm_recurrent.solver = MySolver(
+ sim.model.params[wm_recurrent].weights)
+ wm_to_decision.solver = MySolver(
+ sim.model.params[wm_to_decision].weights)
+ sim = nengo.Simulator(net, dt=dt, progress_bar=False)
for self.trial in range(3):
- logging.debug(f"Simulating: trial: {self.trial}, gain: {fmt_num(self.gains[i])}, bias: {fmt_num(self.biass[i])}")
+ logging.info(
+ f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
sim.run(t_cue + t_delay)
- self.out[self.trial] = np.count_nonzero(sim.data[probe_spikes])
+ self.out[self.trial] = np.count_nonzero(
+ sim.data[probe_spikes])
self.num_spikes[i] = np.average(self.out)
with open(f"out/{datetime.now().isoformat()}-spikes.pkl", "wb") as pout:
- pickle.dump(self.num_spikes, pout)
+ pickle.dump(self, pout)
self.plot()
+
+class MySolver(nengo.solvers.Solver):
+ def __init__(self, weights):
+ self.weights = False
+ self.my_weights = weights
+ self._paramdict = dict()
+
+ def __call__(self, A, Y, rng=None, E=None):
+ return self.my_weights.T, dict()
+
+
def main():
logging.info("Initializing simulation")
plt.style.use("ggplot") # Nice looking and familiar style
- Simulation().run()
+
+ try:
+ data = open("simulation.pkl", "rb")
+ except FileNotFoundError:
+ Simulation().run()
+ else:
+ pickle.load(data).plot()
if __name__ == "__main__":
pass
logging.basicConfig(filename=f"out/{datetime.now().isoformat()}.log",
- level=logging.DEBUG)
+ level=logging.INFO)
main()