Official Pytorch implementation of ICML 2022 paper "TAM: Topology-Aware Margin Loss for Class-Imbalanced Node Classification"
This work investigates the phenomenon that imbalance handling algorithms for node classificaion excessively increase the false positives of minor classes. To mitigate this problem, we propose TAM, which adjusts the margin of each node according to the deviation from class-averaged topology.
The code for semi-supervised node classification. This is implemented mainly based on Pytorch Geometric.
-
Running command for TAM:
- Balanced Softmax + TAM
python main_bs.py \ --loss_type bs \ --dataset [dataset] \ --net [net] \ --n_layer [n_layer] \ --feat_dim [feat_dim] \ --tam \ --tam_alpha [tam_alpha] \ --tam_beta [tam_beta] \ --temp_phi [temp_phi] \
- GraphENS + TAM
python main_ens.py --ens \ --loss_type ce \ --tam \
- ReNode + TAM
python main_renode.py --renode \ --loss_type ce \ --loss_name [loss_name] \ --rn_base [rn_base] \ --rn_max [rn_max] \ --tam \
-
Running command for baselines:
- Cross Entropy
python main_bs.py \ --loss_type ce \
- Re-Weight
python main_rw.py --reweight \ --loss_type ce \
- PC Softmax
python main_pc.py --pc_softmax \ --loss_type ce \
- Balanced Softmax
python main_bs.py \ --loss_type bs \
- GraphENS
python main_ens.py --ens \ --loss_type ce \
- ReNode
python main_renode.py --renode \ --loss_type ce \ --loss_name [loss_name] \ --rn_base [rn_base] \ --rn_max [rn_max] \
-
Argument Description for TAM
- Experiment Dataset (the dataset will be downloaded automatically at the first running time):
Set [dataset] as one of ['Cora', 'Citeseer', 'PubMed', 'chameleon', 'squirrel', 'Wisconsin'] - Backbone GNN architecture:
Set [net] as one of ['GCN', 'GAT', 'SAGE'] - The number of layer for GNN:
Set [n_layer] as one of [1, 2, 3] - Hidden dimension for GNN:
Set [feat_dim] as one of [64, 128, 256] - The strength of ACM, α:
Set [tam_alpha] as one of [0.5, 1.5, 2.5] - The strength of ADM, β:
Set [tam_beta] as one of [0.25, 0.5] - The class-wise temeperature hyperparameter, 𝜙:
Set [temp_phi] as one of [0.8, 1.2]
- Experiment Dataset (the dataset will be downloaded automatically at the first running time):
This code has been tested with
- Python == 3.8.0
- Pytorch == 1.8.0
- Pytorch Geometric == 2.0.1
- torch_scatter == 2.0.8
@InProceedings{pmlr-v162-song22a,
title = {{TAM}: Topology-Aware Margin Loss for Class-Imbalanced Node Classification},
author = {Song, Jaeyun and Park, Joonhyung and Yang, Eunho},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
pages = {20369--20383},
year = {2022},
volume = {162},
series = {Proceedings of Machine Learning Research},
month = {17--23 Jul},
publisher = {PMLR},
}
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST))