Exporting PyTorch Lightning model to ONNX format
In this article, we will convert a deep learning model to ONNX format. We will a Lightning module based on the Efficientnet B1 and we will export it to onyx format. We will show two approaches: 1) Standard torch way of exporting the model to ONNX 2) Export using a torch lighting method
ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers. It is adopted and developed by several top-tier tech companies, such as Facebook, Microsoft, Amazon, and others. Models in onyx format can be easily deployed to various cloud platforms as well as to IoT devices.
# !pip install pytorch-lightning==0.9.1rc3
# !pip install efficientnet_pytorch
# !pip install onnx
# !pip install onnxruntime-gpu
import torch
from torch import nn
from torch.optim import lr_scheduler, Adam
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics import Recall, Accuracy
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from efficientnet_pytorch import EfficientNet
import os
from tqdm.auto import tqdm
from sklearn.metrics import classification_report
import numpy as np
class ImageClassificationDatamodule(pl.LightningDataModule):
def __init__(self, batch_size, train_transform, val_transform):
super().__init__()
self.batch_size = batch_size
self.train_transform = train_transform
self.val_transform = val_transform
def setup(self, stage=None):
self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=self.train_transform)
self.val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.val_transform)
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=4)
class ImageClassifier(pl.LightningModule):
lr = 1e-3
def __init__(self):
super().__init__()
self.criterion = nn.CrossEntropyLoss()
self.metrics = {"accuracy": Accuracy(), "recall": Recall()}
self.model = EfficientNet.from_pretrained('efficientnet-b1',
num_classes=10,
in_channels=3)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = Adam(self.parameters(), self.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
metrics_dict = {f"val_{name}": metric(logits, y) for name, metric in self.metrics.items()}
return {**{"val_loss": loss}, **metrics_dict}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
tensorboard_logs = {name: torch.stack([x[f"val_{name}"] for x in outputs]).mean()
for name, metric in self.metrics.items()}
tensorboard_logs["val_loss"] = avg_loss
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dm = ImageClassificationDatamodule(128, train_transform, val_transform)
model = ImageClassifier()
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min',
)
TENSORBOARD_DIRECTORY = "logs/"
EXPERIMENT_NAME = "efficienet_b1"
logger = TensorBoardLogger(TENSORBOARD_DIRECTORY, name=EXPERIMENT_NAME)
#And then actual training
trainer = Trainer(max_epochs=10,
logger=logger,
gpus=1,
# precision=16,
accumulate_grad_batches=4,
deterministic=True,
early_stop_callback=True,
checkpoint_callback=checkpoint_callback,
# resume_from_checkpoint = 'my_checkpoint.ckpt'
)
trainer.fit(model, dm)
%load_ext tensorboard
%tensorboard --logdir logs
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
val_dl = DataLoader(torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform), batch_size=128, shuffle=True, num_workers=4)
model.to("cuda")
model.eval()
print()
predictions = []
targets = []
for X, y in tqdm(val_dl):
outputs = torch.nn.Softmax(dim=1)(model(X.to("cuda")))
_, predicted = torch.max(outputs, 1)
predictions.append(predicted.detach())
targets.append(y.detach())
predictions = (torch.cat(predictions)).cpu().numpy()
targets = (torch.cat(targets)).cpu().numpy()
print(classification_report(targets, predictions, target_names=classes))
X, y = next(iter(val_dl))
print(f"Model input: {X.size()}")
torch_out = model(X.to("cuda"))
print(f"Model output: {torch_out.detach().cpu().size()}")
Convert model to ONNX - standard torch approach
Inspired by: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
We export the PyTorch Lightning model similarly as we would do with a normal torch model
## a necessary fix, applicable only for Efficientnet
model.model.set_swish(memory_efficient=False)
# # Export the model
torch.onnx.export(model, # model being run
##since model is in the cuda mode, input also need to be
X.to("cuda"), # model input (or a tuple for multiple inputs)
"model_troch_export.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes
'output' : {0 : 'batch_size'}})
ls
## a necessary fix, applicable only for Efficientnet
model.model.set_swish(memory_efficient=False)
model.to_onnx("model_lightnining_export.onnx",
X.to("cuda"),
export_params=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size'},'output' : {0 : 'batch_size'}})
ls
import onnxruntime
import onnx
ort_session = onnxruntime.InferenceSession("model_lightnining_export.onnx")
ort_session.get_providers()
onnx_model = onnx.load("model_lightnining_export.onnx")
onnx.checker.check_model(onnx_model)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
##We set the batch size to the dynamic_axes, so we can use batch of any size we like, 10 in this example
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(X[:10])}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out[:10]), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
ort_outs[0].shape
We have shown how to easily export the PyTorch Lightning module to ONNX format. Neural networks in such format can be easily deployed as a production model both on the cloud and on IoT devices. It can also be used to effortlessly migrate between different frameworks such as PyTorch, Tensorflow, or Caffe2.