This project implements a Graph Isomorphism Network (GIN) for graph classification tasks using the DGL library, based on How Powerful are Graph Neural Networks? The model is trained and evaluated on various graph datasets, and the training and validation accuracies are plotted for analysis.
The following datasets are supported:
- MUTAG
- PTC
- NCI1
- PROTEINS
- COLLAB
- IMDBBINARY
- IMDBMULTI
To train the model on a specific dataset, run the following command:
python main.py --dataset <DATASET_NAME>
Replace <DATASET_NAME>
with one of the supported dataset names.
To enable plotting of training and validation accuracy, add the --plot
flag:
python main.py --dataset <DATASET_NAME> --plot
The training and validation accuracies are plotted for each epoch and saved as an image file. The plot shows the accuracy trends over the training process.
The best validation accuracy for each dataset is plotted and saved as an image file. This plot provides a comparison of the model's performance across different datasets.
Here are some example plots generated during the training process:
- Python 3.x
- DGL
- PyTorch
- NumPy
- scikit-learn
- Plotly
Install the required packages using pip:
pip install dgl torch numpy scikit-learn plotly
To train the model on the MUTAG dataset run:
python main.py --dataset MUTAG
To plot the accuracy run:
python main.py --plot
This will generate the following plots:
accuracy_plot_MUTAG.png
: Training and validation accuracy for each datasetbest_accuracy_plot.png
: Best validation accuracy across all datasets.
This project is licensed under the MIT License.