In this article, we will convert a deep learning model to ONNX format. We will a Lightning module based on the Efficientnet B1 and we will export it to onyx format. We will show two approaches: 1) Standard torch way of exporting the model to ONNX 2) Export using a torch lighting method

ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers. It is adopted and developed by several top-tier tech companies, such as Facebook, Microsoft, Amazon, and others. Models in onyx format can be easily deployed to various cloud platforms as well as to IoT devices.

# !pip install pytorch-lightning==0.9.1rc3
# !pip install efficientnet_pytorch
# !pip install onnx
# !pip install onnxruntime-gpu

Train the model

We will train a simple classification model, as input data, we will use cifar10. The model achieves ~90% accuracy, which is not perfect but since it is not the goal of this article we don't care.


import torch
from torch import nn
from torch.optim import lr_scheduler, Adam
from import DataLoader

import torchvision
import torchvision.transforms as transforms

import pytorch_lightning as pl
from pytorch_lightning.metrics import Recall, Accuracy
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from efficientnet_pytorch import EfficientNet

import os
from import tqdm
from sklearn.metrics import classification_report
import numpy as np


class ImageClassificationDatamodule(pl.LightningDataModule):
    def __init__(self, batch_size, train_transform, val_transform):
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.val_transform = val_transform

    def setup(self, stage=None):
        self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=self.train_transform)

        self.val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=4)

Lightning module

class ImageClassifier(pl.LightningModule):
    lr = 1e-3

    def __init__(self):

        self.criterion = nn.CrossEntropyLoss()
        self.metrics = {"accuracy": Accuracy(), "recall": Recall()}

        self.model = EfficientNet.from_pretrained('efficientnet-b1',

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(),
        scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)

        loss = self.criterion(logits, y)

        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)

        loss = self.criterion(logits, y)

        metrics_dict = {f"val_{name}": metric(logits, y) for name, metric in self.metrics.items()}

        return {**{"val_loss": loss}, **metrics_dict}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

        tensorboard_logs = {name: torch.stack([x[f"val_{name}"] for x in outputs]).mean()
                            for name, metric in self.metrics.items()}

        tensorboard_logs["val_loss"] = avg_loss

        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

Training loop

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

val_transform = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

dm = ImageClassificationDatamodule(128, train_transform, val_transform)

model = ImageClassifier()

checkpoint_callback = ModelCheckpoint(

EXPERIMENT_NAME = "efficienet_b1"

#And then actual training
trainer = Trainer(max_epochs=10,
                  # precision=16,
                  # resume_from_checkpoint = 'my_checkpoint.ckpt'
                  ), dm)
Downloading: "" to /root/.cache/torch/hub/checkpoints/efficientnet-b1-f1951068.pth
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/ UserWarning: Checkpoint directory /content exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!
  warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Loaded pretrained weights for efficientnet-b1
Downloading to ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
  | Name      | Type             | Params
0 | criterion | CrossEntropyLoss | 0     
1 | model     | EfficientNet     | 6 M   

Epoch 00000: val_loss reached 1.04839 (best 1.04839), saving model to /content/epoch=0.ckpt as top 1

Epoch 00001: val_loss reached 0.81029 (best 0.81029), saving model to /content/epoch=1.ckpt as top 1

Epoch 00002: val_loss reached 0.65911 (best 0.65911), saving model to /content/epoch=2.ckpt as top 1

Epoch 00003: val_loss reached 0.63399 (best 0.63399), saving model to /content/epoch=3.ckpt as top 1

Epoch 00004: val_loss reached 0.56925 (best 0.56925), saving model to /content/epoch=4.ckpt as top 1

Epoch 00005: val_loss reached 0.54162 (best 0.54162), saving model to /content/epoch=5.ckpt as top 1

Epoch 00006: val_loss reached 0.52157 (best 0.52157), saving model to /content/epoch=6.ckpt as top 1

Epoch 00007: val_loss reached 0.46342 (best 0.46342), saving model to /content/epoch=7.ckpt as top 1

Epoch 00008: val_loss reached 0.45730 (best 0.45730), saving model to /content/epoch=8.ckpt as top 1

Epoch 00009: val_loss reached 0.45397 (best 0.45397), saving model to /content/epoch=9.ckpt as top 1
Saving latest checkpoint..


Show some training statistics

%load_ext tensorboard
%tensorboard --logdir logs
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
val_dl = DataLoader(torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform), batch_size=128, shuffle=True, num_workers=4)
Files already downloaded and verified"cuda")

predictions = []
targets = []
for X, y in tqdm(val_dl):
    outputs = torch.nn.Softmax(dim=1)(model("cuda")))
    _, predicted = torch.max(outputs, 1)

predictions = ( 
targets = (
print(classification_report(targets, predictions, target_names=classes))
              precision    recall  f1-score   support

       plane       0.84      0.86      0.85      1000
         car       0.93      0.91      0.92      1000
        bird       0.84      0.78      0.81      1000
         cat       0.69      0.72      0.70      1000
        deer       0.79      0.85      0.82      1000
         dog       0.80      0.72      0.76      1000
        frog       0.87      0.92      0.90      1000
       horse       0.90      0.86      0.88      1000
        ship       0.90      0.91      0.90      1000
       truck       0.88      0.90      0.89      1000

    accuracy                           0.84     10000
   macro avg       0.84      0.84      0.84     10000
weighted avg       0.84      0.84      0.84     10000

Converting the trained model to ONNX

After the model is trained we can convert it to the ONNX format

Do the example prediction

ONNX needs some input data, so it knows its shape. Since we already have a dataloader we don't need to create dummy random data of the wanted shape

X, y = next(iter(val_dl))
print(f"Model input: {X.size()}")
torch_out = model("cuda"))
print(f"Model output: {torch_out.detach().cpu().size()}")
Model input: torch.Size([128, 3, 32, 32])
Model output: torch.Size([128, 10])

Convert model to ONNX - standard torch approach

Inspired by:

We export the PyTorch Lightning model similarly as we would do with a normal torch model

## a necessary fix, applicable only for Efficientnet
# # Export the model
torch.onnx.export(model,                     # model being run
                  ##since model is in the cuda mode, input also need to be
        "cuda"),              # model input (or a tuple for multiple inputs)
                  "model_troch_export.onnx", # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})
 data/  'epoch=9.ckpt'   logs/   model_troch_export.onnx   sample_data/

Convert model to ONNX - standard torch approach

PyTorch Lightning has its own method for exporting the model. The arguments are similar as in torch.onnx.export method

## a necessary fix, applicable only for Efficientnet
                input_names = ['input'],  
                output_names = ['output'], 
                dynamic_axes={'input' : {0 : 'batch_size'},'output' : {0 : 'batch_size'}})
 data/           logs/                           model_troch_export.onnx
'epoch=9.ckpt'   model_lightnining_export.onnx   sample_data/

Do the prediction with ONNX

Since we want to do the predicton on the GPU we need to make sure that CUDA is avaible as onnxruntime provider and is a first provider

import onnxruntime
import onnx

ort_session = onnxruntime.InferenceSession("model_lightnining_export.onnx")

['CUDAExecutionProvider', 'CPUExecutionProvider']
onnx_model = onnx.load("model_lightnining_export.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

##We set the batch size to the dynamic_axes, so we can use batch of any size we like, 10 in this example
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(X[:10])} 
ort_outs =, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out[:10]), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!
(10, 10)


We have shown how to easily export the PyTorch Lightning module to ONNX format. Neural networks in such format can be easily deployed as a production model both on the cloud and on IoT devices. It can also be used to effortlessly migrate between different frameworks such as PyTorch, Tensorflow, or Caffe2.