この API には次のメソッドがあります。
-
__init__(model, inputs)
- model
- プルーニングする
torch.nn.Module
オブジェクト。 - inputs
- モデル推論用の入力として使用される単一の torch.Tensor または torch.Tensor のリスト。実際のデータである必要はありません。形状とデータ型が実際のデータと同じであれば、無作為に生成されるテンソルでもかまいません。
-
ofa_model(expand_ratio, channel_divisble=8, excludes=None, auto_add_excludes=True, save_search_space=False)
- expand_ratio
- 各たたみ込み層のプルーニング率のリスト。OFA モデル内の各たたみ込み層に対して、出力チャネルに任意のプルーニング率を使用できます。
このリストの最大値と最小値は、モデルの最大圧縮率と最小圧縮率を表します。その他の値は、最適化されるサブネットワークを表します。デフォルトでは、プルーニング率は [0.5, 0.75, 0.1] に設定されます。
- channel_divisible
- 特定の分周値で割り切れるチャネル数。
- excludes
- プルーニングから除外されるモジュールのリスト。
- auto_add_excludes
- ブール型。TRUE の場合、最初と最後のたたみ込みを自動で特定し、これらを除外リストに追加します。FALSE の場合、このような動作は実行されません。デフォルトは TRUE。
- save_search_space
- ブール型。TRUE の場合、モデルの検索空間を searchspace.config ファイルとして保存します。各レイヤーの検索空間は、ユーザーによるチェックが可能です。デフォルトは FALSE。
-
sample_subnet(model, mode)
特定のモードのサブネットワークとその設定を返します。サブネットワークは、OFA モデルおよびその設定から得られる重みの一部を使用して順方向/逆方向のプロセスを実行できます。
- model
- OFA モデル。
- mode
- ['random', 'max', 'min'] のいずれか。
-
reset_bn_running_stats_for_calibration(model)
BatchNormalization レイヤーの実行中の統計をリセットします。
- model
- OFA モデル。
-
run_evolutionary_search(model, calibration_fn, calib_args, eval_fn, eval_args, evaluation_metric, min_or_max_metric, min_flops, max_flops, flops_step=10, parent_popu_size=16, iteration=10, mutate_size=8, mutate_prob=0.2, crossover_size=4)
進化的検索を実行し、flops が特定の範囲内にある最も条件の良いサブネットワークを見つけます。
- model
- OFA モデル。
- calibration_fn
- BatchNormalization キャリブレーション関数。すべてのサブネットワークは OFA モデル内で重みを共有しますが、OFA モデルに学習させる際にバッチ正規化統計 (平均値および分散) は保存されません。学習の完了後、評価のためにサンプリングした各サブネットワークの学習データを使用して、バッチ正規化統計を再キャリブレーションする必要があります。
- calib_args
- calibration_fn の引数。
- eval_fn
- モデルを評価する関数。
- eval_args
- eval_fn の引数。
- evaluation_metric
- 結果を記録する evaluation_metric の文字列。
- min_or_max_metric
- ['max', 'min'] のいずれか。進化的検索で記録する評価メトリクスの最大値または最小値。たとえば、評価メトリクスの精度が top1 の場合、進化的検索の各反復の最大値を記録します。ただし、評価メトリクスが平均二乗誤差 (mse) または平均絶対誤差 (mae) の場合は、最小値を記録します。
- min_flops
- 検索されるサブネットワークの最小 flops。
- max_flops
- 検索されるサブネットワークの最大 flops。
- flops_step
- 検索に使用する flops のステップ。[min_flops, max flops] の区間を flops_step で割り、セグメントに分割します。各セグメントについて、FLOPs/精度の最適なトレードオフが得られるサブネットワークを検索します。
- parent_popu_size
- flops が特定の範囲内に収まる、特定の数のランダムなサブネットワークのサンプリングに使用する初期母集団の数。この値が大きいほど、検索に長い時間かかり、最善の結果が得られる可能性が高くなります。
- iteration
- 検索の反復回数またはアルゴリズム全体のサイクル数。
- mutate_size
- 突然変異のサイズ。サブネットワークの各設定値は、mutate_prob の確率で候補リストの値に置き換えられます。
- mutate_prob
- 突然変異の確率。
- crossover_size
- 交叉のサイズ。2 つのサブネットワーク設定をサンプリングし、2 つのサブネットワークの任意の設定値を無作為に入れ替えます。
-
save_subnet_config(setting_config, file_name)
動的/静的なサブネットワーク設定を JSON で保存します。
- setting_config
- 動的なサブネットワーク設定のコンフィギュレーション。
- file_name
- サブネットワーク設定を保存するファイルパス。
-
load_subnet_config(file_name)
- file_name
- サブネットワーク設定を読み込むファイルパス。