pytorch_nndct.IterativePruningRunner - 2.5 Japanese

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 Japanese

この API には次のメソッドがあります。

  • __init__(model, inputs)
    model
    プルーニングする torch.nn.Module オブジェクト。
    inputs
    モデル推論用の入力として使用される単一の torch.Tensor または torch.Tensor のリスト。実際のデータである必要はありません。形状とデータ型が実際のデータと同じであれば、無作為に生成されるテンソルでもかまいません。
  • ana(eval_fn, args=(), gpus=None, excludes=None, forced=False)
    eval_fn
    最初の引数に torch.nn.Module オブジェクトを取り、評価スコアを返す呼び出し可能オブジェクト。
    args
    eval_fn に渡される引数のタプル。
    gpus
    使用する GPU インデックスのタプルまたはリスト。設定しない場合、デフォルトの GPU が使用されます。
    excludes
    プルーニングから除外されるノード名または torch モジュールのリスト。
    forced
    FALSE の場合、モデル解析を省略し、キャッシュされた結果を使用します。
  • prune(removal_ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
    removal_ratio
    想定される MAC 削減率。
    threshold
    許容できるモデルの精度低下の相対的比率。
    spec_path
    定義済みのプルーニング仕様。
    excludes
    プルーニングから除外されるノード名または torch モジュールのリスト。
    mode
    ['sparse', 'slim'] のいずれか。反復ループ内では常に 'sparse' を使用する必要があります。スリム モデルは量子化を意識した学習に使用されます。