provide supplementary functions to the main scripts
modify the .gitignore
Showing
4 changed files
with
573 additions
and
360 deletions
DeepGONet.py
0 → 100644
1 | +# Useful packages | ||
2 | + | ||
3 | +import warnings | ||
4 | +import os | ||
5 | +import time | ||
6 | +import signal | ||
7 | +import sys | ||
8 | +import copy | ||
9 | +import h5py | ||
10 | +import pickle | ||
11 | +import random | ||
12 | +from tqdm import tqdm | ||
13 | + | ||
14 | +import numpy as np | ||
15 | +import matplotlib.pyplot as plt | ||
16 | +import seaborn | ||
17 | +import pandas as pd | ||
18 | +import tensorflow as tf | ||
19 | +from tensorflow.keras.utils import to_categorical | ||
20 | + | ||
21 | + | ||
22 | +# Experiment setting | ||
23 | +FLAGS = tf.app.flags.FLAGS | ||
24 | + | ||
25 | +# -- Configuration of the environnement -- | ||
26 | +tf.app.flags.DEFINE_string('log_dir', "../log", "log_dir") | ||
27 | +tf.app.flags.DEFINE_string('dir_data', "", "repository for all the files needed for the training and the evaluation") | ||
28 | +tf.app.flags.DEFINE_bool('save', False, "Do you need to save the model?") | ||
29 | +tf.app.flags.DEFINE_bool('restore', False, "Do you want to restore a previous model?") | ||
30 | +tf.app.flags.DEFINE_bool('is_training', True, "Is the model trainable?") | ||
31 | +tf.app.flags.DEFINE_string('processing', "train", "What to do with the model? {train,evaluate,predict}") | ||
32 | + | ||
33 | +# -- Architecture of the neural network -- | ||
34 | +tf.app.flags.DEFINE_integer('n_input', 54675, "number of features") | ||
35 | +tf.app.flags.DEFINE_integer('n_classes', 1, "number of classes") | ||
36 | +tf.app.flags.DEFINE_integer('n_layers', 6, "number of layers") | ||
37 | +tf.app.flags.DEFINE_integer('n_hidden_1', 1574, "number of neurons for the first hidden layer") #Level 7 | ||
38 | +tf.app.flags.DEFINE_integer('n_hidden_2', 1386, "number of neurons for the second hidden layer") #Level 6 | ||
39 | +tf.app.flags.DEFINE_integer('n_hidden_3', 951, "number of neurons for the third hidden layer") #Level 5 | ||
40 | +tf.app.flags.DEFINE_integer('n_hidden_4', 515, "number of neurons for the fourth hidden layer") #Level 4 | ||
41 | +tf.app.flags.DEFINE_integer('n_hidden_5', 255, "number of neurons for the fifth hidden layer") #Level 3 | ||
42 | +tf.app.flags.DEFINE_integer('n_hidden_6', 90, "number of neurons for the sixth hidden layer") #Level 2 | ||
43 | + | ||
44 | +# -- Learning and Hyperparameters -- | ||
45 | +tf.app.flags.DEFINE_string('lr_method', 'adam', "optimizer {adam, momentum, adagrad, rmsprop}") | ||
46 | +tf.app.flags.DEFINE_float('learning_rate', 0.001, "initial learning rate") | ||
47 | +tf.app.flags.DEFINE_bool('bn', False, "use of batch normalization") | ||
48 | +tf.app.flags.DEFINE_float('keep_prob', 0.4, "keep probability for the dropout") | ||
49 | +tf.app.flags.DEFINE_string('type_training', 'LGO', "regularization term {"", LGO, L2, L1}") | ||
50 | +tf.app.flags.DEFINE_float('alpha', 1, "value of the hyperparameter alpha") | ||
51 | +tf.app.flags.DEFINE_integer('display_step', 5, "when to print the performances") | ||
52 | +tf.app.flags.DEFINE_integer('batch_size', 2**9, "the number of examples in a batch") | ||
53 | +tf.app.flags.DEFINE_integer('epochs', 20, "the number of epochs for training") | ||
54 | +tf.app.flags.DEFINE_string('GPU_device', '/gpu:0', "GPU device") | ||
55 | + | ||
56 | +from base_model import BaseModel | ||
57 | + | ||
58 | +def l1_loss_func(x): | ||
59 | + return tf.reduce_sum(tf.math.abs(x)) | ||
60 | + | ||
61 | +def l2_loss_func(x): | ||
62 | + return tf.reduce_sum(tf.square(x)) | ||
63 | + | ||
64 | + | ||
65 | +def train(save_dir): | ||
66 | + | ||
67 | + warnings.filterwarnings("ignore") | ||
68 | + os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | ||
69 | + os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.GPU_device[len(FLAGS.GPU_device)-1] | ||
70 | + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
71 | + | ||
72 | + # Load the useful files to build the architecture | ||
73 | + print("Loading the connection matrix...") | ||
74 | + start = time.time() | ||
75 | + | ||
76 | + adj_matrix = pd.read_csv(os.path.abspath(os.path.join(FLAGS.dir_data,"adj_matrix.csv")),index_col=0) | ||
77 | + first_matrix_connection = pd.read_csv(os.path.abspath(os.path.join(FLAGS.dir_data,"first_matrix_connection_GO.csv")),index_col=0) | ||
78 | + csv_go = pd.read_csv(os.path.abspath(os.path.join(FLAGS.dir_data,"go_level.csv")),index_col=0) | ||
79 | + | ||
80 | + connection_matrix = [] | ||
81 | + connection_matrix.append(np.array(first_matrix_connection.values,dtype=np.float32)) | ||
82 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(7)].loc[lambda x: x==1].index,csv_go[str(6)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
83 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(6)].loc[lambda x: x==1].index,csv_go[str(5)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
84 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(5)].loc[lambda x: x==1].index,csv_go[str(4)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
85 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(4)].loc[lambda x: x==1].index,csv_go[str(3)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
86 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(3)].loc[lambda x: x==1].index,csv_go[str(2)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
87 | + connection_matrix.append(np.ones((FLAGS.n_hidden_6, FLAGS.n_classes),dtype=np.float32)) | ||
88 | + | ||
89 | + end = time.time() | ||
90 | + elapsed=end - start | ||
91 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
92 | + time.gmtime(elapsed).tm_min, | ||
93 | + time.gmtime(elapsed).tm_sec)) | ||
94 | + | ||
95 | + # Load the data | ||
96 | + print("Loading the data...") | ||
97 | + | ||
98 | + start = time.time() | ||
99 | + loaded = np.load(os.path.abspath(os.path.join(FLAGS.dir_data,"X_train.npz"))) | ||
100 | + X_train = loaded['x'] | ||
101 | + y_train = loaded['y'] | ||
102 | + if FLAGS.n_classes>=2: | ||
103 | + y_train=to_categorical(y_train) | ||
104 | + | ||
105 | + loaded = np.load(os.path.abspath(os.path.join(FLAGS.dir_data,"X_test.npz"))) | ||
106 | + X_test = loaded['x'] | ||
107 | + y_test = loaded['y'] | ||
108 | + if FLAGS.n_classes>=2: | ||
109 | + y_test=to_categorical(y_test) | ||
110 | + | ||
111 | + end = time.time() | ||
112 | + elapsed=end - start | ||
113 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
114 | + time.gmtime(elapsed).tm_min, | ||
115 | + time.gmtime(elapsed).tm_sec)) | ||
116 | + | ||
117 | + | ||
118 | + # Launch the model | ||
119 | + print("Launching the learning") | ||
120 | + if FLAGS.type_training != "": | ||
121 | + print("with {} and ALPHA={}".format(FLAGS.type_training,FLAGS.alpha)) | ||
122 | + | ||
123 | + tf.reset_default_graph() | ||
124 | + | ||
125 | + # -- Inputs of the model -- | ||
126 | + X = tf.placeholder(tf.float32, shape=[None, FLAGS.n_input]) | ||
127 | + Y = tf.placeholder(tf.float32, shape=[None, FLAGS.n_classes]) | ||
128 | + | ||
129 | + # -- Hyperparameters of the neural network -- | ||
130 | + is_training = tf.placeholder(tf.bool,name="is_training") # Batch Norm hyperparameter | ||
131 | + learning_rate = tf.placeholder(tf.float32, name="learning_rate") # Optimizer hyperparameter | ||
132 | + keep_prob = tf.placeholder(tf.float32, name="keep_prob") # Dropout hyperparameter | ||
133 | + total_batches=len(X_train)//FLAGS.batch_size | ||
134 | + | ||
135 | + network=BaseModel(X=X,n_input=FLAGS.n_input,n_classes=FLAGS.n_classes, | ||
136 | + n_hidden_1=FLAGS.n_hidden_1,n_hidden_2=FLAGS.n_hidden_2,n_hidden_3=FLAGS.n_hidden_3,n_hidden_4=FLAGS.n_hidden_4, | ||
137 | + n_hidden_5=FLAGS.n_hidden_5,n_hidden_6=FLAGS.n_hidden_6,keep_prob=keep_prob,is_training=is_training) # Model instantiation | ||
138 | + pred = network() | ||
139 | + | ||
140 | + # -- Loss function -- | ||
141 | + | ||
142 | + # ---- CE loss ---- | ||
143 | + # Compute the average of the loss across all the dimensions | ||
144 | + if FLAGS.n_classes>=2: | ||
145 | + ce_loss = f.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=Y)) | ||
146 | + else: | ||
147 | + ce_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=Y)) | ||
148 | + | ||
149 | + # ---- Regularization loss (LGO, L2, L1) ---- | ||
150 | + additional_loss = 0 | ||
151 | + if FLAGS.type_training=="LGO": | ||
152 | + for idx,weight in enumerate(network.weights.values()): | ||
153 | + additional_loss+=l2_loss_func(weight*(1-connection_matrix[idx])) # Penalization of the noGO connections | ||
154 | + elif FLAGS.type_training=="L2" : | ||
155 | + for weight in network.weights.values(): | ||
156 | + additional_loss += l2_loss_func(weight) | ||
157 | + elif FLAGS.type_training=="L1" : | ||
158 | + for idx,weight in enumerate(network.weights.values()): | ||
159 | + additional_loss+=l1_loss_func(weight) | ||
160 | + | ||
161 | + # ---- Total loss ---- | ||
162 | + if FLAGS.type_training!='' : | ||
163 | + total_loss = ce_loss + FLAGS.alpha*additional_loss | ||
164 | + else: | ||
165 | + total_loss = ce_loss | ||
166 | + | ||
167 | + | ||
168 | + # ---- Norm of the weights of the connections ---- | ||
169 | + norm_no_go_connections=0 | ||
170 | + norm_go_connections=0 | ||
171 | + for idx,weight in enumerate(list(network.weights.values())[:-1]): | ||
172 | + norm_no_go_connections+=tf.norm((weight*(1-connection_matrix[idx])),ord=1)/np.count_nonzero(1-connection_matrix[idx]) | ||
173 | + norm_go_connections+=tf.norm((weight*connection_matrix[idx]),ord=1)/np.count_nonzero(connection_matrix[idx]) | ||
174 | + norm_no_go_connections/=FLAGS.n_layers | ||
175 | + norm_go_connections/=FLAGS.n_layers | ||
176 | + | ||
177 | + # -- Optimizer -- | ||
178 | + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): | ||
179 | + if FLAGS.lr_method=="adam": | ||
180 | + trainer = tf.train.AdamOptimizer(learning_rate = learning_rate) | ||
181 | + elif FLAGS.lr_method=="momentum": | ||
182 | + trainer = tf.train.MomentumOptimizer(learning_rate = learning_rate, momentum=0.09, use_nesterov=True) | ||
183 | + elif FLAGS.lr_method=="adagrad": | ||
184 | + trainer = tf.train.AdagradOptimizer(learning_rate=learning_rate) | ||
185 | + elif FLAGS.lr_method=="rmsprop": | ||
186 | + trainer = tf.train.RMSPropOptimizer(learning_rate = learning_rate) | ||
187 | + optimizer = trainer.minimize(total_loss) | ||
188 | + | ||
189 | + # -- Compute the prediction error -- | ||
190 | + if FLAGS.n_classes>=2: | ||
191 | + correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(Y, 1)) | ||
192 | + else: | ||
193 | + sig_pred=tf.nn.sigmoid(pred) | ||
194 | + sig_pred=tf.cast(sig_pred>0.5,dtype=tf.int64) | ||
195 | + ground_truth=tf.cast(Y,dtype=tf.int64) | ||
196 | + correct_prediction = tf.equal(sig_pred,ground_truth) | ||
197 | + | ||
198 | + # -- Calculate the accuracy across all the given batches and average them out -- | ||
199 | + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
200 | + | ||
201 | + # -- Initialize the variables -- | ||
202 | + init = tf.global_variables_initializer() | ||
203 | + | ||
204 | + # -- Configure the use of the gpu -- | ||
205 | + config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) | ||
206 | + #config.gpu_options.allow_growth = True, log_device_placement=True | ||
207 | + | ||
208 | + if FLAGS.save or FLAGS.restore : saver = tf.train.Saver() | ||
209 | + | ||
210 | + start = time.time() | ||
211 | + | ||
212 | + with tf.device(FLAGS.GPU_device): | ||
213 | + with tf.Session(config=config) as sess: | ||
214 | + sess.run(init) | ||
215 | + | ||
216 | + train_c_accuracy=[] | ||
217 | + train_c_total_loss=[] | ||
218 | + | ||
219 | + test_c_accuracy=[] | ||
220 | + test_c_total_loss=[] | ||
221 | + | ||
222 | + c_l1_norm_go=[] | ||
223 | + c_l1_norm_no_go=[] | ||
224 | + | ||
225 | + if FLAGS.type_training!="": | ||
226 | + train_c_ce_loss=[] | ||
227 | + test_c_ce_loss=[] | ||
228 | + train_c_additional_loss=[] | ||
229 | + test_c_additional_loss=[] | ||
230 | + | ||
231 | + for epoch in tqdm(np.arange(0,FLAGS.epochs)): | ||
232 | + | ||
233 | + index = np.arange(X_train.shape[0]) | ||
234 | + np.random.shuffle(index) | ||
235 | + batch_X = np.array_split(X_train[index], total_batches) | ||
236 | + batch_Y = np.array_split(y_train[index], total_batches) | ||
237 | + | ||
238 | + # -- Optimization -- | ||
239 | + for batch in range(total_batches): | ||
240 | + batch_x,batch_y=batch_X[batch],batch_Y[batch] | ||
241 | + sess.run(optimizer, feed_dict={X: batch_x,Y: batch_y,is_training:FLAGS.is_training,keep_prob:FLAGS.keep_prob,learning_rate:FLAGS.learning_rate}) | ||
242 | + | ||
243 | + if ((epoch+1) % FLAGS.display_step == 0) or (epoch==0) : | ||
244 | + if not((FLAGS.display_step==FLAGS.epochs) and (epoch==0)): | ||
245 | + | ||
246 | + # -- Calculate batch loss and accuracy after a specific epoch on the train and test set -- | ||
247 | + | ||
248 | + avg_cost,avg_acc,l1_norm_no_go,l1_norm_go = sess.run([total_loss, accuracy,norm_no_go_connections,norm_go_connections], feed_dict={X: X_train,Y: y_train, | ||
249 | + is_training:False,keep_prob:1.0}) | ||
250 | + train_c_total_loss.append(avg_cost) | ||
251 | + train_c_accuracy.append(avg_acc) | ||
252 | + c_l1_norm_go.append(l1_norm_go) | ||
253 | + c_l1_norm_no_go.append(l1_norm_no_go) | ||
254 | + | ||
255 | + if FLAGS.type_training!="": | ||
256 | + avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_train,Y: y_train,is_training:False,keep_prob:1.0}) | ||
257 | + train_c_additional_loss.append(avg_additional_loss) | ||
258 | + train_c_ce_loss.append(avg_ce_loss) | ||
259 | + | ||
260 | + avg_cost,avg_acc = sess.run([total_loss, accuracy], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
261 | + test_c_total_loss.append(avg_cost) | ||
262 | + test_c_accuracy.append(avg_acc) | ||
263 | + | ||
264 | + if FLAGS.type_training!="": | ||
265 | + avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
266 | + test_c_additional_loss.append(avg_additional_loss) | ||
267 | + test_c_ce_loss.append(avg_ce_loss) | ||
268 | + | ||
269 | + current_idx=len(train_c_total_loss)-1 | ||
270 | + print('| Epoch: {}/{} | Train: Loss {:.6f} Accuracy : {:.6f} '\ | ||
271 | + '| Test: Loss {:.6f} Accuracy : {:.6f}\n'.format( | ||
272 | + epoch+1, FLAGS.epochs,train_c_total_loss[current_idx], train_c_accuracy[current_idx],test_c_total_loss[current_idx],test_c_accuracy[current_idx])) | ||
273 | + | ||
274 | + if FLAGS.save: saver.save(sess=sess, save_path=os.path.join(save_dir,"model")) | ||
275 | + | ||
276 | + end = time.time() | ||
277 | + elapsed=end - start | ||
278 | + print("Total time: {}h {}min {}sec ".format(time.gmtime(elapsed).tm_hour, | ||
279 | + time.gmtime(elapsed).tm_min, | ||
280 | + time.gmtime(elapsed).tm_sec)) | ||
281 | + | ||
282 | + performances = { | ||
283 | + 'total_loss':train_c_total_loss,'test_total_loss':test_c_total_loss, | ||
284 | + 'acc':train_c_accuracy,'test_acc':test_c_accuracy | ||
285 | + } | ||
286 | + | ||
287 | + performances['norm_go']=c_l1_norm_go | ||
288 | + performances['norm_no_go']=c_l1_norm_no_go | ||
289 | + | ||
290 | + if FLAGS.type_training!="": | ||
291 | + performances['additional_loss']=train_c_additional_loss | ||
292 | + performances['test_additional_loss']=test_c_additional_loss | ||
293 | + performances['ce_loss']=train_c_ce_loss | ||
294 | + performances['test_ce_loss']=test_c_ce_loss | ||
295 | + | ||
296 | + | ||
297 | + return performances | ||
298 | + | ||
299 | +def evaluate(save_dir): | ||
300 | + | ||
301 | + warnings.filterwarnings("ignore") | ||
302 | + os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | ||
303 | + os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.GPU_device[len(FLAGS.GPU_device)-1] | ||
304 | + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
305 | + | ||
306 | + # Load the useful files to build the architecture | ||
307 | + print("Loading the connection matrix...") | ||
308 | + start = time.time() | ||
309 | + | ||
310 | + adj_matrix = pd.read_csv(os.path.join(FLAGS.dir_data,"adj_matrix.csv"),index_col=0) | ||
311 | + first_matrix_connection = pd.read_csv(os.path.join(FLAGS.dir_data,"first_matrix_connection_GO.csv"),index_col=0) | ||
312 | + csv_go = pd.read_csv(os.path.join(FLAGS.dir_data,"go_level.csv"),index_col=0) | ||
313 | + | ||
314 | + connection_matrix = [] | ||
315 | + connection_matrix.append(np.array(first_matrix_connection.values,dtype=np.float32)) | ||
316 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(7)].loc[lambda x: x==1].index,csv_go[str(6)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
317 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(6)].loc[lambda x: x==1].index,csv_go[str(5)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
318 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(5)].loc[lambda x: x==1].index,csv_go[str(4)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
319 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(4)].loc[lambda x: x==1].index,csv_go[str(3)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
320 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(3)].loc[lambda x: x==1].index,csv_go[str(2)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
321 | + connection_matrix.append(np.ones((FLAGS.n_hidden_6, FLAGS.n_classes),dtype=np.float32)) | ||
322 | + | ||
323 | + end = time.time() | ||
324 | + elapsed=end - start | ||
325 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
326 | + time.gmtime(elapsed).tm_min, | ||
327 | + time.gmtime(elapsed).tm_sec)) | ||
328 | + | ||
329 | + # Load the data | ||
330 | + print("Loading the test dataset...") | ||
331 | + | ||
332 | + loaded = np.load(os.path.join(FLAGS.dir_data,"X_test.npz")) | ||
333 | + X_test = loaded['x'] | ||
334 | + y_test = loaded['y'] | ||
335 | + if FLAGS.n_classes>=2: | ||
336 | + y_test=to_categorical(y_test) | ||
337 | + | ||
338 | + end = time.time() | ||
339 | + elapsed=end - start | ||
340 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
341 | + time.gmtime(elapsed).tm_min, | ||
342 | + time.gmtime(elapsed).tm_sec)) | ||
343 | + | ||
344 | + | ||
345 | + # Launch the model | ||
346 | + print("Launching the evaluation") | ||
347 | + if FLAGS.type_training != "": | ||
348 | + print("with {} and ALPHA={}".format(FLAGS.type_training,FLAGS.alpha)) | ||
349 | + | ||
350 | + tf.reset_default_graph() | ||
351 | + | ||
352 | + # -- Inputs of the model -- | ||
353 | + X = tf.placeholder(tf.float32, shape=[None, FLAGS.n_input]) | ||
354 | + Y = tf.placeholder(tf.float32, shape=[None, FLAGS.n_classes]) | ||
355 | + | ||
356 | + # -- Hyperparameters of the neural network -- | ||
357 | + is_training = tf.placeholder(tf.bool,name="is_training") # Batch Norm hyperparameter | ||
358 | + keep_prob = tf.placeholder(tf.float32, name="keep_prob") # Dropout hyperparameter | ||
359 | + | ||
360 | + network=BaseModel(X=X,n_input=FLAGS.n_input,n_classes=FLAGS.n_classes, | ||
361 | + n_hidden_1=FLAGS.n_hidden_1,n_hidden_2=FLAGS.n_hidden_2,n_hidden_3=FLAGS.n_hidden_3,n_hidden_4=FLAGS.n_hidden_4, | ||
362 | + n_hidden_5=FLAGS.n_hidden_5,n_hidden_6=FLAGS.n_hidden_6,keep_prob=keep_prob,is_training=is_training) # Model instantiation | ||
363 | + pred = network() | ||
364 | + | ||
365 | + # -- Loss function -- | ||
366 | + | ||
367 | + # ---- CE loss ---- | ||
368 | + # Compute the average of the loss across all the dimensions | ||
369 | + if FLAGS.n_classes>=2: | ||
370 | + ce_loss = f.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=Y)) | ||
371 | + else: | ||
372 | + ce_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=Y)) | ||
373 | + | ||
374 | + # ---- Regularization loss (LGO, L2, L1) ---- | ||
375 | + additional_loss = 0 | ||
376 | + if FLAGS.type_training=="LGO": | ||
377 | + for idx,weight in enumerate(network.weights.values()): | ||
378 | + additional_loss+=l2_loss_func(weight*(1-connection_matrix[idx])) # Penalization of the noGO connections | ||
379 | + elif FLAGS.type_training=="L2" : | ||
380 | + for weight in network.weights.values(): | ||
381 | + additional_loss += l2_loss_func(weight) | ||
382 | + elif FLAGS.type_training=="L1" : | ||
383 | + for idx,weight in enumerate(network.weights.values()): | ||
384 | + additional_loss+=l1_loss_func(weight) | ||
385 | + | ||
386 | + # ---- Total loss ---- | ||
387 | + if FLAGS.type_training!='' : | ||
388 | + total_loss = ce_loss + FLAGS.alpha*additional_loss | ||
389 | + else: | ||
390 | + total_loss = ce_loss | ||
391 | + | ||
392 | + # -- Compute the prediction error -- | ||
393 | + if FLAGS.n_classes>=2: | ||
394 | + correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(Y, 1)) | ||
395 | + else: | ||
396 | + sig_pred=tf.nn.sigmoid(pred) | ||
397 | + sig_pred=tf.cast(sig_pred>0.5,dtype=tf.int64) | ||
398 | + ground_truth=tf.cast(Y,dtype=tf.int64) | ||
399 | + correct_prediction = tf.equal(sig_pred,ground_truth) | ||
400 | + | ||
401 | + # -- Calculate the accuracy across all the given batches and average them out -- | ||
402 | + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
403 | + | ||
404 | + # -- Configure the use of the gpu -- | ||
405 | + config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) | ||
406 | + #config.gpu_options.allow_growth = True, log_device_placement=True | ||
407 | + | ||
408 | + if FLAGS.restore : saver = tf.train.Saver() | ||
409 | + | ||
410 | + start = time.time() | ||
411 | + | ||
412 | + with tf.device(FLAGS.GPU_device): | ||
413 | + with tf.Session(config=config) as sess: | ||
414 | + if FLAGS.restore: | ||
415 | + saver.restore(sess,os.path.join(save_dir,"model")) | ||
416 | + | ||
417 | + # -- Calculate the final loss and the final accuracy on the test set -- | ||
418 | + | ||
419 | + avg_cost,avg_acc = sess.run([total_loss, accuracy], feed_dict={X: X_test,Y: y_test,is_training:FLAGS.is_training,keep_prob:1}) | ||
420 | + | ||
421 | + print('Test loss {:.6f}, test accuracy : {:.6f}\n'.format(avg_cost,avg_acc)) | ||
422 | + | ||
423 | + end = time.time() | ||
424 | + elapsed=end - start | ||
425 | + print("Total time: {}h {}min {}sec ".format(time.gmtime(elapsed).tm_hour, | ||
426 | + time.gmtime(elapsed).tm_min, | ||
427 | + time.gmtime(elapsed).tm_sec)) | ||
428 | + | ||
429 | + return | ||
430 | + | ||
431 | +def predict(save_dir): | ||
432 | + | ||
433 | + warnings.filterwarnings("ignore") | ||
434 | + os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | ||
435 | + os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.GPU_device[len(FLAGS.GPU_device)-1] | ||
436 | + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
437 | + | ||
438 | + # Load the useful files to build the architecture | ||
439 | + print("Loading the connection matrix...") | ||
440 | + start = time.time() | ||
441 | + | ||
442 | + adj_matrix = pd.read_csv(os.path.join(FLAGS.dir_data,"adj_matrix.csv"),index_col=0) | ||
443 | + first_matrix_connection = pd.read_csv(os.path.join(FLAGS.dir_data,"first_matrix_connection_GO.csv"),index_col=0) | ||
444 | + csv_go = pd.read_csv(os.path.join(FLAGS.dir_data,"go_level.csv"),index_col=0) | ||
445 | + | ||
446 | + connection_matrix = [] | ||
447 | + connection_matrix.append(np.array(first_matrix_connection.values,dtype=np.float32)) | ||
448 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(7)].loc[lambda x: x==1].index,csv_go[str(6)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
449 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(6)].loc[lambda x: x==1].index,csv_go[str(5)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
450 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(5)].loc[lambda x: x==1].index,csv_go[str(4)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
451 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(4)].loc[lambda x: x==1].index,csv_go[str(3)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
452 | + connection_matrix.append(np.array(adj_matrix.loc[csv_go[str(3)].loc[lambda x: x==1].index,csv_go[str(2)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
453 | + connection_matrix.append(np.ones((FLAGS.n_hidden_6, FLAGS.n_classes),dtype=np.float32)) | ||
454 | + | ||
455 | + end = time.time() | ||
456 | + elapsed=end - start | ||
457 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
458 | + time.gmtime(elapsed).tm_min, | ||
459 | + time.gmtime(elapsed).tm_sec)) | ||
460 | + | ||
461 | + # Load the data | ||
462 | + print("Loading the test dataset...") | ||
463 | + | ||
464 | + loaded = np.load(os.path.join(FLAGS.dir_data,"X_test.npz")) | ||
465 | + X_test = loaded['x'] | ||
466 | + y_test = loaded['y'] | ||
467 | + if FLAGS.n_classes>=2: | ||
468 | + y_test=to_categorical(y_test) | ||
469 | + | ||
470 | + end = time.time() | ||
471 | + elapsed=end - start | ||
472 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
473 | + time.gmtime(elapsed).tm_min, | ||
474 | + time.gmtime(elapsed).tm_sec)) | ||
475 | + | ||
476 | + | ||
477 | + # Launch the model | ||
478 | + print("Launching the evaluation") | ||
479 | + if FLAGS.type_training != "": | ||
480 | + print("with {} and ALPHA={}".format(FLAGS.type_training,FLAGS.alpha)) | ||
481 | + | ||
482 | + tf.reset_default_graph() | ||
483 | + | ||
484 | + # -- Inputs of the model -- | ||
485 | + X = tf.placeholder(tf.float32, shape=[None, FLAGS.n_input]) | ||
486 | + Y = tf.placeholder(tf.float32, shape=[None, FLAGS.n_classes]) | ||
487 | + | ||
488 | + # -- Hyperparameters of the neural network -- | ||
489 | + is_training = tf.placeholder(tf.bool,name="is_training") # Batch Norm hyperparameter | ||
490 | + keep_prob = tf.placeholder(tf.float32, name="keep_prob") # Dropout hyperparameter | ||
491 | + | ||
492 | + network=BaseModel(X=X,n_input=FLAGS.n_input,n_classes=FLAGS.n_classes, | ||
493 | + n_hidden_1=FLAGS.n_hidden_1,n_hidden_2=FLAGS.n_hidden_2,n_hidden_3=FLAGS.n_hidden_3,n_hidden_4=FLAGS.n_hidden_4, | ||
494 | + n_hidden_5=FLAGS.n_hidden_5,n_hidden_6=FLAGS.n_hidden_6,keep_prob=keep_prob,is_training=is_training) # Model instantiation | ||
495 | + pred = network() | ||
496 | + # -- Compute the prediction error -- | ||
497 | + if FLAGS.n_classes>=2: | ||
498 | + y_hat = tf.argmax(pred,1) | ||
499 | + else: | ||
500 | + y_hat = tf.nn.sigmoid(pred) | ||
501 | + y_hat = tf.cast(pred>0.5,dtype=tf.int64) | ||
502 | + | ||
503 | + # -- Configure the use of the gpu -- | ||
504 | + config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) | ||
505 | + #config.gpu_options.allow_growth = True, log_device_placement=True | ||
506 | + | ||
507 | + if FLAGS.restore : saver = tf.train.Saver() | ||
508 | + | ||
509 | + start = time.time() | ||
510 | + | ||
511 | + with tf.device(FLAGS.GPU_device): | ||
512 | + with tf.Session(config=config) as sess: | ||
513 | + if FLAGS.restore: | ||
514 | + saver.restore(sess,os.path.join(save_dir,"model")) | ||
515 | + | ||
516 | + # -- Predict the outcome predictions of the samples from the test set -- | ||
517 | + | ||
518 | + y_hat = sess.run([y_hat], feed_dict={X: X_test,Y: y_test,is_training:FLAGS.is_training,keep_prob:1}) | ||
519 | + | ||
520 | + end = time.time() | ||
521 | + elapsed=end - start | ||
522 | + print("Total time: {}h {}min {}sec ".format(time.gmtime(elapsed).tm_hour, | ||
523 | + time.gmtime(elapsed).tm_min, | ||
524 | + time.gmtime(elapsed).tm_sec)) | ||
525 | + | ||
526 | + return y_hat | ||
527 | + | ||
528 | + | ||
529 | +def main(_): | ||
530 | + | ||
531 | + save_dir=os.path.join(FLAGS.log_dir,'MLP_DP={}_BN={}_EPOCHS={}_OPT={}'.format(FLAGS.keep_prob,FLAGS.bn,FLAGS.epochs,FLAGS.lr_method)) | ||
532 | + | ||
533 | + if FLAGS.type_training!="" : | ||
534 | + save_dir=save_dir+'_{}_ALPHA={}'.format(FLAGS.type_training,FLAGS.alpha) | ||
535 | + | ||
536 | + if FLAGS.processing=="train": | ||
537 | + | ||
538 | + start_full = time.time() | ||
539 | + | ||
540 | + if not(os.path.isdir(save_dir)): | ||
541 | + os.mkdir(save_dir) | ||
542 | + | ||
543 | + performances=train(save_dir=save_dir) | ||
544 | + | ||
545 | + with open(os.path.join(save_dir,"histories.txt"), "wb") as fp: | ||
546 | + #Pickling | ||
547 | + pickle.dump(performances, fp) | ||
548 | + | ||
549 | + end = time.time() | ||
550 | + elapsed =end - start_full | ||
551 | + print("Total time full process: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
552 | + time.gmtime(elapsed).tm_min, | ||
553 | + time.gmtime(elapsed).tm_sec)) | ||
554 | + | ||
555 | + elif FLAGS.processing=="evaluate": | ||
556 | + | ||
557 | + evaluate(save_dir=save_dir) | ||
558 | + | ||
559 | + elif FLAGS.processing=="predict": | ||
560 | + | ||
561 | + np.savez_compressed(os.path.join(save_dir,'y_test_hat'),y_hat=predict(save_dir=save_dir)) | ||
562 | + | ||
563 | +if __name__ == "__main__": | ||
564 | + tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -6,13 +6,11 @@ FLAGS = tf.app.flags.FLAGS | ... | @@ -6,13 +6,11 @@ FLAGS = tf.app.flags.FLAGS |
6 | class BaseModel(): | 6 | class BaseModel(): |
7 | 7 | ||
8 | def __init__(self,X,n_input,n_classes,n_hidden_1,n_hidden_2,n_hidden_3,n_hidden_4,n_hidden_5,n_hidden_6,is_training,keep_prob): | 8 | def __init__(self,X,n_input,n_classes,n_hidden_1,n_hidden_2,n_hidden_3,n_hidden_4,n_hidden_5,n_hidden_6,is_training,keep_prob): |
9 | + | ||
10 | + | ||
11 | + # Parameters | ||
9 | self.X = X | 12 | self.X = X |
10 | self.n_input = n_input | 13 | self.n_input = n_input |
11 | - self.is_training = is_training | ||
12 | - | ||
13 | - #Hyperparameters | ||
14 | - self.keep_prob = keep_prob # Dropout | ||
15 | - | ||
16 | self.n_classes=n_classes | 14 | self.n_classes=n_classes |
17 | self.n_hidden_1=n_hidden_1 | 15 | self.n_hidden_1=n_hidden_1 |
18 | self.n_hidden_2=n_hidden_2 | 16 | self.n_hidden_2=n_hidden_2 |
... | @@ -21,6 +19,11 @@ class BaseModel(): | ... | @@ -21,6 +19,11 @@ class BaseModel(): |
21 | self.n_hidden_5=n_hidden_5 | 19 | self.n_hidden_5=n_hidden_5 |
22 | self.n_hidden_6=n_hidden_6 | 20 | self.n_hidden_6=n_hidden_6 |
23 | 21 | ||
22 | + # Hyperparameters | ||
23 | + self.keep_prob = keep_prob # Dropout | ||
24 | + self.is_training = is_training # BN | ||
25 | + | ||
26 | + | ||
24 | def store_layer_weights_and_bias(self): | 27 | def store_layer_weights_and_bias(self): |
25 | self.weights = { | 28 | self.weights = { |
26 | 'h1_w': tf.get_variable('W1', shape=(self.n_input, self.n_hidden_1), initializer=tf.contrib.layers.variance_scaling_initializer()), | 29 | 'h1_w': tf.get_variable('W1', shape=(self.n_input, self.n_hidden_1), initializer=tf.contrib.layers.variance_scaling_initializer()), | ... | ... |
train.py
deleted
100644 → 0
1 | -import warnings | ||
2 | -import os | ||
3 | -import time | ||
4 | -import signal | ||
5 | -import sys | ||
6 | -import copy | ||
7 | -import h5py | ||
8 | - | ||
9 | -import pickle | ||
10 | -import random | ||
11 | -import seaborn | ||
12 | -import numpy as np | ||
13 | -import matplotlib.pyplot as plt | ||
14 | -import pandas as pd | ||
15 | -from sklearn import preprocessing | ||
16 | -from sklearn.model_selection import train_test_split | ||
17 | -from sklearn.utils import class_weight | ||
18 | -import tensorflow as tf | ||
19 | -from tensorflow.keras.utils import to_categorical | ||
20 | -from tqdm import tqdm | ||
21 | - | ||
22 | -# Configuration | ||
23 | -FLAGS = tf.app.flags.FLAGS | ||
24 | - | ||
25 | -tf.app.flags.DEFINE_string('GPU_device', '/gpu:0', "GPU device") | ||
26 | - | ||
27 | -tf.app.flags.DEFINE_bool('save', False, "Do you need to save the trained model?") | ||
28 | -tf.app.flags.DEFINE_bool('restore', False, "Do you want to restore a previous trained model?") | ||
29 | - | ||
30 | -tf.app.flags.DEFINE_string('dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation", "dir") | ||
31 | -tf.app.flags.DEFINE_string('log_dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation/log", "log_dir") | ||
32 | -tf.app.flags.DEFINE_string('file_extension', "", "file_extension {sigmoid,softmax,without_bn}") | ||
33 | -tf.app.flags.DEFINE_string('dir_data', "/home/vbourgeais/Stage/data/MicroArray", "dir_data") | ||
34 | -tf.app.flags.DEFINE_string('temp_dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation", "temp_dir") | ||
35 | -tf.app.flags.DEFINE_integer('seed', 42, "initial random seed") | ||
36 | - | ||
37 | -#EVALUATION PART | ||
38 | -tf.app.flags.DEFINE_float('ref_value', 0.1, "value to test") | ||
39 | -tf.app.flags.DEFINE_string('ref_layer', "h1", "layer to analyze") | ||
40 | -tf.app.flags.DEFINE_string('ref_GO', "", "GO to examine") | ||
41 | - | ||
42 | -tf.app.flags.DEFINE_integer('display_step', 5, "when to print the performances") | ||
43 | - | ||
44 | -tf.app.flags.DEFINE_integer('batch_size', 2**9, "the number of examples in a batch") | ||
45 | -tf.app.flags.DEFINE_integer('EPOCHS', 20, "the number of epochs for training") | ||
46 | - | ||
47 | -tf.app.flags.DEFINE_integer('epoch_decay_start', 100, "epoch of starting learning rate decay") | ||
48 | -tf.app.flags.DEFINE_bool('early_stopping', False, "early_stopping") | ||
49 | - | ||
50 | -tf.app.flags.DEFINE_integer('n_input', 54675, "number of features") | ||
51 | -tf.app.flags.DEFINE_integer('n_classes', 1, "number of classes") | ||
52 | -tf.app.flags.DEFINE_integer('n_layers', 6, "number of layers") | ||
53 | -tf.app.flags.DEFINE_integer('n_hidden_1', 1574, "number of nodes for the first hidden layer") #Level 7 | ||
54 | -tf.app.flags.DEFINE_integer('n_hidden_2', 1386, "number of nodes for the second hidden layer") #Level 6 | ||
55 | -tf.app.flags.DEFINE_integer('n_hidden_3', 951, "number of nodes for the third hidden layer") #Level 5 | ||
56 | -tf.app.flags.DEFINE_integer('n_hidden_4', 515, "number of nodes for the fourth hidden layer") #Level 4 | ||
57 | -tf.app.flags.DEFINE_integer('n_hidden_5', 255, "number of nodes for the fifth hidden layer") #Level 3 | ||
58 | -tf.app.flags.DEFINE_integer('n_hidden_6', 90, "number of nodes for the sixth hidden layer") #Level 2 | ||
59 | - | ||
60 | -tf.app.flags.DEFINE_float('learning_rate', 0.001, "initial learning rate") | ||
61 | -tf.app.flags.DEFINE_bool('bn', False, "BN use") | ||
62 | -tf.app.flags.DEFINE_bool('is_training', True, "Is it trainable?") | ||
63 | -tf.app.flags.DEFINE_float('keep_prob', 0.4, "probability for the dropout") | ||
64 | -tf.app.flags.DEFINE_string('type_training', 'LGO', "{"", LGO, L2, L1}") | ||
65 | -tf.app.flags.DEFINE_float('alpha', 1, "alpha") | ||
66 | -tf.app.flags.DEFINE_bool('weighted_loss', False, "balance the data in the total loss") | ||
67 | -tf.app.flags.DEFINE_string('lr_method', 'adam', "{adam, momentum, adagrad, rmsprop}") | ||
68 | - | ||
69 | -from base_model import BaseModel | ||
70 | - | ||
71 | -def l1_loss_func(x): | ||
72 | - return tf.reduce_sum(tf.math.abs(x)) | ||
73 | - | ||
74 | -def l2_loss_func(x): | ||
75 | - return tf.reduce_sum(tf.square(x)) | ||
76 | - | ||
77 | - | ||
78 | -def train(save_dir): | ||
79 | - | ||
80 | - warnings.filterwarnings("ignore") | ||
81 | - os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | ||
82 | - os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.GPU_device[len(FLAGS.GPU_device)-1] | ||
83 | - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
84 | - | ||
85 | - # Load the files useful | ||
86 | - print("Loading the connexion matrix...") | ||
87 | - start = time.time() | ||
88 | - | ||
89 | - adj_matrix = pd.read_csv(os.path.join(FLAGS.dir,"adj_matrix_cropped.csv"),index_col=0) | ||
90 | - first_matrix_connection = pd.read_csv(os.path.join(FLAGS.dir,"first_matrix_connection_GO.csv"),index_col=0) | ||
91 | - csv_go = pd.read_csv(os.path.join(FLAGS.dir,"go_level_v2.csv"),index_col=0) | ||
92 | - | ||
93 | - connexion_matrix = [] | ||
94 | - connexion_matrix.append(np.array(first_matrix_connection.values,dtype=np.float32)) | ||
95 | - connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(7)].loc[lambda x: x==1].index,csv_go[str(6)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
96 | - connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(6)].loc[lambda x: x==1].index,csv_go[str(5)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
97 | - connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(5)].loc[lambda x: x==1].index,csv_go[str(4)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
98 | - connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(4)].loc[lambda x: x==1].index,csv_go[str(3)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
99 | - connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(3)].loc[lambda x: x==1].index,csv_go[str(2)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
100 | - connexion_matrix.append(np.ones((FLAGS.n_hidden_6, FLAGS.n_classes),dtype=np.float32)) | ||
101 | - | ||
102 | - end = time.time() | ||
103 | - elapsed=end - start | ||
104 | - print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
105 | - time.gmtime(elapsed).tm_min, | ||
106 | - time.gmtime(elapsed).tm_sec)) | ||
107 | - | ||
108 | - # Load the data | ||
109 | - print("Loading the data...") | ||
110 | - start = time.time() | ||
111 | - loaded = np.load(os.path.join(FLAGS.dir_data,"X_train.npz")) | ||
112 | - X_train = loaded['x'] | ||
113 | - | ||
114 | - y_train = loaded['y'] | ||
115 | - if FLAGS.n_classes>=2: | ||
116 | - y_train=to_categorical(y_train) | ||
117 | - | ||
118 | - loaded = np.load(os.path.join(FLAGS.dir_data,"X_test.npz")) | ||
119 | - X_test = loaded['x'] | ||
120 | - y_test = loaded['y'] | ||
121 | - if FLAGS.n_classes>=2: | ||
122 | - y_test=to_categorical(y_test) | ||
123 | - | ||
124 | - | ||
125 | - | ||
126 | - end = time.time() | ||
127 | - elapsed=end - start | ||
128 | - print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
129 | - time.gmtime(elapsed).tm_min, | ||
130 | - time.gmtime(elapsed).tm_sec)) | ||
131 | - | ||
132 | - | ||
133 | - # Launch the model | ||
134 | - print("Launch the learning with the "+FLAGS.type_training) | ||
135 | - if FLAGS.type_training != "baseline": | ||
136 | - print("for ALPHA={}".format(FLAGS.alpha)) | ||
137 | - | ||
138 | - tf.reset_default_graph() | ||
139 | - | ||
140 | - | ||
141 | - #Inputs of the model | ||
142 | - X = tf.placeholder(tf.float32, shape=[None, FLAGS.n_input]) | ||
143 | - Y = tf.placeholder(tf.float32, shape=[None, FLAGS.n_classes]) | ||
144 | - | ||
145 | - #Hyperparameters | ||
146 | - is_training = tf.placeholder(tf.bool,name="is_training") #batch Norm | ||
147 | - learning_rate = tf.placeholder(tf.float32, name="learning_rate") | ||
148 | - keep_prob = tf.placeholder(tf.float32, name="keep_prob") # Dropout | ||
149 | - total_batches=len(X_train)//FLAGS.batch_size | ||
150 | - | ||
151 | - network=BaseModel(X=X,n_input=FLAGS.n_input,n_classes=FLAGS.n_classes, | ||
152 | - n_hidden_1=FLAGS.n_hidden_1,n_hidden_2=FLAGS.n_hidden_2,n_hidden_3=FLAGS.n_hidden_3,n_hidden_4=FLAGS.n_hidden_4, | ||
153 | - n_hidden_5=FLAGS.n_hidden_5,n_hidden_6=FLAGS.n_hidden_6,keep_prob=keep_prob,is_training=is_training) | ||
154 | - #here we can compute the model both for l2 custom and no-custom | ||
155 | - | ||
156 | - pred = network() | ||
157 | - | ||
158 | - #Compute the average of the loss across all the dimensions | ||
159 | - if FLAGS.weighted_loss: | ||
160 | - ce_loss = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=Y,pos_weight=class_weights[1])) | ||
161 | - else: | ||
162 | - ce_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=Y)) | ||
163 | - | ||
164 | - additional_loss = 0 | ||
165 | - if FLAGS.type_training=="LGO": | ||
166 | - for idx,weight in enumerate(network.weights.values()): | ||
167 | - additional_loss+=l2_loss_func(weight*(1-connexion_matrix[idx])) | ||
168 | - elif FLAGS.type_training=="L2" : | ||
169 | - for weight in network.weights.values(): | ||
170 | - additional_loss += l2_loss_func(weight) | ||
171 | - elif FLAGS.type_training=="L1" : | ||
172 | - for idx,weight in enumerate(network.weights.values()): | ||
173 | - additional_loss+=l1_loss_func(weight) | ||
174 | - | ||
175 | - | ||
176 | - norm_no_go_connexions=0 | ||
177 | - norm_go_connexions=0 | ||
178 | - for idx,weight in enumerate(list(network.weights.values())[:-1]): | ||
179 | - norm_no_go_connexions+=tf.norm((weight*(1-connexion_matrix[idx])),ord=1)/np.count_nonzero(1-connexion_matrix[idx]) | ||
180 | - norm_go_connexions+=tf.norm((weight*connexion_matrix[idx]),ord=1)/np.count_nonzero(connexion_matrix[idx]) | ||
181 | - norm_no_go_connexions/=FLAGS.n_layers | ||
182 | - norm_go_connexions/=FLAGS.n_layers | ||
183 | - | ||
184 | - if FLAGS.type_training!='' : | ||
185 | - total_loss = ce_loss + FLAGS.alpha*additional_loss | ||
186 | - else: | ||
187 | - total_loss = ce_loss | ||
188 | - | ||
189 | - #optimizer | ||
190 | - with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): | ||
191 | - if FLAGS.lr_method=="adam": | ||
192 | - trainer = tf.train.AdamOptimizer(learning_rate = learning_rate) | ||
193 | - elif FLAGS.lr_method=="momentum": | ||
194 | - trainer = tf.train.MomentumOptimizer(learning_rate = learning_rate, momentum=0.09, use_nesterov=True) | ||
195 | - elif FLAGS.lr_method=="adagrad": | ||
196 | - trainer = tf.train.AdagradOptimizer(learning_rate=learning_rate) | ||
197 | - elif FLAGS.lr_method=="rmsprop": | ||
198 | - trainer = tf.train.RMSPropOptimizer(learning_rate = learning_rate) | ||
199 | - optimizer = trainer.minimize(total_loss) | ||
200 | - | ||
201 | - if FLAGS.n_classes>=2: | ||
202 | - correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(Y, 1)) | ||
203 | - else: | ||
204 | - sig_pred=tf.nn.sigmoid(pred) | ||
205 | - sig_pred=tf.cast(sig_pred>0.5,dtype=tf.int64) | ||
206 | - ground_truth=tf.cast(Y,dtype=tf.int64) | ||
207 | - correct_prediction = tf.equal(sig_pred,ground_truth) | ||
208 | - | ||
209 | - #Calculate the accuracy across all the given batch and average them out. | ||
210 | - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
211 | - | ||
212 | - # Initializing the variables | ||
213 | - init = tf.global_variables_initializer() | ||
214 | - | ||
215 | - config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) | ||
216 | - #config.gpu_options.allow_growth = True, log_device_placement=True | ||
217 | - #to use the tensorboard | ||
218 | - | ||
219 | - if FLAGS.save or FLAGS.restore : saver = tf.train.Saver() | ||
220 | - | ||
221 | - start = time.time() | ||
222 | - | ||
223 | - with tf.device(FLAGS.GPU_device): | ||
224 | - with tf.Session(config=config) as sess: | ||
225 | - sess.run(init) | ||
226 | - | ||
227 | - train_c_accuracy=[] | ||
228 | - train_c_total_loss=[] | ||
229 | - | ||
230 | - test_c_accuracy=[] | ||
231 | - test_c_total_loss=[] | ||
232 | - | ||
233 | - c_l1_norm_go=[] | ||
234 | - c_l1_norm_no_go=[] | ||
235 | - | ||
236 | - if FLAGS.type_training!="": | ||
237 | - test_c_ce_loss=[] | ||
238 | - test_c_additional_loss=[] | ||
239 | - train_c_ce_loss=[] | ||
240 | - train_c_additional_loss=[] | ||
241 | - | ||
242 | - for epoch in tqdm(np.arange(0,FLAGS.EPOCHS)): | ||
243 | - | ||
244 | - index = np.arange(X_train.shape[0]) | ||
245 | - np.random.shuffle(index) | ||
246 | - batch_X = np.array_split(X_train[index], total_batches) | ||
247 | - batch_Y = np.array_split(y_train[index], total_batches) | ||
248 | - | ||
249 | - # Optimization | ||
250 | - for batch in range(total_batches): | ||
251 | - batch_x,batch_y=batch_X[batch],batch_Y[batch] | ||
252 | - sess.run(optimizer, feed_dict={X: batch_x,Y: batch_y,is_training:FLAGS.is_training,keep_prob:FLAGS.keep_prob,learning_rate:FLAGS.learning_rate}) | ||
253 | - | ||
254 | - if ((epoch+1) % FLAGS.display_step == 0) or (epoch==0) : | ||
255 | - if not((FLAGS.display_step==FLAGS.EPOCHS) and (epoch==0)): | ||
256 | - # Calculate batch loss and accuracy after an epoch on the train and validation set | ||
257 | - avg_cost,avg_acc,l1_norm_no_go,l1_norm_go = sess.run([total_loss, accuracy,norm_no_go_connexions,norm_go_connexions], feed_dict={X: X_train,Y: y_train, | ||
258 | - is_training:False,keep_prob:1.0}) | ||
259 | - train_c_total_loss.append(avg_cost) | ||
260 | - train_c_accuracy.append(avg_acc) | ||
261 | - c_l1_norm_go.append(l1_norm_go) | ||
262 | - c_l1_norm_no_go.append(l1_norm_no_go) | ||
263 | - | ||
264 | - if FLAGS.type_training!="": | ||
265 | - avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_train,Y: y_train,is_training:False,keep_prob:1.0}) | ||
266 | - train_c_additional_loss.append(avg_additional_loss) | ||
267 | - train_c_ce_loss.append(avg_ce_loss) | ||
268 | - | ||
269 | - avg_cost,avg_acc = sess.run([total_loss, accuracy], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
270 | - test_c_total_loss.append(avg_cost) | ||
271 | - test_c_accuracy.append(avg_acc) | ||
272 | - | ||
273 | - if FLAGS.type_training!="": | ||
274 | - avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
275 | - test_c_additional_loss.append(avg_additional_loss) | ||
276 | - test_c_ce_loss.append(avg_ce_loss) | ||
277 | - | ||
278 | - current_idx=len(train_c_total_loss)-1 | ||
279 | - print('| Epoch: {}/{} | Train: Loss {:.6f} Accuracy : {:.6f} '\ | ||
280 | - '| Test: Loss {:.6f} Accuracy : {:.6f}\n'.format( | ||
281 | - epoch+1, FLAGS.EPOCHS,train_c_total_loss[current_idx], train_c_accuracy[current_idx],test_c_total_loss[current_idx],test_c_accuracy[current_idx])) | ||
282 | - | ||
283 | - if FLAGS.save: saver.save(sess=sess, save_path=os.path.join(save_dir,"model")) | ||
284 | - | ||
285 | - end = time.time() | ||
286 | - elapsed=end - start | ||
287 | - print("Total time: {}h {}min {}sec ".format(time.gmtime(elapsed).tm_hour, | ||
288 | - time.gmtime(elapsed).tm_min, | ||
289 | - time.gmtime(elapsed).tm_sec)) | ||
290 | - | ||
291 | - performances = { | ||
292 | - 'type_training': FLAGS.type_training, | ||
293 | - 'total_loss':train_c_total_loss,'test_total_loss':test_c_total_loss, | ||
294 | - 'acc':train_c_accuracy,'test_acc':test_c_accuracy | ||
295 | - } | ||
296 | - | ||
297 | - performances['norm_go']=c_l1_norm_go | ||
298 | - performances['norm_no_go']=c_l1_norm_no_go | ||
299 | - | ||
300 | - if FLAGS.type_training!="baseline": | ||
301 | - performances['additional_loss']=train_c_additional_loss | ||
302 | - performances['test_additional_loss']=test_c_additional_loss | ||
303 | - performances['ce_loss']=train_c_ce_loss | ||
304 | - performances['test_ce_loss']=test_c_ce_loss | ||
305 | - | ||
306 | - | ||
307 | - return performances | ||
308 | - | ||
309 | - | ||
310 | -def main(_): | ||
311 | - | ||
312 | - save_dir=os.path.join(FLAGS.log_dir,'MLP_DP={}_BN={}_EPOCHS={}_OPT={}'.format(FLAGS.keep_prob,FLAGS.bn,FLAGS.EPOCHS,FLAGS.lr_method)) | ||
313 | - | ||
314 | - if FLAGS.type_training=="LGO" : | ||
315 | - save_dir=save_dir+'_LGO_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
316 | - elif FLAGS.type_training=="L2" : | ||
317 | - save_dir=save_dir+'_L2_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
318 | - elif FLAGS.type_training=="" : | ||
319 | - save_dir=save_dir+'_{}'.format(FLAGS.file_extension) | ||
320 | - elif FLAGS.type_training=="L1" : | ||
321 | - save_dir=save_dir+'_L1_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
322 | - | ||
323 | - if FLAGS.is_training: | ||
324 | - | ||
325 | - start_full = time.time() | ||
326 | - | ||
327 | - if not(os.path.isdir(save_dir)): | ||
328 | - os.mkdir(save_dir) | ||
329 | - | ||
330 | - performances=train(save_dir=save_dir) | ||
331 | - | ||
332 | - with open(os.path.join(save_dir,"histories.txt"), "wb") as fp: | ||
333 | - #Pickling | ||
334 | - pickle.dump(performances, fp) | ||
335 | - | ||
336 | - end = time.time() | ||
337 | - elapsed =end - start_full | ||
338 | - print("Total time full process: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
339 | - time.gmtime(elapsed).tm_min, | ||
340 | - time.gmtime(elapsed).tm_sec)) | ||
341 | - else: | ||
342 | - | ||
343 | - # ---------------------------------TO MODIFY :------------------------------ | ||
344 | - | ||
345 | - start_full = time.time() | ||
346 | - evaluate(save_dir=save_dir,ref_layer="h{}".format(1)) #TO DEFINE | ||
347 | - end = time.time() | ||
348 | - elapsed =end - start_full | ||
349 | - print("Total time full process: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
350 | - time.gmtime(elapsed).tm_min, | ||
351 | - time.gmtime(elapsed).tm_sec)) | ||
352 | - | ||
353 | - | ||
354 | -if __name__ == "__main__": | ||
355 | - tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment