量子化 API
wego_torch.quantize(
module: torch.nn.Module,
input_shapes: Sequence[Sequence],
dataloader: Iterable,
calibrator: Callable[[torch.nn.Module, Any, int, torch.device], None], export_dataloader: Iterable = None,
device: torch.device = torch.device("cpu"),
output_dir: str = "quantize_result",
bitwidth: int = None,
quant_config_file: Optional[str] = None,
*args, **kwargs) -> torch.jit.ScriptModule
この関数は、torch 浮動小数点モデルを PTQ (トレーニング後の量子化) 法で量子化し、WeGO のコンパイルで使用する量子化された TorchScript モジュールを返します。
PTQ で必要な精度が得られない場合は、Vitis AI クオンタイザー API で量子化認識トレーニング (QAT) を使用することを検討してください。量子化プロセスの詳細は、『Vitis AI ユーザー ガイド』の モデルの量子化 を参照してください。
パラメーター
- モジュール
-
(torch.nn.Module)
入力 PyTorch 浮動小数点モデル。 - input_shapes
-
(Sequence[Sequence]
モデルの入力形状。リストまたはタプルのシーケンス。 - dataloader
-
(Iterable)
キャリブレーション データセット用のデータローダー。iterable とする必要があります。API は実行を反復し、返された値をキャリブレーターに渡します。 - calibrator
-
(Callable)
バッチ データの前処理と転送を実行する呼び出し可能オブジェクト。データローダーからバッチ データを取得し、必要に応じて前処理を実行し、モジュールを使用して転送します。このキャリブレーターは、キャリブレーションおよびエクスポート ステージで N + 1 回呼び出されます。- ステージ 1 はキャリブレーションです。このステージではデータローダーが反復実行され、データがモジュールを経由して渡されて量子化の統計情報を収集します。キャリブレーターは N 回呼び出されます (N = len(dataloader))。ステージ 1 でオプションの export_dataloader を渡さない場合、データローダーから返される最初のバッチが保存され、ステージ 2 で使用されます。この場合、最初のバッチがキャリブレーターまたは反復の副作用によって変化しないことを確認してください。
- ステージ 2 は量子化された TorchScript モジュールをエクスポートするためのものです。このステージでは、キャリブレーターは 1 つのバッチ データで 1 回だけ呼び出されます。export_dataloader を指定した場合、この export_dataloader は反復され、最初のバッチのみが使用されます。プログラムは、最初のバッチの処理が完了すると反復を終了します。export_dataloader を渡さない場合、ステージ 1 で保存された最初のバッチが使用されます。calibrator の引数:
- モジュール
-
(torch.nn.Module)
量子化用のモジュール。これはユーザーが渡したモジュールを一部変更したもので、統計データの収集に必要なメカニズムが追加されています。データを転送するには、元の浮動小数点モデルではなく、このモジュールを使用してください。 - batch_data
-
(Any)
データローダーから返されるバッチ データ。 - batch_index
-
(int)
バッチのインデックス。転送には、必要に応じてデバイス (torch.device) を使用します。現在サポートされているのは CPU のみです。
注記: 量子化 API に対する追加の位置引数およびキーワード引数がキャリブレーターに転送されます。詳細は、モデルの量子化 を参照してください。
- export_dataloader
-
(Iterable)
エクスポート ステージ用のオプションのデータローダー。デフォルト値は None です。None の場合、ステージ 1 で保存した最初のバッチが使用されます。 - device
-
(torch.device)
キャリブレーションに使用するデバイス。現在サポートされているのは CPU のみです。 - output_dir
-
(str)
作業用の一時ディレクトリ。デフォルト値は quantize_result です。いくつかの中間ファイルがここに保存されます。 - bitwidth
-
(int)
グローバルな量子化ビット幅。デフォルト値は 8 です。 - quant_config_file
-
(str)
クオンタイザー コンフィギュレーション ファイルのパス。デフォルト値は None です。 - args
- キャリブレーターに渡す追加の位置引数。
- kwargs
- キャリブレーターに渡す追加のキーワード引数。
WeGO でのオンザフライ量子化の使用方法の詳細は、WeGO の例を参照してください。