Skip to main content
Skip table of contents

Quantization Guided Training for PyTorch

This section documents the use of the Latent AI SDK for performing Quantization Guided Training (QGT) for PyTorch models. What follows is example usage of the classes provided by the SDK for TA within the PyTorch 1.6 framework. This example will construct a Lenet model to be trained on MNIST. The resulting model will be quantized to 4 bits.

For information on using the version of QGT for Keras, see LEIP Quantization Guided Training .


In this example we will import the following modules:

import os
import torch
import tqdm
from torchvision import datasets, transforms
from leip.core.train.torch.util import search_absorb_batchnorm, AverageMeter, to_float, accuracy
from leip.core.train.torch.sample_models import create_model
from leip.core.train.torch.quantization_guided_training import QuantizationGuidedTraining

Load Data

Now let’s use the following function to load and preprocess the MNIST training and test data from the torchvision datasets module:

def preprocess_data_mnist(output_path, batch_size, test_batch_size):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Normalize((0.1307,), (0.3081,))

    data_path = os.path.join(output_path, "data")

    train_dataset = datasets.MNIST(data_path, train=True, download=True, transform=transform)
    train_loader =, batch_size=batch_size)

    test_dataset = datasets.MNIST(data_path, train=False, download=True, transform=transform)
    test_loader =, batch_size=test_batch_size)

    return train_loader, test_loader


We will use the following function to call when we want to train the model for a number of epochs. The function accepts a model, optimizer, pytorch device, train_loader (from the data loader function above), the number of epochs and an optional qgt object. When the qgt obect is None it will train as a pytorch model is normally trained. When the qgt object is supplied, the loss terms related to Quantization Guided Training will be included in the total loss seen by the training.

def train(model, optimizer, criterion, device, train_loader, epoch, qgt):
    losses = AverageMeter()
    quant_losses = AverageMeter()
    l2_losses = AverageMeter()
    top1 = AverageMeter()
    pbar = tqdm.tqdm(enumerate(train_loader))

    for batch_idx, (data, target) in pbar:
        data, target =, non_blocking=True),, non_blocking=True)

        output = model(data)
        base_loss = criterion(output, target)
        # Get QGT loss with this next line
        quant_loss, l2_loss = 0, 0
        if qgt:
            quant_loss, l2_loss = qgt.get_qgt_loss()
        loss = base_loss + quant_loss + l2_loss


        reduced_loss = base_loss.detach()
        losses.update(to_float(reduced_loss), data.size(0))
        prec1 = accuracy(output.detach(), target, topk=(1,))
        top1.update(to_float(prec1), data.size(0))
        quant_losses.update(quant_loss, data.size(0))
        l2_losses.update(l2_loss, data.size(0))

        if batch_idx % 8 == 0:
                f"E{epoch}, {batch_idx * len(data):4d}/{len(train_loader.dataset):4d}, "
                f"base_loss={losses.avg:.3f}, "
                f"quant_loss={quant_losses.avg:.3f}, "
                f"l2_loss={l2_losses.avg:.3f}, "


The following function will be used to test a trained model. It accepts a model, criterion, pytorch device, a test_loader (from the data loader function above) and a qgt_output flag which tells the function whether to optionally log the loss in the context of QGT.

def test(model, criterion, device, test_loader, log_prefix):
    losses = AverageMeter()
    top1 = AverageMeter()

    with torch.no_grad():
        for data, target in test_loader:
            data, target =, non_blocking=True),, non_blocking=True)
            output = model(data)
            loss = criterion(output, target)
            reduced_loss = loss.detach()
            prec1 = accuracy(output.detach(), target, topk=(1,))
            losses.update(to_float(reduced_loss), data.size(0))
            top1.update(to_float(prec1), data.size(0))
        print("{}: Avg loss: {:.4f}, Acc: {:.2f}%".format(log_prefix, losses.avg, top1.avg))


Here we declare a bunch of variables related to our training session. In general these can be tweaked to your liking. Since we are training a simple LeNet model for MNIST here, the epochs are set to 2. We are calling the preprocess_data_mnist() we defied above to initialize the train and test data loaders. We are also initializing our pytorch device. The params dict holds a number of lambdas as well as the bits field where we have set the number of bits to quantize to as 4.

output_path = "qgt_torch_lenet_mnist"
batch_size = 64
test_batch_size = 64
epochs = 2
lr = 0.1
train_loader, test_loader = preprocess_data_mnist(output_path, batch_size, test_batch_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The number of quantization bits can be specified on a per named-layer
# basis by using a dict with the format of the following example:
# custom_layer_bits = {
#     "features.final_block.conv.weight": 1
# }
custom_layer_bits = {}

# Train the baseline model by setting lamb2=0 and lamb3=0
params = {
    "bits": 4,
    "bias_bits": 8,  # biases are quantized to 8 bits though
    "lamb2": 0.1,
    "lamb3": 0.00001,
    "quantizer": "asymmetricpc",  # asymmetric per-channel
    "quantize_bias": False,
    "quantize_bn": False,
    "custom_layer_bits": custom_layer_bits

Build Model

Next we will construct a Lenet model using the create_model() from the sample_models module and initialize it with the device.

model = create_model("LeNet".lower(), resolution=224, nclasses=10, finetune_baseline=True)
model =

Next we run a single epoch of un-quantized training to get the LeNet model on the right track towards convergence for the MNIST dataset. The number of pre-quantization epochs versus quantization guided epochs is something that will vary depending on the model architecture, task, and datasets. Experimenting with this as well as other hyperparameters is out of the scope of this tutorial. We call test() to see how accurate the model is so far.

# With QGT
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

print("Training baseline model:")
train(model, optimizer, criterion, device, train_loader, epoch=1, qgt=None)
test(model, criterion, device, test_loader, "Validation set")

Batchnorm Folding

This step is optional. Folding batchnorm layers can significantly improve a model’s inference latency. To accurately simulate the quantization effects in the folded model, we need to apply batchnorm folding and transform the graph for training too. To enable this optional step we would do the following:

# Apply bactchnorm folding if required, set the correct lamb2, lamb3 values and train with QGT


Finally we implement the Quantization Guided Training for our model using the following snippet of code. First the qgt object is constructed. We then loop over the number of epochs and call train() for each one again, but this time with the qgt object supplied. We evaulate the model at each step by calling test() while keeping track of the losses. When the cur_q_loss has ratcheted down, we save a checkpoint of the quantized model.

qgt = QuantizationGuidedTraining(model, params)

optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)
print("Training quantized model with QGT:")
q_loss = float("inf")
for epoch in range(1, epochs + 1):
    train(model, optimizer, criterion, device, train_loader, epoch, qgt)

    # Evaluate the quantized model
    q_model = qgt.get_quantized_model()
    q_accuracy, cur_q_loss = test(q_model, criterion, device, test_loader, "Quantized model on validation set")

    if cur_q_loss < q_loss:
        q_loss = cur_q_loss
        file_name = f"Acc{q_accuracy:.2f}_L{q_loss:6.4f}_E{epoch:02d}.pt"
        output_folder = os.path.join(output_path, "ckpt")
        if not os.path.exists(output_folder):
            os.mkdir(output_folder), os.path.join(output_path, "ckpt", file_name))

JavaScript errors detected

Please note, these errors can depend on your browser setup.

If this problem persists, please contact our support.