- with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
- for i, _ in tqdm(enumerate(steps), total=len(steps)):
- wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
- wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
- wm_recurrent.solver = MySolver(
- sim.model.params[wm_recurrent].weights)
- wm_to_decision.solver = MySolver(
- sim.model.params[wm_to_decision].weights)
- sim = nengo.Simulator(net, dt=dt, progress_bar=False)
- for self.trial in range(3):
- logging.info(
- f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
- sim.run(t_cue + t_delay)
- self.out[self.trial] = np.count_nonzero(
- sim.data[probe_spikes])
- self.num_spikes[i] = np.average(self.out)
+ for i, _ in tqdm(enumerate(steps), total=len(steps), unit="step"):
+ sim = nengo.Simulator(net, dt=dt, progress_bar=False)
+ wm.gain = (self.a1.gains[i] + self.a2.gains[i]) * sim.data[wm].gain
+ wm.bias = (self.a1.biass[i] + self.a2.biass[i]) * sim.data[wm].bias
+ wm_recurrent.solver = MySolver(
+ sim.model.params[wm_recurrent].weights)
+ wm_to_decision.solver = MySolver(
+ sim.model.params[wm_to_decision].weights)
+ sim = nengo.Simulator(net, dt=dt, progress_bar=False)
+ for self.trial in range(n_trials):
+ logging.info(
+ f"Simulating: trial: {self.trial}, gain: {fmt_num(wm.gain)}, bias: {fmt_num(wm.bias)}")
+ sim.run(t_cue + t_delay)
+
+ # Firing rate
+ self.out[self.trial] = np.count_nonzero(
+ sim.data[spikes_probe])
+
+ cue = self.cues[self.trial]
+ # Correctness
+ out = sim.data[output_probe][int(t_cue + t_delay)][0]
+ if (out * cue) > 0: # check if same sign
+ self.num_correct[i] += np.abs(1 / (out - cue))
+
+ self.num_spikes[i] = np.average(self.out)