鄭州網(wǎng)站建設(shè)推銷鏈接制作軟件
量化工具箱pytorch_quantization 通過提供一個(gè)方便的 PyTorch 庫(kù)來補(bǔ)充 TensorRT ,該庫(kù)有助于生成可優(yōu)化的 QAT 模型。該工具包提供了一個(gè) API 來自動(dòng)或手動(dòng)為 QAT 或 PTQ 準(zhǔn)備模型。
API 的核心是 TensorQuantizer 模塊,它可以量化、偽量化或收集張量的統(tǒng)計(jì)信息。它與 QuantDescriptor 一起使用,后者描述了如何量化張量。在 TensorQuantizer 之上的是量化模塊,這些模塊被設(shè)計(jì)為 PyTorch 全精度模塊的替代品。這些是使用 TensorQuantizer 對(duì)模塊的權(quán)重和輸入進(jìn)行偽量化或收集統(tǒng)計(jì)信息的方便模塊。
API 支持將 PyTorch 模塊自動(dòng)轉(zhuǎn)換為其量化版本。轉(zhuǎn)換也可以使用 API 手動(dòng)完成,這允許在不想量化所有模塊的情況下進(jìn)行部分量化。例如,一些層可能對(duì)量化更敏感,并且使其未量化可提高任務(wù)精度。
量化第一步是將量化器模塊添加到神經(jīng)網(wǎng)絡(luò)圖中。該包提供了許多量化層模塊,其中包含用于輸入和權(quán)重的量化器。例如quant_nn.QuantLinear,它可以用來代替nn.Linear。這些量化層可以通過猴子修補(bǔ)或手動(dòng)修改模型定義來自動(dòng)替換。自動(dòng)層替換是使用quant_module完成的。這應(yīng)該在創(chuàng)建模型之前調(diào)用。
首先看以下代碼:
from pytorch_quantization import quant_modules
quant_modules.initialize()
initialize()
會(huì)動(dòng)態(tài)地修改 PyTorch 代碼,適用于每個(gè)模塊的所有實(shí)例,將 torch.nn.module 的一些子類替換為對(duì)應(yīng)的量化版本。如果不希望所有模塊都量化,則應(yīng)手動(dòng)替換量化模塊。獨(dú)立量化器也可以添加到帶有quant_nn.TensorQuantizer的模型中。
initialize()
位于:tools\pytorch-quantization\pytorch_quantization\quant_modules.py
,作用使用使用monkey patching進(jìn)行動(dòng)態(tài)模塊更換為量化版本
什么是猴子補(bǔ)丁
- Python是一種典型的動(dòng)態(tài)腳本語(yǔ)言。它不僅具有 動(dòng)態(tài)類型(dynamic type) ,而且它的 對(duì)象模型(object model) 也是動(dòng)態(tài)的。Python的類是可變的(mutable),方法(methods)只是類的屬性(attributes);這允許我們?cè)?運(yùn)行時(shí)(run time) 修改其行為。這被稱為猴子補(bǔ)丁(Monkey Patching), 它指的是偷偷地更改代碼。
- Monkey Patching只是在 運(yùn)行時(shí)(run time) 動(dòng)態(tài)替換屬性(attributes)。而在Python中,術(shù)語(yǔ)monkey patch指的是對(duì)函數(shù)(function)、類(class)或模塊(module)的動(dòng)態(tài)(或運(yùn)行時(shí))修改。
def initialize(float_module_list=None, custom_quant_modules=None):"""用量化版本動(dòng)態(tài)地替換模塊。在內(nèi)部,狀態(tài)由helper類對(duì)象維護(hù),該對(duì)象有助于將原始模塊替換回去。參數(shù):float_module_list:列表,用戶提供的列表,其中指明哪些模塊不可執(zhí)行替換custom_quant_modules:一個(gè)字典。用戶提供的映射,用于指示除torch.nn及其相應(yīng)量化版本之外的任何其他模塊。Returns:空"""# 準(zhǔn)備monkey patching中使用的內(nèi)部變量quant_map和orginal_func_map_quant_module_helper_object.prepare_state(float_module_list, custom_quant_modules)#執(zhí)行量化模塊替換_quant_module_helper_object.apply_quant_modules()def deactivate():"""動(dòng)態(tài)模塊更換,可逆轉(zhuǎn)monkey patching使用維護(hù)狀態(tài)的helper類對(duì)象動(dòng)態(tài)地替換回先前在initialize()函數(shù)調(diào)用中被monkey patching的原始模塊。"""_quant_module_helper_object.restore_float_modules()# 維護(hù)被替換模塊狀態(tài)的全局對(duì)象。
_quant_module_helper_object = QuantModuleReplacementHelper()
自定義量化模塊使用示例:
# torch.nn模塊定義不可執(zhí)行替換列表
float_module_list = ["Linear"]
# torch.nn以外的模塊自定義映射
custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]
# Monkey修補(bǔ)模塊
pytorch_quantization.quant_modules.initialize(float_module_list, custom_modules)
# 使用量化模塊
pytorch_quantization.quant_modules.deactivate()
繼續(xù)看helper
類QuantModuleReplacementHelper
class QuantModuleReplacementHelper():"""幫助量化版本替換torch.nn模塊術(shù)語(yǔ)monkey patch指的是對(duì)函數(shù)(function)、類(class)或模塊(module)的動(dòng)態(tài)(或運(yùn)行時(shí))修改該模塊用工具內(nèi)部實(shí)現(xiàn)或任何其他用戶提供的自定義模塊提供的量化版 替換(通過monkey patching)torch.nn模塊屬性:orginal_func_map:一個(gè)dict.維護(hù)原始torch.nn模塊字典quant_support_list:列表,包含工具提供的量化版本的模塊名稱quant_map:一個(gè)字典,包含模塊名稱及其量化版本的字典quant_switch_opt:一個(gè)字典,用于指示哪些模塊不能替換其量化版本。該dict由用戶提供的列表更新,該列表指示在monkey patching中要忽略的模塊"""def __init__(self):# 保留要更換的原始模塊self.orginal_func_map = set()# 默認(rèn)情況下,維護(hù)工具支持的量化模塊列表self.default_quant_map = _DEFAULT_QUANT_MAP# 保存最終量化模塊。self.quant_map = set()
_DEFAULT_QUANT_MAP
是包含量化模塊映射的文件的全局成員
_DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),_quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),_quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),_quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),_quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),_quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),_quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),_quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),_quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),_quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),_quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),_quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),_quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),_quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),_quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]
_quant_entry
定義命名元組,用于存儲(chǔ)量化模塊映射,它擁有三個(gè)屬性orig_mod mod_name replace_mod
_quant_entry = namedtuple('quant_entry', 'orig_mod mod_name replace_mod')
QuantModuleReplacementHelper
類的屬性方法:
prepare_state
準(zhǔn)備稍后在monkey patching機(jī)制中使用的量化模塊的命名字典quant_map
和更換為原始模塊orginal_func_map
- 設(shè)置torch.nn工具支持的量化模塊列表
- 為torch.nn以外的模塊設(shè)置自定義映射
- 使用float_module_list關(guān)閉用戶指示模塊的monkey patching替換
def prepare_state(self, float_module_list=None, custom_map=None):""""""# 對(duì)于支持的默認(rèn)量化模塊,生成quant_mapfor item in self.default_quant_map:if float_module_list is not None and item.mod_name in float_module_list:# 如果float_module_list中存在此模塊,則跳過此模塊continueelse:# 將模塊追加到將在monkey patching中使用的變量中self.quant_map.add(item)# 存儲(chǔ)要在反向monkey patching中使用的原始模塊self.orginal_func_map.add(_quant_entry(item.orig_mod, item.mod_name,getattr(item.orig_mod, item.mod_name)))# 將自定義模塊添加到quant_mapif custom_map is not None:for item in custom_map:# 將自定義模塊附加到將在monkey補(bǔ)丁中使用的列表中# 將元組轉(zhuǎn)換為命名元組self.quant_map.add(_quant_entry(item[0], item[1], item[2]))# 將原始模塊存儲(chǔ)在另一個(gè)列表中,該列表將用于反向monkey patchingself.orginal_func_map.add(_quant_entry(item[0], item[1], getattr(item[0], item[1])))
- apply_quant_modules:根據(jù)quant_map,執(zhí)行替換為量化模塊
def apply_quant_modules(self):for entry in self.quant_map:# 用于設(shè)置屬性值,該屬性不一定是存在的,對(duì)應(yīng)函數(shù) getattr()setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
- restore_float_modules:通過使用orginal_func_map替換回原始模塊,反轉(zhuǎn)monkey patch的效果
def restore_float_modules(self):for entry in self.orginal_func_map:setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)