pytorch_nndct.Pruner - 1.4.1 Japanese

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2021-10-29
Version
1.4.1 Japanese

チャネル プルーニングをモジュール レベルでインプリメントします。

引数

Pruner(module, inputs)

新しいプルーナー オブジェクトを作成します。

module
プルーニングする torch.nn.Module オブジェクト。
inputs
モジュールの入力。

方法

  • ana(eval_fn, args=(), gpus=None)

    モデル解析を実行します。

    eval_fn
    最初の引数に torch.nn.Module オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。
    args
    eval_fn に渡される引数のタプル。
    gpus
    モデル解析に使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
  • prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')

    指定された比またはしきい値によってネットワークをプルーニングし、‘torch.nn.Module’ オブジェクトを返します。返されるオブジェクトと torch ネイティブ モジュールの違いは、‘pruned_state_dict()’ という名前のメソッドが 1 つ追加されていることです。これにより、プルーニング済みデンス (密) モデルのパラメーターを取得できます。‘pruned_state_dict()’ によって返される重みは、‘output_script’ ファイル内の Python で作成されたモデルにロードできます。

    ratio
    想定される FLOPs 削減率。これは概算値です。プルーニング後の実際の削減率は、厳密にはこの値に達しないことがあります。
    threshold
    許容できるモデルの精度低下の相対的比率。
    excludes
    プルーニングから除外する必要があるモジュール。
    output_script
    モデルの再構築に使用される生成済みスクリプトを保存するファイルパス。