この API には次のメソッドがあります。
-
__init__(model, inputs)
- model
- プルーニングする
torch.nn.Module
オブジェクト。 - inputs
- モデル推論用の入力として使用される単一の torch.Tensor または torch.Tensor のリスト。実際のデータである必要はありません。形状とデータ型が実際のデータと同じであれば、無作為に生成されるテンソルでもかまいません。
-
ana(eval_fn, args=(), gpus=None, excludes=None, forced=False)
- eval_fn
- 最初の引数に
torch.nn.Module
オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。 - args
- eval_fn に渡される引数のタプル。
- gpus
- 使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
- excludes
- プルーニングから除外されるノード名または torch モジュールのリスト。
- forced
- FALSE の場合、モデル解析を省略し、キャッシュされた結果を使用します。
-
prune(removal_ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
- removal_ratio
- 想定される MAC 削減率。
- threshold
- 許容できるモデルの精度低下の相対的比率。
- spec_path
- 定義済みのプルーニング仕様。
- excludes
- プルーニングから除外されるノード名または torch モジュールのリスト。
- mode
- ['sparse', 'slim'] のいずれか。反復ループ内では常に 'sparse' を使用する必要があります。スリム モデルは量子化を意識した学習に使用されます。