]> git.armaanb.net Git - norepinephrine_wm.git/blob - old/model.py
final update
[norepinephrine_wm.git] / old / model.py
1 '''
2 Peter Duggins, Terry Stewart, Xuan Choo, Chris Eliasmith
3 Effects of Guanfacine and Phenylephrine on a Spiking Neuron Model of Working Memory
4 June-August 2016
5 Main Model File
6 '''
7
8 def run(params):
9     import nengo
10     from nengo.rc import rc
11     import numpy as np
12     import pandas as pd
13     from helper import reset_gain_bias, primary_dataframe, firing_dataframe
14
15     decision_type=params[0]
16     drug_type=params[1]
17     drug = params[2]
18     trial = params[3]
19     seed = params[4]
20     P = params[5]
21     dt=P['dt']
22     dt_sample=P['dt_sample']
23     t_cue=P['t_cue']
24     t_delay=P['t_delay']
25     drug_effect_neural=P['drug_effect_neural']
26     drug_effect_functional=P['drug_effect_functional']
27     drug_effect_biophysical=P['drug_effect_biophysical']
28     enc_min_cutoff=P['enc_min_cutoff']
29     enc_max_cutoff=P['enc_max_cutoff']
30     sigma_smoothing=P['sigma_smoothing']
31     frac=P['frac']
32     neurons_inputs=P['neurons_inputs']
33     neurons_wm=P['neurons_wm']
34     neurons_decide=P['neurons_decide']
35     time_scale=P['time_scale']
36     cue_scale=P['cue_scale']
37     tau=P['tau']
38     tau_wm=P['tau_wm']
39     noise_wm=P['noise_wm']
40     noise_decision=P['noise_decision']
41     perceived=P['perceived']
42     cues=P['cues']
43
44     if drug_type == 'biophysical': rc.set("decoder_cache", "enabled", "False") #don't try to remember old decoders
45     else: rc.set("decoder_cache", "enabled", "True")
46
47     def cue_function(t):
48         if t < t_cue and perceived[trial]!=0:
49             return cue_scale * cues[trial]
50         else: return 0
51
52     def time_function(t):
53         if t > t_cue:
54             return time_scale
55         else: return 0
56
57     def noise_bias_function(t):
58         import numpy as np
59         if drug_type=='neural':
60             return np.random.normal(drug_effect_neural[drug],noise_wm)
61         else:
62             return np.random.normal(0.0,noise_wm)
63
64     def noise_decision_function(t):
65         import numpy as np
66         if decision_type == 'default':
67             return np.random.normal(0.0,noise_decision)
68         elif decision_type == 'basal_ganglia':
69             return np.random.normal(0.0,noise_decision,size=2)
70
71     def inputs_function(x):
72         return x * tau_wm
73
74     def wm_recurrent_function(x):
75         if drug_type == 'functional':
76             return x * drug_effect_functional[drug]
77         else:
78             return x
79
80     def decision_function(x):
81         output=0.0
82         if decision_type=='default':
83             value=x[0]+x[1]
84             if value > 0.0: output = 1.0
85             elif value < 0.0: output = -1.0
86         elif decision_type=='basal_ganglia':
87             if x[0] > x[1]: output = 1.0
88             elif x[0] < x[1]: output = -1.0
89         return output
90
91     def BG_rescale(x): #rescales -1 to 1 into 0.3 to 1, makes 2-dimensional
92         pos_x = 0.5 * (x + 1)
93         rescaled = 0.4 + 0.6 * pos_x, 0.4 + 0.6 * (1 - pos_x)
94         return rescaled
95
96     '''model definition'''
97     with nengo.Network(seed=seed+trial) as model:
98
99         #Ensembles
100         cue = nengo.Node(output=cue_function)
101         time = nengo.Node(output=time_function)
102         inputs = nengo.Ensemble(neurons_inputs,2)
103         noise_wm_node = nengo.Node(output=noise_bias_function)
104         noise_decision_node = nengo.Node(output=noise_decision_function)
105         wm = nengo.Ensemble(neurons_wm,2)
106         decision = nengo.Ensemble(neurons_decide,2)
107         output = nengo.Ensemble(neurons_decide,1)
108
109         #Connections
110         nengo.Connection(cue,inputs[0],synapse=None)
111         nengo.Connection(time,inputs[1],synapse=None)
112         nengo.Connection(inputs,wm,synapse=tau_wm,function=inputs_function)
113         wm_recurrent=nengo.Connection(wm,wm,synapse=tau_wm,function=wm_recurrent_function)
114         nengo.Connection(noise_wm_node,wm.neurons,synapse=tau_wm,transform=np.ones((neurons_wm,1))*tau_wm)
115         wm_to_decision=nengo.Connection(wm[0],decision[0],synapse=tau)
116         nengo.Connection(noise_decision_node,decision[1],synapse=None)
117         nengo.Connection(decision,output,function=decision_function)
118
119         #Probes
120         probe_wm=nengo.Probe(wm[0],synapse=0.01,sample_every=dt_sample)
121         probe_spikes=nengo.Probe(wm.neurons, 'spikes', sample_every=dt_sample)
122         probe_output=nengo.Probe(output,synapse=None,sample_every=dt_sample)
123
124
125
126
127     '''SIMULATION'''
128     print 'Running drug \"%s\", trial %s...' %(drug,trial+1)
129     with nengo.Simulator(model,dt=dt) as sim:
130         if drug_type == 'biophysical': sim=reset_gain_bias(
131                 P,model,sim,wm,wm_recurrent,wm_to_decision,drug)
132         sim.run(t_cue+t_delay)
133         df_primary=primary_dataframe(P,sim,drug,trial,probe_wm,probe_output)
134         df_firing=firing_dataframe(P,sim,drug,trial,sim.data[wm],probe_spikes)
135     return [df_primary, df_firing]
136
137
138
139 '''MAIN'''
140 def main():
141     import matplotlib.pyplot as plt
142     import seaborn as sns
143     import pandas as pd
144     import numpy as np
145     from helper import make_cues, empirical_dataframe
146     from pathos.helpers import freeze_support #for Windows
147     # import ipdb
148
149     '''Import Parameters from File'''
150     P=eval(open('parameters.txt').read()) #parameter dictionary
151     seed=P['seed'] #sets tuning curves equal to control before drug application
152     n_trials=P['n_trials']
153     drug_type=str(P['drug_type'])
154     decision_type=str(P['decision_type'])
155     drugs=P['drugs']
156     trials, perceived, cues = make_cues(P)
157     P['timesteps']=np.arange(0,int((P['t_cue']+P['t_delay'])/P['dt_sample']))
158     P['cues']=cues
159     P['perceived']=perceived
160
161     '''Multiprocessing'''
162     print "drug_type=%s, decision_type=%s, trials=%s..." %(drug_type,decision_type,n_trials)
163     freeze_support()
164     exp_params=[]
165     for drug in drugs:
166         for trial in trials:
167             exp_params.append([decision_type, drug_type, drug, trial, seed, P])
168     df_list=[run(exp_params[0]),run(exp_params[-1])]
169     primary_dataframe = pd.concat([df_list[i][0] for i in range(len(df_list))], ignore_index=True)
170     firing_dataframe = pd.concat([df_list[i][1] for i in range(len(df_list))], ignore_index=True)
171
172     '''Plot and Export'''
173     print 'Exporting Data...'
174     primary_dataframe.to_pickle('primary_data.pkl')
175     firing_dataframe.to_pickle('firing_data.pkl')
176     param_df=pd.DataFrame([P])
177     param_df.reset_index().to_json('params.json',orient='records')
178
179     print 'Plotting...'
180     emp_dataframe=empirical_dataframe()
181     sns.set(context='poster')
182     figure, (ax1, ax2) = plt.subplots(2, 1)
183     sns.tsplot(time="time",value="wm",data=primary_dataframe,unit="trial",condition='drug',ax=ax1,ci=95)
184     sns.tsplot(time="time",value="correct",data=primary_dataframe,unit="trial",condition='drug',ax=ax2,ci=95)
185     sns.tsplot(time="time",value="accuracy",data=emp_dataframe,unit='trial',condition='drug',
186                 interpolate=False,ax=ax2)
187     sns.tsplot(time="time",value="accuracy",data=emp_dataframe, unit='trial',condition='drug',
188                 interpolate=True,ax=ax2, legend=False)
189     ax1.set(xlabel='',ylabel='decoded $\hat{cue}$',xlim=(0,9.5),ylim=(0,1),
190                 title="drug_type=%s, decision_type=%s, trials=%s" %(drug_type,decision_type,n_trials))
191     ax2.set(xlabel='time (s)',xlim=(0,9.5),ylim=(0.5,1),ylabel='DRT accuracy')
192     figure.savefig('primary_plots.png')
193
194     figure2, (ax3, ax4) = plt.subplots(1, 2)
195     if len(firing_dataframe.query("tuning=='weak'"))>0:
196         sns.tsplot(time="time",value="firing_rate",unit="neuron-trial",condition='drug',ax=ax3,ci=95,
197                 data=firing_dataframe.query("tuning=='weak'").reset_index())
198     if len(firing_dataframe.query("tuning=='nonpreferred'"))>0:
199         sns.tsplot(time="time",value="firing_rate",unit="neuron-trial",condition='drug',ax=ax4,ci=95,
200                 data=firing_dataframe.query("tuning=='nonpreferred'").reset_index())
201     ax3.set(xlabel='time (s)',xlim=(0.0,9.5),ylim=(0,250),ylabel='Normalized Firing Rate',title='Preferred Direction')
202     ax4.set(xlabel='time (s)',xlim=(0.0,9.5),ylim=(0,250),ylabel='',title='Nonpreferred Direction')
203     figure2.savefig('firing_plots.png')
204
205     plt.show()
206
207 if __name__=='__main__':
208     main()