この API には次のメソッドがあります。
-
__init__(model, inputs)
- model
- プルーニングする
torch.nn.Module
オブジェクト。 - inputs
- モデル推論用の入力として使用される単一の torch.Tensor または torch.Tensor のリスト。実際のデータである必要はありません。形状とデータ型が実際のデータと同じであれば、無作為に生成されるテンソルでもかまいません。
-
search(gpus=['0'], calibration_fn=None, calib_args=(), num_subnet=10, removal_ratio=0.5, excludes=[], eval_fn=None, eval_args=())
- gpus
- 使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
- calibration_fn
- 最初の引数に torch.nn.Module オブジェクトを取る呼び出し可能オブジェクト。BatchNormalization レイヤーの統計のキャリブレーションに使用されます。
- calib_args
- calibration_fn に渡される引数のタプル。
- num_subnet
- MAC の制約を満たすサブネットワークの数。
- removal_ratio
- 想定される MAC 削減率。
- excludes
- プルーニングから除外する必要があるモジュール。
- eval_fn
- 最初の引数に
torch.nn.Module
オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。 - eval_args
- eval_fn に渡される引数のタプル。
-
prune(mode='slim', index=None, removal_ratio=None, pruning_info_path=None)
- mode
- ['sparse', 'slim'] のいずれか。ワンステップ プルーニングの場合は必ず 'slim' モードを使用します。
- index
- サブネットワークのインデックス。デフォルトでは、最適なサブネットワークが自動的に選択されます。
- removal_ratio
- 想定される MAC 削減率。
- pruning_info_path
- .json ファイル。現在のモデルの詳細なプルーニング情報を保存します。このファイルと元のモデルを使用して、スリム モデルを生成できます。