Showing
2 changed files
with
420 additions
and
0 deletions
base_model.py
0 → 100644
1 | +import numpy as np | ||
2 | +import tensorflow as tf | ||
3 | + | ||
4 | +FLAGS = tf.app.flags.FLAGS | ||
5 | + | ||
6 | +class BaseModel(): | ||
7 | + | ||
8 | + def __init__(self,X,n_input,n_classes,n_hidden_1,n_hidden_2,n_hidden_3,n_hidden_4,n_hidden_5,n_hidden_6,is_training,keep_prob): | ||
9 | + self.X = X | ||
10 | + self.n_input = n_input | ||
11 | + self.is_training = is_training | ||
12 | + | ||
13 | + #Hyperparameters | ||
14 | + self.keep_prob = keep_prob # Dropout | ||
15 | + | ||
16 | + self.n_classes=n_classes | ||
17 | + self.n_hidden_1=n_hidden_1 | ||
18 | + self.n_hidden_2=n_hidden_2 | ||
19 | + self.n_hidden_3=n_hidden_3 | ||
20 | + self.n_hidden_4=n_hidden_4 | ||
21 | + self.n_hidden_5=n_hidden_5 | ||
22 | + self.n_hidden_6=n_hidden_6 | ||
23 | + | ||
24 | + def store_layer_weights_and_bias(self): | ||
25 | + self.weights = { | ||
26 | + 'h1_w': tf.get_variable('W1', shape=(self.n_input, self.n_hidden_1), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
27 | + 'h2_w': tf.get_variable('W2', shape=(self.n_hidden_1, self.n_hidden_2), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
28 | + 'h3_w': tf.get_variable('W3', shape=(self.n_hidden_2, self.n_hidden_3), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
29 | + 'h4_w': tf.get_variable('W4', shape=(self.n_hidden_3, self.n_hidden_4), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
30 | + 'h5_w': tf.get_variable('W5', shape=(self.n_hidden_4, self.n_hidden_5), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
31 | + 'h6_w': tf.get_variable('W6', shape=(self.n_hidden_5, self.n_hidden_6), initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
32 | + 'out_w': tf.get_variable('W_out',shape=(self.n_hidden_6, self.n_classes), initializer=tf.contrib.layers.variance_scaling_initializer()) | ||
33 | + } | ||
34 | + self.biases = { | ||
35 | + 'h1_b': tf.get_variable('B1',shape=(self.n_hidden_1),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
36 | + 'h2_b': tf.get_variable('B2',shape=(self.n_hidden_2),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
37 | + 'h3_b': tf.get_variable('B3',shape=(self.n_hidden_3),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
38 | + 'h4_b': tf.get_variable('B4',shape=(self.n_hidden_4),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
39 | + 'h5_b': tf.get_variable('B5',shape=(self.n_hidden_5),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
40 | + 'h6_b': tf.get_variable('B6',shape=(self.n_hidden_6),initializer=tf.contrib.layers.variance_scaling_initializer()), | ||
41 | + 'out_b': tf.get_variable('B_out',shape=(self.n_classes),initializer=tf.contrib.layers.variance_scaling_initializer()) | ||
42 | + } | ||
43 | + | ||
44 | + def fc(self,input,weights,biases,name,dim): | ||
45 | + h = tf.add(tf.matmul(input, weights), biases) | ||
46 | + if FLAGS.bn: | ||
47 | + h = tf.layers.batch_normalization(h,training=self.is_training,name='bn_'+name) | ||
48 | + h = tf.nn.relu(h, name=name) | ||
49 | + h = tf.nn.dropout(h, self.keep_prob) | ||
50 | + return h | ||
51 | + | ||
52 | + def net(self): | ||
53 | + self.h1 = self.fc(self.X,self.weights['h1_w'],self.biases['h1_b'],name='layer1',dim=self.n_hidden_1) | ||
54 | + self.h2 = self.fc(self.h1,self.weights['h2_w'],self.biases['h2_b'],name='layer2',dim=self.n_hidden_2) | ||
55 | + self.h3 = self.fc(self.h2, self.weights['h3_w'], self.biases['h3_b'],name='layer3',dim=self.n_hidden_3) | ||
56 | + self.h4 = self.fc(self.h3,self.weights['h4_w'],self.biases['h4_b'],name='layer4',dim=self.n_hidden_4) | ||
57 | + self.h5 = self.fc(self.h4,self.weights['h5_w'],self.biases['h5_b'],name='layer5',dim=self.n_hidden_5) | ||
58 | + self.h6 = self.fc(self.h5, self.weights['h6_w'], self.biases['h6_b'],name='layer6',dim=self.n_hidden_6) | ||
59 | + output_layer = tf.add(tf.matmul(self.h6, self.weights['out_w']), self.biases['out_b'],name='output') | ||
60 | + return output_layer | ||
61 | + | ||
62 | + def __call__(self): | ||
63 | + self.store_layer_weights_and_bias() | ||
64 | + return self.net() | ||
65 | + |
train.py
0 → 100644
1 | +import warnings | ||
2 | +import os | ||
3 | +import time | ||
4 | +import signal | ||
5 | +import sys | ||
6 | +import copy | ||
7 | +import h5py | ||
8 | + | ||
9 | +import pickle | ||
10 | +import random | ||
11 | +import seaborn | ||
12 | +import numpy as np | ||
13 | +import matplotlib.pyplot as plt | ||
14 | +import pandas as pd | ||
15 | +from sklearn import preprocessing | ||
16 | +from sklearn.model_selection import train_test_split | ||
17 | +from sklearn.utils import class_weight | ||
18 | +import tensorflow as tf | ||
19 | +from tensorflow.keras.utils import to_categorical | ||
20 | +from tqdm import tqdm | ||
21 | + | ||
22 | +# Configuration | ||
23 | +FLAGS = tf.app.flags.FLAGS | ||
24 | + | ||
25 | +tf.app.flags.DEFINE_string('GPU_device', '/gpu:0', "GPU device") | ||
26 | + | ||
27 | +tf.app.flags.DEFINE_bool('save', False, "Do you need to save the trained model?") | ||
28 | +tf.app.flags.DEFINE_bool('restore', False, "Do you want to restore a previous trained model?") | ||
29 | + | ||
30 | +tf.app.flags.DEFINE_string('dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation", "dir") | ||
31 | +tf.app.flags.DEFINE_string('log_dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation/log", "log_dir") | ||
32 | +tf.app.flags.DEFINE_string('file_extension', "", "file_extension {sigmoid,softmax,without_bn}") | ||
33 | +tf.app.flags.DEFINE_string('dir_data', "/home/vbourgeais/Stage/data/MicroArray", "dir_data") | ||
34 | +tf.app.flags.DEFINE_string('temp_dir', "/nhome/siniac/vbourgeais/Documents/PhD/1ère année/Thèse/Interprétation", "temp_dir") | ||
35 | +tf.app.flags.DEFINE_integer('seed', 42, "initial random seed") | ||
36 | + | ||
37 | +#EVALUATION PART | ||
38 | +tf.app.flags.DEFINE_float('ref_value', 0.1, "value to test") | ||
39 | +tf.app.flags.DEFINE_string('ref_layer', "h1", "layer to analyze") | ||
40 | +tf.app.flags.DEFINE_string('ref_GO', "", "GO to examine") | ||
41 | + | ||
42 | +tf.app.flags.DEFINE_integer('display_step', 5, "when to print the performances") | ||
43 | + | ||
44 | +tf.app.flags.DEFINE_integer('batch_size', 2**9, "the number of examples in a batch") | ||
45 | +tf.app.flags.DEFINE_integer('EPOCHS', 20, "the number of epochs for training") | ||
46 | + | ||
47 | +tf.app.flags.DEFINE_integer('epoch_decay_start', 100, "epoch of starting learning rate decay") | ||
48 | +tf.app.flags.DEFINE_bool('early_stopping', False, "early_stopping") | ||
49 | + | ||
50 | +tf.app.flags.DEFINE_integer('n_input', 54675, "number of features") | ||
51 | +tf.app.flags.DEFINE_integer('n_classes', 1, "number of classes") | ||
52 | +tf.app.flags.DEFINE_integer('n_layers', 6, "number of layers") | ||
53 | +tf.app.flags.DEFINE_integer('n_hidden_1', 1574, "number of nodes for the first hidden layer") #Level 7 | ||
54 | +tf.app.flags.DEFINE_integer('n_hidden_2', 1386, "number of nodes for the second hidden layer") #Level 6 | ||
55 | +tf.app.flags.DEFINE_integer('n_hidden_3', 951, "number of nodes for the third hidden layer") #Level 5 | ||
56 | +tf.app.flags.DEFINE_integer('n_hidden_4', 515, "number of nodes for the fourth hidden layer") #Level 4 | ||
57 | +tf.app.flags.DEFINE_integer('n_hidden_5', 255, "number of nodes for the fifth hidden layer") #Level 3 | ||
58 | +tf.app.flags.DEFINE_integer('n_hidden_6', 90, "number of nodes for the sixth hidden layer") #Level 2 | ||
59 | + | ||
60 | +tf.app.flags.DEFINE_float('learning_rate', 0.001, "initial learning rate") | ||
61 | +tf.app.flags.DEFINE_bool('bn', False, "BN use") | ||
62 | +tf.app.flags.DEFINE_bool('is_training', True, "Is it trainable?") | ||
63 | +tf.app.flags.DEFINE_float('keep_prob', 0.4, "probability for the dropout") | ||
64 | +tf.app.flags.DEFINE_string('type_training', 'LGO', "{"", LGO, L2, L1}") | ||
65 | +tf.app.flags.DEFINE_float('alpha', 1, "alpha") | ||
66 | +tf.app.flags.DEFINE_bool('weighted_loss', False, "balance the data in the total loss") | ||
67 | +tf.app.flags.DEFINE_string('lr_method', 'adam', "{adam, momentum, adagrad, rmsprop}") | ||
68 | + | ||
69 | +from base_model import BaseModel | ||
70 | + | ||
71 | +def l1_loss_func(x): | ||
72 | + return tf.reduce_sum(tf.math.abs(x)) | ||
73 | + | ||
74 | +def l2_loss_func(x): | ||
75 | + return tf.reduce_sum(tf.square(x)) | ||
76 | + | ||
77 | + | ||
78 | +def train(save_dir): | ||
79 | + | ||
80 | + warnings.filterwarnings("ignore") | ||
81 | + os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | ||
82 | + os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.GPU_device[len(FLAGS.GPU_device)-1] | ||
83 | + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
84 | + | ||
85 | + # Load the files useful | ||
86 | + print("Loading the connexion matrix...") | ||
87 | + start = time.time() | ||
88 | + | ||
89 | + adj_matrix = pd.read_csv(os.path.join(FLAGS.dir,"adj_matrix_cropped.csv"),index_col=0) | ||
90 | + first_matrix_connection = pd.read_csv(os.path.join(FLAGS.dir,"first_matrix_connection_GO.csv"),index_col=0) | ||
91 | + csv_go = pd.read_csv(os.path.join(FLAGS.dir,"go_level_v2.csv"),index_col=0) | ||
92 | + | ||
93 | + connexion_matrix = [] | ||
94 | + connexion_matrix.append(np.array(first_matrix_connection.values,dtype=np.float32)) | ||
95 | + connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(7)].loc[lambda x: x==1].index,csv_go[str(6)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
96 | + connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(6)].loc[lambda x: x==1].index,csv_go[str(5)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
97 | + connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(5)].loc[lambda x: x==1].index,csv_go[str(4)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
98 | + connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(4)].loc[lambda x: x==1].index,csv_go[str(3)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
99 | + connexion_matrix.append(np.array(adj_matrix.loc[csv_go[str(3)].loc[lambda x: x==1].index,csv_go[str(2)].loc[lambda x: x==1].index].values,dtype=np.float32)) | ||
100 | + connexion_matrix.append(np.ones((FLAGS.n_hidden_6, FLAGS.n_classes),dtype=np.float32)) | ||
101 | + | ||
102 | + end = time.time() | ||
103 | + elapsed=end - start | ||
104 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
105 | + time.gmtime(elapsed).tm_min, | ||
106 | + time.gmtime(elapsed).tm_sec)) | ||
107 | + | ||
108 | + # Load the data | ||
109 | + print("Loading the data...") | ||
110 | + start = time.time() | ||
111 | + loaded = np.load(os.path.join(FLAGS.dir_data,"X_train.npz")) | ||
112 | + X_train = loaded['x'] | ||
113 | + | ||
114 | + y_train = loaded['y'] | ||
115 | + if FLAGS.n_classes>=2: | ||
116 | + y_train=to_categorical(y_train) | ||
117 | + | ||
118 | + loaded = np.load(os.path.join(FLAGS.dir_data,"X_test.npz")) | ||
119 | + X_test = loaded['x'] | ||
120 | + y_test = loaded['y'] | ||
121 | + if FLAGS.n_classes>=2: | ||
122 | + y_test=to_categorical(y_test) | ||
123 | + | ||
124 | + | ||
125 | + | ||
126 | + end = time.time() | ||
127 | + elapsed=end - start | ||
128 | + print("Total time: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
129 | + time.gmtime(elapsed).tm_min, | ||
130 | + time.gmtime(elapsed).tm_sec)) | ||
131 | + | ||
132 | + | ||
133 | + # Launch the model | ||
134 | + print("Launch the learning with the "+FLAGS.type_training) | ||
135 | + if FLAGS.type_training != "baseline": | ||
136 | + print("for ALPHA={}".format(FLAGS.alpha)) | ||
137 | + | ||
138 | + tf.reset_default_graph() | ||
139 | + | ||
140 | + | ||
141 | + #Inputs of the model | ||
142 | + X = tf.placeholder(tf.float32, shape=[None, FLAGS.n_input]) | ||
143 | + Y = tf.placeholder(tf.float32, shape=[None, FLAGS.n_classes]) | ||
144 | + | ||
145 | + #Hyperparameters | ||
146 | + is_training = tf.placeholder(tf.bool,name="is_training") #batch Norm | ||
147 | + learning_rate = tf.placeholder(tf.float32, name="learning_rate") | ||
148 | + keep_prob = tf.placeholder(tf.float32, name="keep_prob") # Dropout | ||
149 | + total_batches=len(X_train)//FLAGS.batch_size | ||
150 | + | ||
151 | + network=BaseModel(X=X,n_input=FLAGS.n_input,n_classes=FLAGS.n_classes, | ||
152 | + n_hidden_1=FLAGS.n_hidden_1,n_hidden_2=FLAGS.n_hidden_2,n_hidden_3=FLAGS.n_hidden_3,n_hidden_4=FLAGS.n_hidden_4, | ||
153 | + n_hidden_5=FLAGS.n_hidden_5,n_hidden_6=FLAGS.n_hidden_6,keep_prob=keep_prob,is_training=is_training) | ||
154 | + #here we can compute the model both for l2 custom and no-custom | ||
155 | + | ||
156 | + pred = network() | ||
157 | + | ||
158 | + #Compute the average of the loss across all the dimensions | ||
159 | + if FLAGS.weighted_loss: | ||
160 | + ce_loss = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=Y,pos_weight=class_weights[1])) | ||
161 | + else: | ||
162 | + ce_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=Y)) | ||
163 | + | ||
164 | + additional_loss = 0 | ||
165 | + if FLAGS.type_training=="LGO": | ||
166 | + for idx,weight in enumerate(network.weights.values()): | ||
167 | + additional_loss+=l2_loss_func(weight*(1-connexion_matrix[idx])) | ||
168 | + elif FLAGS.type_training=="L2" : | ||
169 | + for weight in network.weights.values(): | ||
170 | + additional_loss += l2_loss_func(weight) | ||
171 | + elif FLAGS.type_training=="L1" : | ||
172 | + for idx,weight in enumerate(network.weights.values()): | ||
173 | + additional_loss+=l1_loss_func(weight) | ||
174 | + | ||
175 | + | ||
176 | + norm_no_go_connexions=0 | ||
177 | + norm_go_connexions=0 | ||
178 | + for idx,weight in enumerate(list(network.weights.values())[:-1]): | ||
179 | + norm_no_go_connexions+=tf.norm((weight*(1-connexion_matrix[idx])),ord=1)/np.count_nonzero(1-connexion_matrix[idx]) | ||
180 | + norm_go_connexions+=tf.norm((weight*connexion_matrix[idx]),ord=1)/np.count_nonzero(connexion_matrix[idx]) | ||
181 | + norm_no_go_connexions/=FLAGS.n_layers | ||
182 | + norm_go_connexions/=FLAGS.n_layers | ||
183 | + | ||
184 | + if FLAGS.type_training!='' : | ||
185 | + total_loss = ce_loss + FLAGS.alpha*additional_loss | ||
186 | + else: | ||
187 | + total_loss = ce_loss | ||
188 | + | ||
189 | + #optimizer | ||
190 | + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): | ||
191 | + if FLAGS.lr_method=="adam": | ||
192 | + trainer = tf.train.AdamOptimizer(learning_rate = learning_rate) | ||
193 | + elif FLAGS.lr_method=="momentum": | ||
194 | + trainer = tf.train.MomentumOptimizer(learning_rate = learning_rate, momentum=0.09, use_nesterov=True) | ||
195 | + elif FLAGS.lr_method=="adagrad": | ||
196 | + trainer = tf.train.AdagradOptimizer(learning_rate=learning_rate) | ||
197 | + elif FLAGS.lr_method=="rmsprop": | ||
198 | + trainer = tf.train.RMSPropOptimizer(learning_rate = learning_rate) | ||
199 | + optimizer = trainer.minimize(total_loss) | ||
200 | + | ||
201 | + if FLAGS.n_classes>=2: | ||
202 | + correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(Y, 1)) | ||
203 | + else: | ||
204 | + sig_pred=tf.nn.sigmoid(pred) | ||
205 | + sig_pred=tf.cast(sig_pred>0.5,dtype=tf.int64) | ||
206 | + ground_truth=tf.cast(Y,dtype=tf.int64) | ||
207 | + correct_prediction = tf.equal(sig_pred,ground_truth) | ||
208 | + | ||
209 | + #Calculate the accuracy across all the given batch and average them out. | ||
210 | + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
211 | + | ||
212 | + # Initializing the variables | ||
213 | + init = tf.global_variables_initializer() | ||
214 | + | ||
215 | + config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) | ||
216 | + #config.gpu_options.allow_growth = True, log_device_placement=True | ||
217 | + #to use the tensorboard | ||
218 | + | ||
219 | + if FLAGS.save or FLAGS.restore : saver = tf.train.Saver() | ||
220 | + | ||
221 | + start = time.time() | ||
222 | + | ||
223 | + with tf.device(FLAGS.GPU_device): | ||
224 | + with tf.Session(config=config) as sess: | ||
225 | + sess.run(init) | ||
226 | + | ||
227 | + train_c_accuracy=[] | ||
228 | + train_c_total_loss=[] | ||
229 | + | ||
230 | + test_c_accuracy=[] | ||
231 | + test_c_total_loss=[] | ||
232 | + | ||
233 | + c_l1_norm_go=[] | ||
234 | + c_l1_norm_no_go=[] | ||
235 | + | ||
236 | + if FLAGS.type_training!="": | ||
237 | + test_c_ce_loss=[] | ||
238 | + test_c_additional_loss=[] | ||
239 | + train_c_ce_loss=[] | ||
240 | + train_c_additional_loss=[] | ||
241 | + | ||
242 | + for epoch in tqdm(np.arange(0,FLAGS.EPOCHS)): | ||
243 | + | ||
244 | + index = np.arange(X_train.shape[0]) | ||
245 | + np.random.shuffle(index) | ||
246 | + batch_X = np.array_split(X_train[index], total_batches) | ||
247 | + batch_Y = np.array_split(y_train[index], total_batches) | ||
248 | + | ||
249 | + # Optimization | ||
250 | + for batch in range(total_batches): | ||
251 | + batch_x,batch_y=batch_X[batch],batch_Y[batch] | ||
252 | + sess.run(optimizer, feed_dict={X: batch_x,Y: batch_y,is_training:FLAGS.is_training,keep_prob:FLAGS.keep_prob,learning_rate:FLAGS.learning_rate}) | ||
253 | + | ||
254 | + if ((epoch+1) % FLAGS.display_step == 0) or (epoch==0) : | ||
255 | + if not((FLAGS.display_step==FLAGS.EPOCHS) and (epoch==0)): | ||
256 | + # Calculate batch loss and accuracy after an epoch on the train and validation set | ||
257 | + avg_cost,avg_acc,l1_norm_no_go,l1_norm_go = sess.run([total_loss, accuracy,norm_no_go_connexions,norm_go_connexions], feed_dict={X: X_train,Y: y_train, | ||
258 | + is_training:False,keep_prob:1.0}) | ||
259 | + train_c_total_loss.append(avg_cost) | ||
260 | + train_c_accuracy.append(avg_acc) | ||
261 | + c_l1_norm_go.append(l1_norm_go) | ||
262 | + c_l1_norm_no_go.append(l1_norm_no_go) | ||
263 | + | ||
264 | + if FLAGS.type_training!="": | ||
265 | + avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_train,Y: y_train,is_training:False,keep_prob:1.0}) | ||
266 | + train_c_additional_loss.append(avg_additional_loss) | ||
267 | + train_c_ce_loss.append(avg_ce_loss) | ||
268 | + | ||
269 | + avg_cost,avg_acc = sess.run([total_loss, accuracy], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
270 | + test_c_total_loss.append(avg_cost) | ||
271 | + test_c_accuracy.append(avg_acc) | ||
272 | + | ||
273 | + if FLAGS.type_training!="": | ||
274 | + avg_ce_loss,avg_additional_loss= sess.run([ce_loss, additional_loss], feed_dict={X: X_test,Y: y_test,is_training:False,keep_prob:1.0}) | ||
275 | + test_c_additional_loss.append(avg_additional_loss) | ||
276 | + test_c_ce_loss.append(avg_ce_loss) | ||
277 | + | ||
278 | + current_idx=len(train_c_total_loss)-1 | ||
279 | + print('| Epoch: {}/{} | Train: Loss {:.6f} Accuracy : {:.6f} '\ | ||
280 | + '| Test: Loss {:.6f} Accuracy : {:.6f}\n'.format( | ||
281 | + epoch+1, FLAGS.EPOCHS,train_c_total_loss[current_idx], train_c_accuracy[current_idx],test_c_total_loss[current_idx],test_c_accuracy[current_idx])) | ||
282 | + | ||
283 | + if FLAGS.save: saver.save(sess=sess, save_path=os.path.join(save_dir,"model")) | ||
284 | + | ||
285 | + end = time.time() | ||
286 | + elapsed=end - start | ||
287 | + print("Total time: {}h {}min {}sec ".format(time.gmtime(elapsed).tm_hour, | ||
288 | + time.gmtime(elapsed).tm_min, | ||
289 | + time.gmtime(elapsed).tm_sec)) | ||
290 | + | ||
291 | + performances = { | ||
292 | + 'type_training': FLAGS.type_training, | ||
293 | + 'total_loss':train_c_total_loss,'test_total_loss':test_c_total_loss, | ||
294 | + 'acc':train_c_accuracy,'test_acc':test_c_accuracy | ||
295 | + } | ||
296 | + | ||
297 | + performances['norm_go']=c_l1_norm_go | ||
298 | + performances['norm_no_go']=c_l1_norm_no_go | ||
299 | + | ||
300 | + if FLAGS.type_training!="baseline": | ||
301 | + performances['additional_loss']=train_c_additional_loss | ||
302 | + performances['test_additional_loss']=test_c_additional_loss | ||
303 | + performances['ce_loss']=train_c_ce_loss | ||
304 | + performances['test_ce_loss']=test_c_ce_loss | ||
305 | + | ||
306 | + | ||
307 | + return performances | ||
308 | + | ||
309 | + | ||
310 | +def main(_): | ||
311 | + | ||
312 | + save_dir=os.path.join(FLAGS.log_dir,'MLP_DP={}_BN={}_EPOCHS={}_OPT={}'.format(FLAGS.keep_prob,FLAGS.bn,FLAGS.EPOCHS,FLAGS.lr_method)) | ||
313 | + | ||
314 | + if FLAGS.type_training=="LGO" : | ||
315 | + save_dir=save_dir+'_LGO_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
316 | + elif FLAGS.type_training=="L2" : | ||
317 | + save_dir=save_dir+'_L2_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
318 | + elif FLAGS.type_training=="" : | ||
319 | + save_dir=save_dir+'_{}'.format(FLAGS.file_extension) | ||
320 | + elif FLAGS.type_training=="L1" : | ||
321 | + save_dir=save_dir+'_L1_ALPHA={}{}'.format(FLAGS.alpha,FLAGS.file_extension) | ||
322 | + | ||
323 | + if FLAGS.is_training: | ||
324 | + | ||
325 | + start_full = time.time() | ||
326 | + | ||
327 | + if not(os.path.isdir(save_dir)): | ||
328 | + os.mkdir(save_dir) | ||
329 | + | ||
330 | + performances=train(save_dir=save_dir) | ||
331 | + | ||
332 | + with open(os.path.join(save_dir,"histories.txt"), "wb") as fp: | ||
333 | + #Pickling | ||
334 | + pickle.dump(performances, fp) | ||
335 | + | ||
336 | + end = time.time() | ||
337 | + elapsed =end - start_full | ||
338 | + print("Total time full process: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
339 | + time.gmtime(elapsed).tm_min, | ||
340 | + time.gmtime(elapsed).tm_sec)) | ||
341 | + else: | ||
342 | + | ||
343 | + # ---------------------------------TO MODIFY :------------------------------ | ||
344 | + | ||
345 | + start_full = time.time() | ||
346 | + evaluate(save_dir=save_dir,ref_layer="h{}".format(1)) #TO DEFINE | ||
347 | + end = time.time() | ||
348 | + elapsed =end - start_full | ||
349 | + print("Total time full process: {}h {}min {}sec".format(time.gmtime(elapsed).tm_hour, | ||
350 | + time.gmtime(elapsed).tm_min, | ||
351 | + time.gmtime(elapsed).tm_sec)) | ||
352 | + | ||
353 | + | ||
354 | +if __name__ == "__main__": | ||
355 | + tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment