この API には次のメソッドがあります。
-
__init__(model, inputs)
- model
- プルーニングする
torch.nn.Module
オブジェクト。 - inputs
- モデル推論用の入力として使用される単一の torch.Tensor または torch.Tensor のリスト。実際のデータである必要はありません。形状とデータ型が実際のデータと同じであれば、無作為に生成されるテンソルでもかまいません。
-
sparse_model(w_sparsity=0.5, a_sparsity=0, block_size=16,excludes=None)
- w_sparsity
- ['0', '0.5', '0.75'] のいずれか。たたみ込み層および完全接続層の重みのスパース度を示す float。デフォルトでは、w_sparsity は 0.5 に設定されます。
- a_sparsity
- ['0', '0.5'] のいずれか。活性化のスパース度を示す float。ここで、活性化値はスパース層の入力特徴マップを表します。デフォルトでは、a_sparsity は 0 に設定されます。a_sparsity が 0.5 に設定されている場合、w_sparsity は 0.75 にする必要があります。
- block_size
- 入力チャネル (または重みと活性化値に従って展開されたチャネル) の連続する要素の int 数。
- excludes
- スパース プルーニングから除外されるモジュールのリスト。
-
export_sparse_model(model)
指定したスパース演算用にハードウェア上での推論用のスパース型重みを含むスパース ネットワークから変換したネットワークを返します。
- model
- スパース モデル。