add notebooks to get the interpretation of a prediction and to build the GO laye…
…rs architecture of the NN
Showing
8 changed files
with
32 additions
and
38 deletions
... | @@ -10,7 +10,7 @@ GraphGONet is a self-explaining neural network integrating the Gene Ontology int | ... | @@ -10,7 +10,7 @@ GraphGONet is a self-explaining neural network integrating the Gene Ontology int |
10 | 10 | ||
11 | ## Get started | 11 | ## Get started |
12 | 12 | ||
13 | -The code is implemented in Python (3.6.7) using the [PyTorch](https://pytorch.org/) framework v1.7.1 (see [requirements.txt](https://forge.ibisc.univ-evry.fr/vbourgeais/GraphGONet/blob/master/requirements.txt) for more details about the additional packages used). | 13 | +The code is implemented in Python (3.6.7) using [PyTorch v1.7.1](https://pytorch.org/) and [PyTorch-geometric v1.6.3](https://pytorch-geometric.readthedocs.io/en/1.6.3/modules/nn.html) (see [requirements.txt](https://forge.ibisc.univ-evry.fr/vbourgeais/GraphGONet/blob/master/requirements.txt) for more details about the additional packages used). |
14 | 14 | ||
15 | ## Dataset | 15 | ## Dataset |
16 | 16 | ||
... | @@ -31,49 +31,38 @@ There exists 3 functions (flag *processing*): one is dedicated to the training o | ... | @@ -31,49 +31,38 @@ There exists 3 functions (flag *processing*): one is dedicated to the training o |
31 | 31 | ||
32 | <!-- On the microarray dataset: | 32 | <!-- On the microarray dataset: |
33 | ```bash | 33 | ```bash |
34 | -python3 GraphGONet.py --save --n_inputs=36834 --n_nodes=10663 --n_nodes_annotated=8249 --n_classes=1 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight | 34 | +python3 scripts/GraphGONet.py --save --n_inputs=36834 --n_nodes=10663 --n_nodes_annotated=8249 --n_classes=1 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight |
35 | ``` | 35 | ``` |
36 | --> | 36 | --> |
37 | 37 | ||
38 | ```bash | 38 | ```bash |
39 | -python3 GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight | 39 | +python3 scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight |
40 | ``` | 40 | ``` |
41 | 41 | ||
42 | <!-- | 42 | <!-- |
43 | ### 2) Evaluate | 43 | ### 2) Evaluate |
44 | 44 | ||
45 | - | ||
46 | -```bash | ||
47 | -python DeepGONet.py --type_training="LGO" --alpha=1e-2 --EPOCHS=600 --is_training=False --restore=True --processing="evaluate" | ||
48 | -``` | ||
49 | - | ||
50 | ### 3) Predict | 45 | ### 3) Predict |
51 | 46 | ||
52 | - | ||
53 | -```bash | ||
54 | -python DeepGONet.py --type_training="LGO" --alpha=1e-2 --EPOCHS=600 --is_training=False --restore=True --processing="predict" | ||
55 | -``` | ||
56 | - | ||
57 | - | ||
58 | The outcomes are saved into a numpy array. | 47 | The outcomes are saved into a numpy array. |
59 | --> | 48 | --> |
60 | 49 | ||
61 | ### Comparison with random selection | 50 | ### Comparison with random selection |
62 | 51 | ||
63 | ```bash | 52 | ```bash |
64 | -python GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="random" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight | 53 | +python scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="random" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight |
65 | ``` | 54 | ``` |
66 | 55 | ||
67 | ### Comparison with no selection | 56 | ### Comparison with no selection |
68 | 57 | ||
69 | ```bash | 58 | ```bash |
70 | -python GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --n_epochs=50 --es --patience=5 --class_weight | 59 | +python scripts/GraphGONet.py --save --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --n_epochs=50 --es --patience=5 --class_weight |
71 | ``` | 60 | ``` |
72 | 61 | ||
73 | ### Train the model with a small number of training samples | 62 | ### Train the model with a small number of training samples |
74 | 63 | ||
75 | ```bash | 64 | ```bash |
76 | -python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --mask="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight | 65 | +python scripts/GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_nodes_annotated=8288 --n_classes=12 --selection_op="top" --selection_ratio=0.001 --n_epochs=50 --es --patience=5 --class_weight |
77 | ``` | 66 | ``` |
78 | 67 | ||
79 | ### Help | 68 | ### Help |
... | @@ -81,7 +70,7 @@ python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_ | ... | @@ -81,7 +70,7 @@ python GraphGONet.py --save --n_samples=50 --n_inputs=18427 --n_nodes=10636 --n_ |
81 | All the details about the command line flags can be provided by the following command: | 70 | All the details about the command line flags can be provided by the following command: |
82 | 71 | ||
83 | ```bash | 72 | ```bash |
84 | -python GraphGONet.py --help | 73 | +python scripts/GraphGONet.py --help |
85 | ``` | 74 | ``` |
86 | 75 | ||
87 | For most of the flags, the default values can be employed. *dir_data*, *dir_files*, and *dir_log* can be set to your own repositories. Only the flags in the command lines displayed have to be adjusted to reproduce the results from the paper. If you have enough GPU memory, you can choose to switch to the entire GO graph (argument *type_graph* set to "entire"). The graph can be reconstructed by following the notebooks: Build_GONet_graph_part{1,2,3}.ipynb located in the notebooks directory. Then, you should change the value of the arguments *n_nodes* and *n_nodes_annotated* in the command line. | 76 | For most of the flags, the default values can be employed. *dir_data*, *dir_files*, and *dir_log* can be set to your own repositories. Only the flags in the command lines displayed have to be adjusted to reproduce the results from the paper. If you have enough GPU memory, you can choose to switch to the entire GO graph (argument *type_graph* set to "entire"). The graph can be reconstructed by following the notebooks: Build_GONet_graph_part{1,2,3}.ipynb located in the notebooks directory. Then, you should change the value of the arguments *n_nodes* and *n_nodes_annotated* in the command line. | ... | ... |
notebooks/1-Build_GONet_graph_part1.ipynb
0 → 100644
This diff is collapsed. Click to expand it.
notebooks/1-Build_GONet_graph_part2.ipynb
0 → 100644
This diff could not be displayed because it is too large.
notebooks/1-Build_GONet_graph_part3.ipynb
0 → 100644
This diff could not be displayed because it is too large.
notebooks/Interpretation.ipynb
0 → 100644
This diff could not be displayed because it is too large.
... | @@ -90,7 +90,7 @@ def train(args): | ... | @@ -90,7 +90,7 @@ def train(args): |
90 | print("Launching the learning") | 90 | print("Launching the learning") |
91 | device = torch.device(args.device) | 91 | device = torch.device(args.device) |
92 | model = Net(n_genes=args.n_inputs,n_nodes=args.n_nodes,n_nodes_annot=args.n_nodes_annotated,n_nodes_emb=args.dim_init,n_classes=args.n_classes, | 92 | model = Net(n_genes=args.n_inputs,n_nodes=args.n_nodes,n_nodes_annot=args.n_nodes_annotated,n_nodes_emb=args.dim_init,n_classes=args.n_classes, |
93 | - n_prop1=args.n_prop1,adj_mat_fc1=connection_matrix.values,mask=args.mask,ratio=args.selection_ratio).to(device) | 93 | + n_prop1=args.n_prop1,adj_mat_fc1=connection_matrix.values,selection=args.selection_op,ratio=args.selection_ratio).to(device) |
94 | print(model) | 94 | print(model) |
95 | print("(model mem allocation) - Memory available : {:.2e}".format(torch.cuda.memory_reserved(0)-torch.cuda.memory_allocated(0))) | 95 | print("(model mem allocation) - Memory available : {:.2e}".format(torch.cuda.memory_reserved(0)-torch.cuda.memory_allocated(0))) |
96 | 96 | ||
... | @@ -309,7 +309,7 @@ def main(): | ... | @@ -309,7 +309,7 @@ def main(): |
309 | parser.add_argument('--n_classes', type=int, default=1, help="number of classes") | 309 | parser.add_argument('--n_classes', type=int, default=1, help="number of classes") |
310 | 310 | ||
311 | # -- Learning and Hyperparameters -- | 311 | # -- Learning and Hyperparameters -- |
312 | - parser.add_argument('--mask', type=str, default=None, help='type of selection (random,top)') | 312 | + parser.add_argument('--selection_op', type=str, default=None, help='type of selection (random,top)') |
313 | parser.add_argument('--selection_ratio', type=float, default=0.5, help='selection ratio') | 313 | parser.add_argument('--selection_ratio', type=float, default=0.5, help='selection ratio') |
314 | parser.add_argument('--optimizer', type=str, default='adam', help="optimizer {adam, momentum, adagrad, rmsprop}") | 314 | parser.add_argument('--optimizer', type=str, default='adam', help="optimizer {adam, momentum, adagrad, rmsprop}") |
315 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') | 315 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') |
... | @@ -329,10 +329,10 @@ def main(): | ... | @@ -329,10 +329,10 @@ def main(): |
329 | if not(os.path.isdir(args.dir_log)): | 329 | if not(os.path.isdir(args.dir_log)): |
330 | os.mkdir(args.dir_log) | 330 | os.mkdir(args.dir_log) |
331 | 331 | ||
332 | - if args.mask: | 332 | + if args.selection_op: |
333 | - args.dir_save=os.path.join(args.dir_log,'GraphGONet_MASK={}_SELECTRATIO={}'.format(args.mask,args.selection_ratio)) | 333 | + args.dir_save=os.path.join(args.dir_log,'GraphGONet_SELECTOP={}_SELECTRATIO={}'.format(args.selection_op,args.selection_ratio)) |
334 | else: | 334 | else: |
335 | - args.dir_save=os.path.join(args.dir_log,'GraphGONet_MASK={}'.format(args.mask)) | 335 | + args.dir_save=os.path.join(args.dir_log,'GraphGONet_SELECTOP={}'.format(args.selection_op)) |
336 | 336 | ||
337 | if args.n_samples: | 337 | if args.n_samples: |
338 | args.dir_save+="_N_SAMPLES={}".format(args.n_samples) | 338 | args.dir_save+="_N_SAMPLES={}".format(args.n_samples) | ... | ... |
... | @@ -189,9 +189,9 @@ def concatenate_and_mask(x: Tensor, batch: Tensor, idx_nodes_kept : Tensor, num_ | ... | @@ -189,9 +189,9 @@ def concatenate_and_mask(x: Tensor, batch: Tensor, idx_nodes_kept : Tensor, num_ |
189 | output[i,mask]=x[i*num_nodes_kept[i]:(i+1)*num_nodes_kept[i]].view(-1) #shape : (num_nodes_by_graph,1) -> (_,max_num_nodes) | 189 | output[i,mask]=x[i*num_nodes_kept[i]:(i+1)*num_nodes_kept[i]].view(-1) #shape : (num_nodes_by_graph,1) -> (_,max_num_nodes) |
190 | return output | 190 | return output |
191 | 191 | ||
192 | -class NoSelection(torch.nn.Module): | 192 | +class Mask(torch.nn.Module): |
193 | def __init__(self, in_channels, method, n_nodes, **kwargs): | 193 | def __init__(self, in_channels, method, n_nodes, **kwargs): |
194 | - super(NoSelection, self).__init__() | 194 | + super(Mask, self).__init__() |
195 | self.method = method | 195 | self.method = method |
196 | self.in_channels = in_channels | 196 | self.in_channels = in_channels |
197 | if self.method.__name__ == "global_mean_pool": | 197 | if self.method.__name__ == "global_mean_pool": |
... | @@ -211,8 +211,8 @@ class NoSelection(torch.nn.Module): | ... | @@ -211,8 +211,8 @@ class NoSelection(torch.nn.Module): |
211 | 211 | ||
212 | class Net(torch.nn.Module): | 212 | class Net(torch.nn.Module): |
213 | def __init__(self,n_genes,n_nodes,n_nodes_annot,n_nodes_emb,n_prop1,n_classes,adj_mat_fc1, | 213 | def __init__(self,n_genes,n_nodes,n_nodes_annot,n_nodes_emb,n_prop1,n_classes,adj_mat_fc1, |
214 | - propagation="DAGProp",mask=None,ratio=1.0, | 214 | + propagation="DAGProp",selection=None,ratio=1.0, |
215 | - selection="concatenate_and_mask"): | 215 | + mask="concatenate_and_mask"): |
216 | super(Net, self).__init__() | 216 | super(Net, self).__init__() |
217 | self.n_genes = n_genes | 217 | self.n_genes = n_genes |
218 | self.n_nodes = n_nodes | 218 | self.n_nodes = n_nodes |
... | @@ -226,15 +226,15 @@ class Net(torch.nn.Module): | ... | @@ -226,15 +226,15 @@ class Net(torch.nn.Module): |
226 | with torch.no_grad(): | 226 | with torch.no_grad(): |
227 | self.fc1.weight.mul_(self.adj_mat_fc1) #mask all the connections btw genes and neurons that do not represent GO annotations | 227 | self.fc1.weight.mul_(self.adj_mat_fc1) #mask all the connections btw genes and neurons that do not represent GO annotations |
228 | self.propagation = eval(propagation)(in_channels=n_nodes_emb, out_channels=n_prop1,aggr = "mean") # expected dim: [nSamples, nNodes, nChannels] | 228 | self.propagation = eval(propagation)(in_channels=n_nodes_emb, out_channels=n_prop1,aggr = "mean") # expected dim: [nSamples, nNodes, nChannels] |
229 | - if mask: | 229 | + if selection: |
230 | self.ratio = ratio | 230 | self.ratio = ratio |
231 | - if mask=="random": | 231 | + if selection=="random": |
232 | - self.mask = RandomSelection(in_channels=n_prop1,ratio=ratio) | 232 | + self.selection = RandomSelection(in_channels=n_prop1,ratio=ratio) |
233 | - elif mask=="top": | 233 | + elif selection=="top": |
234 | - self.mask = TopSelection(in_channels=n_prop1,ratio=ratio) | 234 | + self.selection = TopSelection(in_channels=n_prop1,ratio=ratio) |
235 | else: | 235 | else: |
236 | - selection="concatenate" | 236 | + mask="concatenate" |
237 | - self.selection = NoSelection(method=globals()[selection],in_channels=n_prop1,n_nodes=n_nodes) #option no selection => concatenate | 237 | + self.mask = Mask(method=globals()[mask],in_channels=n_prop1,n_nodes=n_nodes) #option no selection => concatenate |
238 | self.fc2 = Linear(in_features=n_nodes,out_features=n_classes) | 238 | self.fc2 = Linear(in_features=n_nodes,out_features=n_classes) |
239 | 239 | ||
240 | def forward(self,transcriptomic_data,graph_data): | 240 | def forward(self,transcriptomic_data,graph_data): |
... | @@ -247,10 +247,13 @@ class Net(torch.nn.Module): | ... | @@ -247,10 +247,13 @@ class Net(torch.nn.Module): |
247 | 247 | ||
248 | num_nodes = scatter_add(batch.new_ones(x.size(0),dtype=torch.int16), batch, dim=0) | 248 | num_nodes = scatter_add(batch.new_ones(x.size(0),dtype=torch.int16), batch, dim=0) |
249 | 249 | ||
250 | - if self.mask: | 250 | + if self.selection: |
251 | - x, edge_index, _, batch,idx_nodes_kept,_ = self.mask(x, edge_index, None, batch) | 251 | + x, edge_index, _, batch,idx_nodes_kept,_ = self.selection(x, edge_index, None, batch) |
252 | - | 252 | + if self.mask.method.__name__ == "concatenate_and_mask": |
253 | - x = self.selection(x,batch,idx_nodes_kept,num_nodes) | 253 | + x = self.mask(x,batch,idx_nodes_kept,num_nodes) |
254 | + else: | ||
255 | + x = self.mask(x,batch) | ||
256 | + | ||
254 | x = self.fc2(x) | 257 | x = self.fc2(x) |
255 | 258 | ||
256 | if self.n_classes>=2: | 259 | if self.n_classes>=2: | ... | ... |
-
Please register or login to post a comment