Victoria BOURGEAIS

add notebooks to get the interpretation of a prediction and to build the GO laye…

…rs architecture of the NN
1 files/ 1 files/
2 .ipynb_checkpoints/ 2 .ipynb_checkpoints/
3 scripts/__pycache__/ 3 scripts/__pycache__/
4 +log/
5 +
......
...@@ -10,7 +10,7 @@ GraphGONet is a self-explaining neural network integrating the Gene Ontology int ...@@ -10,7 +10,7 @@ GraphGONet is a self-explaining neural network integrating the Gene Ontology int
10 10
11 ## Get started 11 ## Get started
12 12
13 -The code is implemented in Python (3.6.7) using the [PyTorch](https://pytorch.org/) framework v1.7.1 (see [requirements.txt](https://forge.ibisc.univ-evry.fr/vbourgeais/GraphGONet/blob/master/requirements.txt) for more details about the additional packages used). 13 +The code is implemented in Python (3.6.7) using [PyTorch v1.7.1](https://pytorch.org/) and [PyTorch-geometric v1.6.3](https://pytorch-geometric.readthedocs.io/en/1.6.3/modules/nn.html) (see [requirements.txt](https://forge.ibisc.univ-evry.fr/vbourgeais/GraphGONet/blob/master/requirements.txt) for more details about the additional packages used).
14 14
15 ## Dataset 15 ## Dataset
16 16
...@@ -31,49 +31,38 @@ There exists 3 functions (flag *processing*): one is dedicated to the training o ...@@ -31,49 +31,38 @@ There exists 3 functions (flag *processing*): one is dedicated to the training o
31 31
32 <!-- On the microarray dataset: 32 <!-- On the microarray dataset:
33 ```bash 33 ```bash
34 -python3 GraphGONet.py --save --n_inputs=36834 --n_nodes=10663 --n_nodes_annotated=8249 --n_classes=1 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight 34 +python3 scripts/GraphGONet.py --save --n_inputs=36834 --n_nodes=10663 --n_nodes_annotated=8249 --n_classes=1 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight
35 ``` 35 ```
36 --> 36 -->
37 37
38 ```bash 38 ```bash
39 -python3 GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight 39 +python3 scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight
40 ``` 40 ```
41 41
42 <!-- 42 <!--
43 ### 2) Evaluate 43 ### 2) Evaluate
44 44
45 -
46 -```bash
47 -python DeepGONet.py --type_training="LGO" --alpha=1e-2 --EPOCHS=600 --is_training=False --restore=True --processing="evaluate"
48 -```
49 -
50 ### 3) Predict 45 ### 3) Predict
51 46
52 -
53 -```bash
54 -python DeepGONet.py --type_training="LGO" --alpha=1e-2 --EPOCHS=600 --is_training=False --restore=True --processing="predict"
55 -```
56 -
57 -
58 The outcomes are saved into a numpy array. 47 The outcomes are saved into a numpy array.
59 --> 48 -->
60 49
61 ### Comparison with random selection 50 ### Comparison with random selection
62 51
63 ```bash 52 ```bash
64 -python GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="random" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight 53 +python scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="random" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight
65 ``` 54 ```
66 55
67 ### Comparison with no selection 56 ### Comparison with no selection
68 57
69 ```bash 58 ```bash
70 -python GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --n_epochs=50 --es --patience=5 --class_weight 59 +python scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --n_epochs=50 --es --patience=5 --class_weight
71 ``` 60 ```
72 61
73 ### Train the model with a small number of training samples 62 ### Train the model with a small number of training samples
74 63
75 ```bash 64 ```bash
76 -python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight 65 +python scripts/GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight
77 ``` 66 ```
78 67
79 ### Help 68 ### Help
...@@ -81,7 +70,7 @@ python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_ ...@@ -81,7 +70,7 @@ python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_
81 All the details about the command line flags can be provided by the following command: 70 All the details about the command line flags can be provided by the following command:
82 71
83 ```bash 72 ```bash
84 -python GraphGONet.py --help 73 +python scripts/GraphGONet.py --help
85 ``` 74 ```
86 75
87 For most of the flags, the default values can be employed. *dir_data*, *dir_files*, and *dir_log* can be set to your own repositories. Only the flags in the command lines displayed have to be adjusted to reproduce the results from the paper. If you have enough GPU memory, you can choose to switch to the entire GO graph (argument *type_graph* set to "entire"). The graph can be reconstructed by following the notebooks: Build_GONet_graph_part{1,2,3}.ipynb located in the notebooks directory. Then, you should change the value of the arguments *n_nodes* and *n_nodes_annotated* in the command line. 76 For most of the flags, the default values can be employed. *dir_data*, *dir_files*, and *dir_log* can be set to your own repositories. Only the flags in the command lines displayed have to be adjusted to reproduce the results from the paper. If you have enough GPU memory, you can choose to switch to the entire GO graph (argument *type_graph* set to "entire"). The graph can be reconstructed by following the notebooks: Build_GONet_graph_part{1,2,3}.ipynb located in the notebooks directory. Then, you should change the value of the arguments *n_nodes* and *n_nodes_annotated* in the command line.
......
This diff is collapsed. Click to expand it.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
...@@ -90,7 +90,7 @@ def train(args): ...@@ -90,7 +90,7 @@ def train(args):
90 print("Launching the learning") 90 print("Launching the learning")
91 device = torch.device(args.device) 91 device = torch.device(args.device)
92 model = Net(n_genes=args.n_inputs,n_nodes=args.n_nodes,n_nodes_annot=args.n_nodes_annotated,n_nodes_emb=args.dim_init,n_classes=args.n_classes, 92 model = Net(n_genes=args.n_inputs,n_nodes=args.n_nodes,n_nodes_annot=args.n_nodes_annotated,n_nodes_emb=args.dim_init,n_classes=args.n_classes,
93 - n_prop1=args.n_prop1,adj_mat_fc1=connection_matrix.values,mask=args.mask,ratio=args.selection_ratio).to(device) 93 + n_prop1=args.n_prop1,adj_mat_fc1=connection_matrix.values,selection=args.selection_op,ratio=args.selection_ratio).to(device)
94 print(model) 94 print(model)
95 print("(model mem allocation) - Memory available : {:.2e}".format(torch.cuda.memory_reserved(0)-torch.cuda.memory_allocated(0))) 95 print("(model mem allocation) - Memory available : {:.2e}".format(torch.cuda.memory_reserved(0)-torch.cuda.memory_allocated(0)))
96 96
...@@ -309,7 +309,7 @@ def main(): ...@@ -309,7 +309,7 @@ def main():
309 parser.add_argument('--n_classes', type=int, default=1, help="number of classes") 309 parser.add_argument('--n_classes', type=int, default=1, help="number of classes")
310 310
311 # -- Learning and Hyperparameters -- 311 # -- Learning and Hyperparameters --
312 - parser.add_argument('--mask', type=str, default=None, help='type of selection (random,top)') 312 + parser.add_argument('--selection_op', type=str, default=None, help='type of selection (random,top)')
313 parser.add_argument('--selection_ratio', type=float, default=0.5, help='selection ratio') 313 parser.add_argument('--selection_ratio', type=float, default=0.5, help='selection ratio')
314 parser.add_argument('--optimizer', type=str, default='adam', help="optimizer {adam, momentum, adagrad, rmsprop}") 314 parser.add_argument('--optimizer', type=str, default='adam', help="optimizer {adam, momentum, adagrad, rmsprop}")
315 parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 315 parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
...@@ -329,10 +329,10 @@ def main(): ...@@ -329,10 +329,10 @@ def main():
329 if not(os.path.isdir(args.dir_log)): 329 if not(os.path.isdir(args.dir_log)):
330 os.mkdir(args.dir_log) 330 os.mkdir(args.dir_log)
331 331
332 - if args.mask: 332 + if args.selection_op:
333 - args.dir_save=os.path.join(args.dir_log,'GraphGONet_MASK={}_SELECTRATIO={}'.format(args.mask,args.selection_ratio)) 333 + args.dir_save=os.path.join(args.dir_log,'GraphGONet_SELECTOP={}_SELECTRATIO={}'.format(args.selection_op,args.selection_ratio))
334 else: 334 else:
335 - args.dir_save=os.path.join(args.dir_log,'GraphGONet_MASK={}'.format(args.mask)) 335 + args.dir_save=os.path.join(args.dir_log,'GraphGONet_SELECTOP={}'.format(args.selection_op))
336 336
337 if args.n_samples: 337 if args.n_samples:
338 args.dir_save+="_N_SAMPLES={}".format(args.n_samples) 338 args.dir_save+="_N_SAMPLES={}".format(args.n_samples)
......
...@@ -189,9 +189,9 @@ def concatenate_and_mask(x: Tensor, batch: Tensor, idx_nodes_kept : Tensor, num_ ...@@ -189,9 +189,9 @@ def concatenate_and_mask(x: Tensor, batch: Tensor, idx_nodes_kept : Tensor, num_
189 output[i,mask]=x[i*num_nodes_kept[i]:(i+1)*num_nodes_kept[i]].view(-1) #shape : (num_nodes_by_graph,1) -> (_,max_num_nodes) 189 output[i,mask]=x[i*num_nodes_kept[i]:(i+1)*num_nodes_kept[i]].view(-1) #shape : (num_nodes_by_graph,1) -> (_,max_num_nodes)
190 return output 190 return output
191 191
192 -class NoSelection(torch.nn.Module): 192 +class Mask(torch.nn.Module):
193 def __init__(self, in_channels, method, n_nodes, **kwargs): 193 def __init__(self, in_channels, method, n_nodes, **kwargs):
194 - super(NoSelection, self).__init__() 194 + super(Mask, self).__init__()
195 self.method = method 195 self.method = method
196 self.in_channels = in_channels 196 self.in_channels = in_channels
197 if self.method.__name__ == "global_mean_pool": 197 if self.method.__name__ == "global_mean_pool":
...@@ -211,8 +211,8 @@ class NoSelection(torch.nn.Module): ...@@ -211,8 +211,8 @@ class NoSelection(torch.nn.Module):
211 211
212 class Net(torch.nn.Module): 212 class Net(torch.nn.Module):
213 def __init__(self,n_genes,n_nodes,n_nodes_annot,n_nodes_emb,n_prop1,n_classes,adj_mat_fc1, 213 def __init__(self,n_genes,n_nodes,n_nodes_annot,n_nodes_emb,n_prop1,n_classes,adj_mat_fc1,
214 - propagation="DAGProp",mask=None,ratio=1.0, 214 + propagation="DAGProp",selection=None,ratio=1.0,
215 - selection="concatenate_and_mask"): 215 + mask="concatenate_and_mask"):
216 super(Net, self).__init__() 216 super(Net, self).__init__()
217 self.n_genes = n_genes 217 self.n_genes = n_genes
218 self.n_nodes = n_nodes 218 self.n_nodes = n_nodes
...@@ -226,15 +226,15 @@ class Net(torch.nn.Module): ...@@ -226,15 +226,15 @@ class Net(torch.nn.Module):
226 with torch.no_grad(): 226 with torch.no_grad():
227 self.fc1.weight.mul_(self.adj_mat_fc1) #mask all the connections btw genes and neurons that do not represent GO annotations 227 self.fc1.weight.mul_(self.adj_mat_fc1) #mask all the connections btw genes and neurons that do not represent GO annotations
228 self.propagation = eval(propagation)(in_channels=n_nodes_emb, out_channels=n_prop1,aggr = "mean") # expected dim: [nSamples, nNodes, nChannels] 228 self.propagation = eval(propagation)(in_channels=n_nodes_emb, out_channels=n_prop1,aggr = "mean") # expected dim: [nSamples, nNodes, nChannels]
229 - if mask: 229 + if selection:
230 self.ratio = ratio 230 self.ratio = ratio
231 - if mask=="random": 231 + if selection=="random":
232 - self.mask = RandomSelection(in_channels=n_prop1,ratio=ratio) 232 + self.selection = RandomSelection(in_channels=n_prop1,ratio=ratio)
233 - elif mask=="top": 233 + elif selection=="top":
234 - self.mask = TopSelection(in_channels=n_prop1,ratio=ratio) 234 + self.selection = TopSelection(in_channels=n_prop1,ratio=ratio)
235 else: 235 else:
236 - selection="concatenate" 236 + mask="concatenate"
237 - self.selection = NoSelection(method=globals()[selection],in_channels=n_prop1,n_nodes=n_nodes) #option no selection => concatenate 237 + self.mask = Mask(method=globals()[mask],in_channels=n_prop1,n_nodes=n_nodes) #option no selection => concatenate
238 self.fc2 = Linear(in_features=n_nodes,out_features=n_classes) 238 self.fc2 = Linear(in_features=n_nodes,out_features=n_classes)
239 239
240 def forward(self,transcriptomic_data,graph_data): 240 def forward(self,transcriptomic_data,graph_data):
...@@ -247,10 +247,13 @@ class Net(torch.nn.Module): ...@@ -247,10 +247,13 @@ class Net(torch.nn.Module):
247 247
248 num_nodes = scatter_add(batch.new_ones(x.size(0),dtype=torch.int16), batch, dim=0) 248 num_nodes = scatter_add(batch.new_ones(x.size(0),dtype=torch.int16), batch, dim=0)
249 249
250 - if self.mask: 250 + if self.selection:
251 - x, edge_index, _, batch,idx_nodes_kept,_ = self.mask(x, edge_index, None, batch) 251 + x, edge_index, _, batch,idx_nodes_kept,_ = self.selection(x, edge_index, None, batch)
252 - 252 + if self.mask.method.__name__ == "concatenate_and_mask":
253 - x = self.selection(x,batch,idx_nodes_kept,num_nodes) 253 + x = self.mask(x,batch,idx_nodes_kept,num_nodes)
254 + else:
255 + x = self.mask(x,batch)
256 +
254 x = self.fc2(x) 257 x = self.fc2(x)
255 258
256 if self.n_classes>=2: 259 if self.n_classes>=2:
......