チャネル プルーニングをモジュール レベルでインプリメントします。
引数
Pruner(module, inputs)
新しいプルーナー オブジェクトを作成します。
- module
- プルーニングする
torch.nn.Module
オブジェクト。 - inputs
- モジュールの入力。
方法
-
ana(eval_fn, args=(), gpus=None)
モデル解析を実行します。
- eval_fn
- 最初の引数に
torch.nn.Module
オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。 - args
-
eval_fn
に渡される引数のタプル。 - gpus
- モデル解析に使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
-
prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')
指定された比またはしきい値によってネットワークをプルーニングし、
‘torch.nn.Module’
オブジェクトを返します。返されるオブジェクトと torch ネイティブ モジュールの違いは、‘pruned_state_dict()’
という名前のメソッドが 1 つ追加されていることです。これにより、プルーニング済みデンス (密) モデルのパラメーターを取得できます。‘pruned_state_dict()’
によって返される重みは、‘output_script’ ファイル内の Python で作成されたモデルにロードできます。- ratio
- 想定される FLOPs 削減率。これは概算値です。プルーニング後の実際の削減率は、厳密にはこの値に達しないことがあります。
- threshold
- 許容できるモデルの精度低下の相対的比率。
- excludes
- プルーニングから除外する必要があるモジュール。
- output_script
- モデルの再構築に使用される生成済みスクリプトを保存するファイルパス。