- Post-training Static Quantization
self.model.eval()
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
load_model_weight(self.model, checkpoint)
self.model.qconfig = torch.quantization.get_default_qconfig(self.qconfig_name)
fuse_module(self.model)
torch.quantization.prepare(self.model, inplace=True)
dummy_input = torch.randn(1, 3, *self.cfg.data.eval.pipeline.input_size).to(self.device)
_ = self.model(dummy_input)
self.model.apply(torch.quantization.disable_observer)
torch.quantization.convert(self.model, inplace=True)
这种情况下模型是在正常浮点模式下训练的,注意在推理的时候要在前后module的forward头尾加上QuantStub, DeQuantStub
- 加载QAT模型
self.model.qconfig = torch.quantization.get_default_qat_qconfig(self.qconfig_name)
self.model.train()
fuse_module(self.model)
torch.quantization.prepare_qat(self.model, inplace=True)
dummy_input = torch.randn(1, 3, *self.cfg.data.eval.pipeline.input_size).to(self.device)
_ = self.model(dummy_input)
self.model.eval()
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
load_model_weight(self.model, checkpoint)
self.model.apply(torch.quantization.disable_observer)
self.model = torch.quantization.convert(self.model)
这种情况下,模型是QAT训练的,用QAT的模式加载