PyTorch .pth
to ONNX: From Model Training to Cross-Platform Deployment#
In deep learning, the format of the model determines its usability.
If you are a PyTorch user, you may be familiar with .pth
files, which are used to store trained models.
However, when you want to deploy the model in different environments (such as TensorRT, OpenVINO, ONNX Runtime), .pth
may not be suitable. At this point, ONNX (Open Neural Network Exchange) becomes essential.
Table of Contents:
- What is a
.pth
file? - What is an
.onnx
file? - Why convert?
- How to convert
.pth
to.onnx
? - Benefits and potential risks after conversion
1. What is a .pth
file?#
.pth
is a PyTorch-specific model weight file used to store:
- Model weights (state_dict): Only saves parameters, does not include model structure.
- Complete model: Contains both model structure and weights, suitable for cases saved directly with
torch.save(model, "model.pth")
.
In PyTorch, you can load a .pth
file as follows:
import torch
from NestedUNet import NestedUNet # Your model class
# Loading method for weights only
model = NestedUNet(num_classes=2, input_channels=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
The .pth
file can only be used in an environment where PyTorch is running and cannot be directly run in TensorFlow, OpenVINO, or TensorRT.
2. What is ONNX?#
ONNX (Open Neural Network Exchange) is an open neural network standard format with the goal of:
- Cross-framework compatibility: Supports PyTorch, TensorFlow, Keras, MXNet, etc.
- Optimized inference: Can use ONNX Runtime or TensorRT to accelerate inference.
- Flexible deployment: Supports running on hardware such as CPU, GPU, FPGA, TPU, etc.
An ONNX file is a .onnx
file that contains:
- The computational graph of the model
- Operator (OPs) definitions
- Model weights
ONNX allows you to run the same model on different platforms without relying on a specific deep learning framework.
3. Why convert .pth
to .onnx
?#
The main benefits of converting to ONNX are:
✅ Cross-platform compatibility
.pth
can only be used in PyTorch, while.onnx
can run in TensorRT, ONNX Runtime, OpenVINO, CoreML, and other environments.
✅ Faster inference speed
- ONNX Runtime uses graph optimization to reduce computational redundancy and improve inference speed.
- TensorRT can compile ONNX models into highly optimized GPU code, significantly increasing throughput.
✅ Support for various hardware
.pth
is mainly used for CPU/GPU, while.onnx
can be used on FPGA, TPU, ARM devices, such as Android phones, Raspberry Pi, Jetson Nano, etc.
✅ More lightweight
- PyTorch runtime requires a complete Python interpreter, while ONNX can run directly with C++/C code, suitable for embedded devices.
4. How to convert .pth
to .onnx
?#
4.1 Install dependencies#
Before conversion, ensure you have installed PyTorch and ONNX:
pip install torch torchvision onnx
4.2 Write conversion code#
Assuming you have a trained .pth
file for NestedUNet
, the conversion method is as follows:
import torch
import torch.onnx
from NestedUNet import NestedUNet # Your model file
# 1. Load the PyTorch model
model = NestedUNet(num_classes=2, input_channels=3, deep_supervision=False)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
# 2. Create a dummy input (ensure the shape is correct)
dummy_input = torch.randn(1, 3, 256, 256)
# 3. Export to ONNX
onnx_path = "nested_unet.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=11, # Ensure compatibility
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
print(f"✅ Model successfully converted to {onnx_path}")
4.3 Verify ONNX#
Install onnxruntime
and test:
pip install onnxruntime
Then run:
import onnxruntime as ort
import numpy as np
# Load ONNX
ort_session = ort.InferenceSession("nested_unet.onnx")
# Generate random input
input_data = np.random.randn(1, 3, 256, 256).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})
print("ONNX inference result:", outputs[0].shape)
5. Benefits and potential risks after conversion#
5.1 Benefits#
✅ Improved inference speed
- ONNX Runtime and TensorRT can significantly accelerate inference, especially on GPUs.
✅ Cross-platform deployment
.onnx
can be used on Windows, Linux, Android, iOS, and embedded devices.
✅ Reduced dependencies
- Run directly with ONNX Runtime without needing the complete PyTorch dependencies.
5.2 Potential issues#
⚠ ONNX may not support certain PyTorch operations
- Some custom operations in PyTorch (like
grid_sample
) may not be supported in ONNX and may require manual modification of the model.
⚠ ONNX's Upsample
may require align_corners=False
-
If using
Upsample(scale_factor=2, mode='bilinear', align_corners=True)
, it may cause ONNX compatibility issues; it is recommended to change to:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
⚠ Inference on CPU with ONNX may be slower than PyTorch
- If the model has not been optimized, ONNX may not be faster than PyTorch, especially on CPUs.
⚠ TensorRT requires additional optimization
-
Running ONNX directly with TensorRT may throw errors; you may need
onnx-simplifier
:pip install onnx-simplifier python -m onnxsim nested_unet.onnx nested_unet_simplified.onnx
6. Comparison#
Comparison Item | .pth (PyTorch) | .onnx (ONNX) |
---|---|---|
Framework Dependency | Only supports PyTorch | Compatible with multiple frameworks |
Inference Speed | Slower | Faster (ONNX Runtime / TensorRT) |
Cross-platform | Only supports PyTorch | Can run on various devices |
Deployment Difficulty | Requires complete Python | Lightweight, suitable for embedded |
👉 Recommendation
- If the model is only used in PyTorch, conversion is not recommended.
- If cross-platform deployment is needed (e.g., server, mobile), converting to ONNX is the best option.
- If GPU acceleration is desired, it is recommended to use TensorRT for further optimization of ONNX.
This article is synchronized and updated by Mix Space to xLog. The original link is https://blog.kanes.top/posts/default/pth2ONNX