banner
kanes

kanes

PyTorch `.pth` 轉 ONNX:從模型訓練到跨平台部署

PyTorch .pth 轉 ONNX:從模型訓練到跨平台部署#

在深度學習裡,模型的格式決定了它的可用性

如果你是 PyTorch 使用者,你可能熟悉 .pth 文件,它用於存儲訓練好的模型。

但當你想在不同的環境(如 TensorRT、OpenVINO、ONNX Runtime)部署模型時,.pth 可能並不適用。這時,ONNX(Open Neural Network Exchange)就必不可少。

本文目錄:

  • 什麼是 .pth 文件?
  • 什麼是 .onnx 文件?
  • 為什麼要轉換?
  • 如何轉換 .pth.onnx
  • 轉換後的好處和潛在風險

1. 什麼是 .pth 文件?#

.pthPyTorch 專屬的模型權重文件,用於存儲:

  1. 模型權重(state_dict):僅保存參數,不包含模型結構。
  2. 完整模型:包含模型結構和權重,適用於直接 torch.save(model, "model.pth") 保存的情況。

在 PyTorch 中,你可以用以下方式加載 .pth

import torch
from NestedUNet import NestedUNet  # 你的模型類

# 僅保存權重的加載方式
model = NestedUNet(num_classes=2, input_channels=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

.pth 文件只能在 PyTorch 運行的環境中使用,不能直接在 TensorFlow、OpenVINO 或 TensorRT 裡運行。


2. 什麼是 ONNX?#

ONNX(Open Neural Network Exchange)是 一個開放的神經網絡標準格式,它的目標是:

  1. 跨框架兼容:支持 PyTorch、TensorFlow、Keras、MXNet 等。
  2. 優化推理:可以用 ONNX Runtime 或 TensorRT 加速推理。
  3. 部署靈活:支持在 CPU、GPU、FPGA、TPU 等硬件上運行。

ONNX 文件是一個 .onnx 文件,它包含:

  • 模型的計算圖
  • 算子(OPs)定義
  • 模型權重

ONNX 讓你可以在不同平台上運行同一個模型,而不必依賴某個特定的深度學習框架。


3. 為什麼要轉換 .pth.onnx#

轉換為 ONNX 主要有以下好處:

跨平台兼容

  • .pth 只能在 PyTorch 裡用,而 .onnx 可以在 TensorRT、ONNX Runtime、OpenVINO、CoreML 等多種環境中運行。

推理速度更快

  • ONNX Runtime 使用圖優化(Graph Optimization),減少計算冗餘,提高推理速度。
  • TensorRT 可以將 ONNX 模型編譯為高度優化的 GPU 代碼,顯著提高吞吐量。

支持多種硬件

  • .pth 主要用於 CPU/GPU,而 .onnx 可用於 FPGA、TPU、ARM 設備,如 安卓手機、樹莓派、Jetson Nano 等。

更輕量級

  • PyTorch 運行時需要完整的 Python 解釋器,而 ONNX 可以直接用 C++/C 代碼運行,適用於嵌入式設備。

4. 如何轉換 .pth.onnx#

4.1 安裝依賴#

在轉換前,確保你已安裝 PyTorch 和 ONNX:

pip install torch torchvision onnx

4.2 編寫轉換代碼#

假設你有一個 NestedUNet 訓練好的 .pth 文件,轉換方式如下:

import torch
import torch.onnx
from NestedUNet import NestedUNet  # 你的模型文件

# 1. 加載 PyTorch 模型
model = NestedUNet(num_classes=2, input_channels=3, deep_supervision=False)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

# 2. 創建示例輸入(確保形狀正確)
dummy_input = torch.randn(1, 3, 256, 256)

# 3. 導出為 ONNX
onnx_path = "nested_unet.onnx"
torch.onnx.export(
    model, 
    dummy_input, 
    onnx_path,
    export_params=True,
    opset_version=11,  # 確保兼容性
    do_constant_folding=True,
    input_names=["input"], 
    output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

print(f"✅ 模型已成功轉換為 {onnx_path}")

4.3 驗證 ONNX#

安裝 onnxruntime 並測試:

pip install onnxruntime

然後運行:

import onnxruntime as ort
import numpy as np

# 加載 ONNX
ort_session = ort.InferenceSession("nested_unet.onnx")

# 生成隨機輸入
input_data = np.random.randn(1, 3, 256, 256).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})

print("ONNX 推理結果:", outputs[0].shape)

5. 轉換後的好處和潛在風險#

5.1 好處#

提高推理速度

  • ONNX Runtime 和 TensorRT 可以顯著加速推理,尤其是在 GPU 上。

跨平台部署

  • .onnx 可用於 Windows、Linux、安卓、iOS、嵌入式設備。

減少依賴

  • 直接用 ONNX Runtime 運行,不需要完整的 PyTorch 依賴。

5.2 可能遇到的問題#

ONNX 可能不支持某些 PyTorch 操作

  • PyTorch 的某些自定義操作(如 grid_sample)可能在 ONNX 不支持,需要手動修改模型。

ONNX 的 Upsample 可能需要 align_corners=False

  • 如果 Upsample(scale_factor=2, mode='bilinear', align_corners=True),可能會導致 ONNX 兼容性問題,建議改為:

    self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    

ONNX 在 CPU 上的推理可能比 PyTorch 慢

  • 如果模型沒有經過優化,ONNX 可能不會比 PyTorch 快,尤其是在 CPU 上。

TensorRT 需要額外優化

  • 直接用 TensorRT 運行 ONNX 可能會報錯,需要 onnx-simplifier

    pip install onnx-simplifier
    python -m onnxsim nested_unet.onnx nested_unet_simplified.onnx
    

6. 對比#

比較項.pth (PyTorch).onnx (ONNX)
框架依賴僅支持 PyTorch兼容多框架
推理速度較慢更快(ONNX Runtime / TensorRT)
跨平台性僅支持 PyTorch可在多種設備上運行
部署難度需要完整 Python輕量級,適用於嵌入式

👉 建議

  • 如果模型只在 PyTorch 中用,不建議轉換
  • 如果要跨平台部署(如伺服器、移動端),轉換為 ONNX 是最佳方案
  • 如果要在 GPU 加速,建議用 TensorRT 進一步優化 ONNX

此文由 Mix Space 同步更新至 xLog
原始鏈接為 https://blog.kanes.top/posts/default/pth2ONNX


載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。