- このリポジトリは、GitHub - jacobgil/pytorch-pruningの日本語による解説を目的としています。
- [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inferenceのpytorch実装です。
- 畳み込み層のフィルタをpruningしてネットワークモデルを軽量化します。
- 重要度の低いn枚のフィルタを削除し、ファインチューニングを行います。
- VGG-16ベースの犬猫2クラス分類モデルをpruningした結果
フィルタ枚数 | 4224 | 1664 | 896 | 128 |
---|---|---|---|---|
認識精度 | 0.9875 | 0.9613 | 0.9625 | 0.72875 |
実行時間 | 12.01 s | 6.26s | 3.36s | 2.04s |
- 実行環境: Intel(R) Xeon(R) CPU E5-2690 v3 @ 2.60GHz, Tesla K40c, メモリ128GB
- テストデータ800枚,バッチサイズ32
PyTorchのImageFolderを使用しているため、分類クラスごとディレクトリ作成し、その中に画像ファイル 5FEC を保存します。
train/
+---dogs/
+---cats/
test
+---dogs/
+---cats/
画像はDogs vs. Cats | Kaggleからダウンロードします。
元リポジトリのブログに倣い、そのうち各クラス1000枚ずつをトレーニングデータ、400枚ずつをテストデータとしました。
学習
python finetune.py --train --save model
pruning
python finetune.py --prune --pruned_filters_per_iter 512 --pruning_iter 5 --model model --save prunned_model
テスト
python finetune.py --test --model model
python finetune.py --test --model prunned_model