]> git.armaanb.net Git - norepinephrine_wm.git/blobdiff - old/model.py
Restructure repository
[norepinephrine_wm.git] / old / model.py
diff --git a/old/model.py b/old/model.py
new file mode 100644 (file)
index 0000000..a81eb76
--- /dev/null
@@ -0,0 +1,208 @@
+'''
+Peter Duggins, Terry Stewart, Xuan Choo, Chris Eliasmith
+Effects of Guanfacine and Phenylephrine on a Spiking Neuron Model of Working Memory
+June-August 2016
+Main Model File
+'''
+
+def run(params):
+    import nengo
+    from nengo.rc import rc
+    import numpy as np
+    import pandas as pd
+    from helper import reset_gain_bias, primary_dataframe, firing_dataframe
+
+    decision_type=params[0]
+    drug_type=params[1]
+    drug = params[2]
+    trial = params[3]
+    seed = params[4]
+    P = params[5]
+    dt=P['dt']
+    dt_sample=P['dt_sample']
+    t_cue=P['t_cue']
+    t_delay=P['t_delay']
+    drug_effect_neural=P['drug_effect_neural']
+    drug_effect_functional=P['drug_effect_functional']
+    drug_effect_biophysical=P['drug_effect_biophysical']
+    enc_min_cutoff=P['enc_min_cutoff']
+    enc_max_cutoff=P['enc_max_cutoff']
+    sigma_smoothing=P['sigma_smoothing']
+    frac=P['frac']
+    neurons_inputs=P['neurons_inputs']
+    neurons_wm=P['neurons_wm']
+    neurons_decide=P['neurons_decide']
+    time_scale=P['time_scale']
+    cue_scale=P['cue_scale']
+    tau=P['tau']
+    tau_wm=P['tau_wm']
+    noise_wm=P['noise_wm']
+    noise_decision=P['noise_decision']
+    perceived=P['perceived']
+    cues=P['cues']
+
+    if drug_type == 'biophysical': rc.set("decoder_cache", "enabled", "False") #don't try to remember old decoders
+    else: rc.set("decoder_cache", "enabled", "True")
+
+    def cue_function(t):
+        if t < t_cue and perceived[trial]!=0:
+            return cue_scale * cues[trial]
+        else: return 0
+
+    def time_function(t):
+        if t > t_cue:
+            return time_scale
+        else: return 0
+
+    def noise_bias_function(t):
+        import numpy as np
+        if drug_type=='neural':
+            return np.random.normal(drug_effect_neural[drug],noise_wm)
+        else:
+            return np.random.normal(0.0,noise_wm)
+
+    def noise_decision_function(t):
+        import numpy as np
+        if decision_type == 'default':
+            return np.random.normal(0.0,noise_decision)
+        elif decision_type == 'basal_ganglia':
+            return np.random.normal(0.0,noise_decision,size=2)
+
+    def inputs_function(x):
+        return x * tau_wm
+
+    def wm_recurrent_function(x):
+        if drug_type == 'functional':
+            return x * drug_effect_functional[drug]
+        else:
+            return x
+
+    def decision_function(x):
+        output=0.0
+        if decision_type=='default':
+            value=x[0]+x[1]
+            if value > 0.0: output = 1.0
+            elif value < 0.0: output = -1.0
+        elif decision_type=='basal_ganglia':
+            if x[0] > x[1]: output = 1.0
+            elif x[0] < x[1]: output = -1.0
+        return output
+
+    def BG_rescale(x): #rescales -1 to 1 into 0.3 to 1, makes 2-dimensional
+        pos_x = 0.5 * (x + 1)
+        rescaled = 0.4 + 0.6 * pos_x, 0.4 + 0.6 * (1 - pos_x)
+        return rescaled
+
+    '''model definition'''
+    with nengo.Network(seed=seed+trial) as model:
+
+        #Ensembles
+        cue = nengo.Node(output=cue_function)
+        time = nengo.Node(output=time_function)
+        inputs = nengo.Ensemble(neurons_inputs,2)
+        noise_wm_node = nengo.Node(output=noise_bias_function)
+        noise_decision_node = nengo.Node(output=noise_decision_function)
+        wm = nengo.Ensemble(neurons_wm,2)
+        decision = nengo.Ensemble(neurons_decide,2)
+        output = nengo.Ensemble(neurons_decide,1)
+
+        #Connections
+        nengo.Connection(cue,inputs[0],synapse=None)
+        nengo.Connection(time,inputs[1],synapse=None)
+        nengo.Connection(inputs,wm,synapse=tau_wm,function=inputs_function)
+        wm_recurrent=nengo.Connection(wm,wm,synapse=tau_wm,function=wm_recurrent_function)
+        nengo.Connection(noise_wm_node,wm.neurons,synapse=tau_wm,transform=np.ones((neurons_wm,1))*tau_wm)
+        wm_to_decision=nengo.Connection(wm[0],decision[0],synapse=tau)
+        nengo.Connection(noise_decision_node,decision[1],synapse=None)
+        nengo.Connection(decision,output,function=decision_function)
+
+        #Probes
+        probe_wm=nengo.Probe(wm[0],synapse=0.01,sample_every=dt_sample)
+        probe_spikes=nengo.Probe(wm.neurons, 'spikes', sample_every=dt_sample)
+        probe_output=nengo.Probe(output,synapse=None,sample_every=dt_sample)
+
+
+
+
+    '''SIMULATION'''
+    print 'Running drug \"%s\", trial %s...' %(drug,trial+1)
+    with nengo.Simulator(model,dt=dt) as sim:
+        if drug_type == 'biophysical': sim=reset_gain_bias(
+                P,model,sim,wm,wm_recurrent,wm_to_decision,drug)
+        sim.run(t_cue+t_delay)
+        df_primary=primary_dataframe(P,sim,drug,trial,probe_wm,probe_output)
+        df_firing=firing_dataframe(P,sim,drug,trial,sim.data[wm],probe_spikes)
+    return [df_primary, df_firing]
+
+
+
+'''MAIN'''
+def main():
+    import matplotlib.pyplot as plt
+    import seaborn as sns
+    import pandas as pd
+    import numpy as np
+    from helper import make_cues, empirical_dataframe
+    from pathos.helpers import freeze_support #for Windows
+    # import ipdb
+
+    '''Import Parameters from File'''
+    P=eval(open('parameters.txt').read()) #parameter dictionary
+    seed=P['seed'] #sets tuning curves equal to control before drug application
+    n_trials=P['n_trials']
+    drug_type=str(P['drug_type'])
+    decision_type=str(P['decision_type'])
+    drugs=P['drugs']
+    trials, perceived, cues = make_cues(P)
+    P['timesteps']=np.arange(0,int((P['t_cue']+P['t_delay'])/P['dt_sample']))
+    P['cues']=cues
+    P['perceived']=perceived
+
+    '''Multiprocessing'''
+    print "drug_type=%s, decision_type=%s, trials=%s..." %(drug_type,decision_type,n_trials)
+    freeze_support()
+    exp_params=[]
+    for drug in drugs:
+        for trial in trials:
+            exp_params.append([decision_type, drug_type, drug, trial, seed, P])
+    df_list=[run(exp_params[0]),run(exp_params[-1])]
+    primary_dataframe = pd.concat([df_list[i][0] for i in range(len(df_list))], ignore_index=True)
+    firing_dataframe = pd.concat([df_list[i][1] for i in range(len(df_list))], ignore_index=True)
+
+    '''Plot and Export'''
+    print 'Exporting Data...'
+    primary_dataframe.to_pickle('primary_data.pkl')
+    firing_dataframe.to_pickle('firing_data.pkl')
+    param_df=pd.DataFrame([P])
+    param_df.reset_index().to_json('params.json',orient='records')
+
+    print 'Plotting...'
+    emp_dataframe=empirical_dataframe()
+    sns.set(context='poster')
+    figure, (ax1, ax2) = plt.subplots(2, 1)
+    sns.tsplot(time="time",value="wm",data=primary_dataframe,unit="trial",condition='drug',ax=ax1,ci=95)
+    sns.tsplot(time="time",value="correct",data=primary_dataframe,unit="trial",condition='drug',ax=ax2,ci=95)
+    sns.tsplot(time="time",value="accuracy",data=emp_dataframe,unit='trial',condition='drug',
+                interpolate=False,ax=ax2)
+    sns.tsplot(time="time",value="accuracy",data=emp_dataframe, unit='trial',condition='drug',
+                interpolate=True,ax=ax2, legend=False)
+    ax1.set(xlabel='',ylabel='decoded $\hat{cue}$',xlim=(0,9.5),ylim=(0,1),
+                title="drug_type=%s, decision_type=%s, trials=%s" %(drug_type,decision_type,n_trials))
+    ax2.set(xlabel='time (s)',xlim=(0,9.5),ylim=(0.5,1),ylabel='DRT accuracy')
+    figure.savefig('primary_plots.png')
+
+    figure2, (ax3, ax4) = plt.subplots(1, 2)
+    if len(firing_dataframe.query("tuning=='weak'"))>0:
+        sns.tsplot(time="time",value="firing_rate",unit="neuron-trial",condition='drug',ax=ax3,ci=95,
+                data=firing_dataframe.query("tuning=='weak'").reset_index())
+    if len(firing_dataframe.query("tuning=='nonpreferred'"))>0:
+        sns.tsplot(time="time",value="firing_rate",unit="neuron-trial",condition='drug',ax=ax4,ci=95,
+                data=firing_dataframe.query("tuning=='nonpreferred'").reset_index())
+    ax3.set(xlabel='time (s)',xlim=(0.0,9.5),ylim=(0,250),ylabel='Normalized Firing Rate',title='Preferred Direction')
+    ax4.set(xlabel='time (s)',xlim=(0.0,9.5),ylim=(0,250),ylabel='',title='Nonpreferred Direction')
+    figure2.savefig('firing_plots.png')
+
+    plt.show()
+
+if __name__=='__main__':
+    main()
\ No newline at end of file