]> git.armaanb.net Git - norepinephrine_wm.git/commitdiff
Factor simulation into its own class
authorArmaan Bhojwani <me@armaanb.net>
Mon, 9 Aug 2021 22:04:29 +0000 (18:04 -0400)
committerArmaan Bhojwani <me@armaanb.net>
Mon, 9 Aug 2021 22:36:05 +0000 (18:36 -0400)
model.py

index cd6002207c20f0c1f5b6246163964ce490c9e937..0ea66c629cfde8c0818d0cc5db8bd4f97fe5862d 100644 (file)
--- a/model.py
+++ b/model.py
@@ -38,7 +38,7 @@ def decision_function(x):
     return 1.0 if x[0] + x[1] > 0.0 else -1.0
 
 
-class Alpha(object):
+class Alpha():
     """
     Base class for alpha receptors. Not to be used directly.
     """
@@ -58,7 +58,8 @@ class Alpha(object):
     def plot(self):
         out = f"./out/{self.__class__.__name__}"
 
-        title = "Norepinepherine Concentration vs Neuron Activity in " + self.pretty
+        title = "Norepinepherine Concentration vs Neuron Activity in " + \
+            self.pretty
         logging.info("Plotting " + title)
         plt.figure()
         plt.plot(self.x, self.y)
@@ -140,59 +141,58 @@ class Alpha2(Alpha):
         super().__init__()
 
 
-def simulate(a1, a2):
-    for i in range(steps):
-        gain = a1.gains[i] + a2.gains[i] - 1
-        bias = a1.biass[i] + a2.biass[i] - 1
-        logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}")
-        with nengo.Network() as net:
-            # Nodes
-            time_node = nengo.Node(output=time_function)
-            noise_wm_node = nengo.Node(output=noise_bias_function)
-            noise_decision_node = nengo.Node(
-                output=noise_decision_function)
-
-            # Ensembles
-            wm = nengo.Ensemble(neurons_wm, 2)
-            wm.gain = np.full(wm.n_neurons, gain)
-            wm.bias = np.full(wm.n_neurons, bias)
-            decision = nengo.Ensemble(neurons_decide, 2)
-            inputs = nengo.Ensemble(neurons_inputs, 2)
-            output = nengo.Ensemble(neurons_decide, 1)
-
-            # Connections
-            nengo.Connection(time_node, inputs[1], synapse=None)
-            nengo.Connection(inputs, wm, synapse=tau_wm,
-                             function=inputs_function)
-            wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
-            nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
-                             transform=np.ones((neurons_wm, 1)) * tau_wm)
-            wm_to_decision = nengo.Connection(
-                wm[0], decision[0], synapse=tau)
-            nengo.Connection(noise_decision_node,
-                             decision[1], synapse=None)
-            nengo.Connection(decision, output, function=decision_function)
-
-            # Probes
-            probes_wm = nengo.Probe(wm[0], synapse=0.01)
-            probe_output = nengo.Probe(output, synapse=None)
-
-        # Run simulation
-        with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
-            sim.run(t_cue + t_delay)
+class Simulation():
+    def __init__(self):
+        self.a1 = Alpha1()
+        self.a2 = Alpha2()
+
+    def run(self):
+        for i in range(steps):
+            gain = self.a1.gains[i] + self.a2.gains[i] - 1
+            bias = self.a1.biass[i] + self.a2.biass[i] - 1
+            logging.info(f"gain: {fmt_num(gain)}, bias: {fmt_num(bias)}")
+
+            with nengo.Network() as net:
+                # Nodes
+                time_node = nengo.Node(output=time_function)
+                noise_wm_node = nengo.Node(output=noise_bias_function)
+                noise_decision_node = nengo.Node(
+                    output=noise_decision_function)
+
+                # Ensembles
+                wm = nengo.Ensemble(neurons_wm, 2)
+                wm.gain = np.full(wm.n_neurons, gain)
+                wm.bias = np.full(wm.n_neurons, bias)
+                decision = nengo.Ensemble(neurons_decide, 2)
+                inputs = nengo.Ensemble(neurons_inputs, 2)
+                output = nengo.Ensemble(neurons_decide, 1)
+
+                # Connections
+                nengo.Connection(time_node, inputs[1], synapse=None)
+                nengo.Connection(inputs, wm, synapse=tau_wm,
+                                 function=inputs_function)
+                wm_recurrent = nengo.Connection(wm, wm, synapse=tau_wm)
+                nengo.Connection(noise_wm_node, wm.neurons, synapse=tau_wm,
+                                 transform=np.ones((neurons_wm, 1)) * tau_wm)
+                wm_to_decision = nengo.Connection(
+                    wm[0], decision[0], synapse=tau)
+                nengo.Connection(noise_decision_node,
+                                 decision[1], synapse=None)
+                nengo.Connection(decision, output, function=decision_function)
+
+                # Probes
+                probes_wm = nengo.Probe(wm[0], synapse=0.01)
+                probe_output = nengo.Probe(output, synapse=None)
+
+            # Run simulation
+            with nengo.Simulator(net, dt=dt, progress_bar=False) as sim:
+                sim.run(t_cue + t_delay)
 
 
 def main():
     logging.info("Initializing simulation")
     plt.style.use("ggplot")  # Nice looking and familiar style
-
-    a1 = Alpha1()
-    a1.plot()
-
-    a2 = Alpha2()
-    a2.plot()
-
-    simulate(a1, a2)
+    Simulation().run()
 
 
 if __name__ == "__main__":