PyTorch .pth
轉 ONNX:從模型訓練到跨平台部署#
在深度學習裡,模型的格式決定了它的可用性。
如果你是 PyTorch 使用者,你可能熟悉 .pth
文件,它用於存儲訓練好的模型。
但當你想在不同的環境(如 TensorRT、OpenVINO、ONNX Runtime)部署模型時,.pth
可能並不適用。這時,ONNX(Open Neural Network Exchange)就必不可少。
本文目錄:
- 什麼是
.pth
文件? - 什麼是
.onnx
文件? - 為什麼要轉換?
- 如何轉換
.pth
到.onnx
? - 轉換後的好處和潛在風險
1. 什麼是 .pth
文件?#
.pth
是 PyTorch 專屬的模型權重文件,用於存儲:
- 模型權重(state_dict):僅保存參數,不包含模型結構。
- 完整模型:包含模型結構和權重,適用於直接
torch.save(model, "model.pth")
保存的情況。
在 PyTorch 中,你可以用以下方式加載 .pth
:
import torch
from NestedUNet import NestedUNet # 你的模型類
# 僅保存權重的加載方式
model = NestedUNet(num_classes=2, input_channels=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
.pth
文件只能在 PyTorch 運行的環境中使用,不能直接在 TensorFlow、OpenVINO 或 TensorRT 裡運行。
2. 什麼是 ONNX?#
ONNX(Open Neural Network Exchange)是 一個開放的神經網絡標準格式,它的目標是:
- 跨框架兼容:支持 PyTorch、TensorFlow、Keras、MXNet 等。
- 優化推理:可以用 ONNX Runtime 或 TensorRT 加速推理。
- 部署靈活:支持在 CPU、GPU、FPGA、TPU 等硬件上運行。
ONNX 文件是一個 .onnx
文件,它包含:
- 模型的計算圖
- 算子(OPs)定義
- 模型權重
ONNX 讓你可以在不同平台上運行同一個模型,而不必依賴某個特定的深度學習框架。
3. 為什麼要轉換 .pth
到 .onnx
?#
轉換為 ONNX 主要有以下好處:
✅ 跨平台兼容
.pth
只能在 PyTorch 裡用,而.onnx
可以在 TensorRT、ONNX Runtime、OpenVINO、CoreML 等多種環境中運行。
✅ 推理速度更快
- ONNX Runtime 使用圖優化(Graph Optimization),減少計算冗餘,提高推理速度。
- TensorRT 可以將 ONNX 模型編譯為高度優化的 GPU 代碼,顯著提高吞吐量。
✅ 支持多種硬件
.pth
主要用於 CPU/GPU,而.onnx
可用於 FPGA、TPU、ARM 設備,如 安卓手機、樹莓派、Jetson Nano 等。
✅ 更輕量級
- PyTorch 運行時需要完整的 Python 解釋器,而 ONNX 可以直接用 C++/C 代碼運行,適用於嵌入式設備。
4. 如何轉換 .pth
到 .onnx
?#
4.1 安裝依賴#
在轉換前,確保你已安裝 PyTorch 和 ONNX:
pip install torch torchvision onnx
4.2 編寫轉換代碼#
假設你有一個 NestedUNet
訓練好的 .pth
文件,轉換方式如下:
import torch
import torch.onnx
from NestedUNet import NestedUNet # 你的模型文件
# 1. 加載 PyTorch 模型
model = NestedUNet(num_classes=2, input_channels=3, deep_supervision=False)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
# 2. 創建示例輸入(確保形狀正確)
dummy_input = torch.randn(1, 3, 256, 256)
# 3. 導出為 ONNX
onnx_path = "nested_unet.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=11, # 確保兼容性
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
print(f"✅ 模型已成功轉換為 {onnx_path}")
4.3 驗證 ONNX#
安裝 onnxruntime
並測試:
pip install onnxruntime
然後運行:
import onnxruntime as ort
import numpy as np
# 加載 ONNX
ort_session = ort.InferenceSession("nested_unet.onnx")
# 生成隨機輸入
input_data = np.random.randn(1, 3, 256, 256).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})
print("ONNX 推理結果:", outputs[0].shape)
5. 轉換後的好處和潛在風險#
5.1 好處#
✅ 提高推理速度
- ONNX Runtime 和 TensorRT 可以顯著加速推理,尤其是在 GPU 上。
✅ 跨平台部署
.onnx
可用於 Windows、Linux、安卓、iOS、嵌入式設備。
✅ 減少依賴
- 直接用 ONNX Runtime 運行,不需要完整的 PyTorch 依賴。
5.2 可能遇到的問題#
⚠ ONNX 可能不支持某些 PyTorch 操作
- PyTorch 的某些自定義操作(如
grid_sample
)可能在 ONNX 不支持,需要手動修改模型。
⚠ ONNX 的 Upsample
可能需要 align_corners=False
-
如果
Upsample(scale_factor=2, mode='bilinear', align_corners=True)
,可能會導致 ONNX 兼容性問題,建議改為:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
⚠ ONNX 在 CPU 上的推理可能比 PyTorch 慢
- 如果模型沒有經過優化,ONNX 可能不會比 PyTorch 快,尤其是在 CPU 上。
⚠ TensorRT 需要額外優化
-
直接用 TensorRT 運行 ONNX 可能會報錯,需要
onnx-simplifier
:pip install onnx-simplifier python -m onnxsim nested_unet.onnx nested_unet_simplified.onnx
6. 對比#
比較項 | .pth (PyTorch) | .onnx (ONNX) |
---|---|---|
框架依賴 | 僅支持 PyTorch | 兼容多框架 |
推理速度 | 較慢 | 更快(ONNX Runtime / TensorRT) |
跨平台性 | 僅支持 PyTorch | 可在多種設備上運行 |
部署難度 | 需要完整 Python | 輕量級,適用於嵌入式 |
👉 建議
- 如果模型只在 PyTorch 中用,不建議轉換。
- 如果要跨平台部署(如伺服器、移動端),轉換為 ONNX 是最佳方案。
- 如果要在 GPU 加速,建議用 TensorRT 進一步優化 ONNX。
此文由 Mix Space 同步更新至 xLog
原始鏈接為 https://blog.kanes.top/posts/default/pth2ONNX