pytorch_nndct.Pruner - 1.3 Japanese

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2021-02-03
Version
1.3 Japanese

チャネル プルーニングをモジュール レベルでインプリメントします。

引数

Pruner(module, input_specs)

新しいプルーナー オブジェクトを作成します。

module
プルーニングする torch.nn.Module オブジェクト。
input_specs
モジュールの入力 (InputSpec オブジェクトまたは InputSpec のリスト)。

方法

  • ana(eval_fn, args=(), gpus=None)

    モデル解析を実行します。

    eval_fn
    最初の引数に torch.nn.Module オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。
    args
    eval_fn に渡される引数のタプル。
    gpus
    モデル解析に使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
  • prune(ratio=None, threshold=None, excludes=[], output_script='graph.py')

    指定された比またはしきい値により、ネットワークをプルーニングします。追加のプルーニング情報を指定した通常の torch.nn.Module と同じように機能する PruningModule オブジェクトを返します。

    ratio
    想定される FLOPs 削減率。この値は単なるヒントです。プルーニング後の実際の FLOPs の減少は、必ずしもこの値と一致しません。
    threshold
    許容できるモデルの精度低下の相対的比率。
    excludes
    プルーニングから除外する必要があるモジュール。
    output_script
    モデルの再構築に使用される生成済みスクリプトを保存するファイルパス。
  • summary(pruned_model)

    プルーニング済みモデルのプルーニング サマリを生成します。

    pruned_model
    prune() メソッドで返されるプルーニング済みモジュール。