チャネル プルーニングをモジュール レベルでインプリメントします。
引数
Pruner(module, input_specs)
新しいプルーナー オブジェクトを作成します。
- module
- プルーニングする
torch.nn.Module
オブジェクト。 - input_specs
- モジュールの入力 (
InputSpec
オブジェクトまたはInputSpec
のリスト)。
方法
-
ana(eval_fn, args=(), gpus=None)
モデル解析を実行します。
- eval_fn
- 最初の引数に
torch.nn.Module
オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。 - args
-
eval_fn
に渡される引数のタプル。 - gpus
- モデル解析に使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
-
prune(ratio=None, threshold=None, excludes=[], output_script='graph.py')
指定された比またはしきい値により、ネットワークをプルーニングします。追加のプルーニング情報を指定した通常の
torch.nn.Module
と同じように機能するPruningModule
オブジェクトを返します。- ratio
- 想定される FLOPs 削減率。この値は単なるヒントです。プルーニング後の実際の FLOPs の減少は、必ずしもこの値と一致しません。
- threshold
- 許容できるモデルの精度低下の相対的比率。
- excludes
- プルーニングから除外する必要があるモジュール。
- output_script
- モデルの再構築に使用される生成済みスクリプトを保存するファイルパス。
-
summary(pruned_model)
プルーニング済みモデルのプルーニング サマリを生成します。
- pruned_model
-
prune()
メソッドで返されるプルーニング済みモジュール。