From: Armaan Bhojwani Date: Mon, 9 Aug 2021 22:04:29 +0000 (-0400) Subject: Factor simulation into its own class X-Git-Url: https://git.armaanb.net/?p=norepinephrine_wm.git;a=commitdiff_plain;h=864e1ff852299ff175ebc67ee48766d565c4bb29 Factor simulation into its own class --- diff --git a/model.py b/model.py index cd60022..0ea66c6 100644 --- a/model.py +++ b/model.py @@ -38,7 +38,7 @@ def decision_function(x): return 1.0 if x[0] + x[1] > 0.0 else -1.0 -class Alpha(object): +class Alpha(): """ Base class for alpha receptors. Not to be used directly. """ @@ -58,7 +58,8 @@ class Alpha(object): def plot(self): out = f"./out/{self.__class__.__name__}" - title = "Norepinepherine Concentration vs Neuron Activity in " + self.pretty + title = "Norepinepherine Concentration vs Neuron Activity in " + \ + self.pretty logging.info("Plotting " + title) plt.figure() plt.plot(self.x, self.y) @@ -140,59 +141,58 @@ class Alpha2(Alpha): super().__init__() -def simulate(a1, a2): - for i in range(steps): - gain = a1.gains[i] + a2.gains[i] - 1 - bias = a1.biass[i] + a2.biass[i] - 1 - logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}") - with nengo.Network() as net: - # Nodes - time_node = nengo.Node(output=time_function) - noise_wm_node = nengo.Node(output=noise_bias_function) - noise_decision_node = nengo.Node( - output=noise_decision_function) - - # Ensembles - wm = nengo.Ensemble(neurons_wm, 2) - wm.gain = np.full(wm.n_neurons, gain) - wm.bias = np.full(wm.n_neurons, bias) - decision = nengo.Ensemble(neurons_decide, 2) - inputs = nengo.Ensemble(neurons_inputs, 2) - output = nengo.Ensemble(neurons_decide, 1) - - # Connections - nengo.Connection(time_node, inputs[1], synapse=None) - nengo.Connection(inputs, wm, synapse=tau_wm, - function=inputs_function) - wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm) - nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm, - transform=np.ones((neurons_wm, 1)) * tau_wm) - wm_to_decision = nengo.Connection( - wm[0], decision[0], synapse=tau) - nengo.Connection(noise_decision_node, - decision[1], synapse=None) - nengo.Connection(decision, output, function=decision_function) - - # Probes - probes_wm = nengo.Probe(wm[0], synapse=0.01) - probe_output = nengo.Probe(output, synapse=None) - - # Run simulation - with nengo.Simulator(net, dt=dt, progress_bar=False) as sim: - sim.run(t_cue + t_delay) +class Simulation(): + def __init__(self): + self.a1 = Alpha1() + self.a2 = Alpha2() + + def run(self): + for i in range(steps): + gain = self.a1.gains[i] + self.a2.gains[i] - 1 + bias = self.a1.biass[i] + self.a2.biass[i] - 1 + logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}") + + with nengo.Network() as net: + # Nodes + time_node = nengo.Node(output=time_function) + noise_wm_node = nengo.Node(output=noise_bias_function) + noise_decision_node = nengo.Node( + output=noise_decision_function) + + # Ensembles + wm = nengo.Ensemble(neurons_wm, 2) + wm.gain = np.full(wm.n_neurons, gain) + wm.bias = np.full(wm.n_neurons, bias) + decision = nengo.Ensemble(neurons_decide, 2) + inputs = nengo.Ensemble(neurons_inputs, 2) + output = nengo.Ensemble(neurons_decide, 1) + + # Connections + nengo.Connection(time_node, inputs[1], synapse=None) + nengo.Connection(inputs, wm, synapse=tau_wm, + function=inputs_function) + wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm) + nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm, + transform=np.ones((neurons_wm, 1)) * tau_wm) + wm_to_decision = nengo.Connection( + wm[0], decision[0], synapse=tau) + nengo.Connection(noise_decision_node, + decision[1], synapse=None) + nengo.Connection(decision, output, function=decision_function) + + # Probes + probes_wm = nengo.Probe(wm[0], synapse=0.01) + probe_output = nengo.Probe(output, synapse=None) + + # Run simulation + with nengo.Simulator(net, dt=dt, progress_bar=False) as sim: + sim.run(t_cue + t_delay) def main(): logging.info("Initializing simulation") plt.style.use("ggplot") # Nice looking and familiar style - - a1 = Alpha1() - a1.plot() - - a2 = Alpha2() - a2.plot() - - simulate(a1, a2) + Simulation().run() if __name__ == "__main__":