要在C++中調用PyTorch模型,可以使用LibTorch庫。以下是一個簡單的示例代碼,演示了如何加載一個PyTorch模型并使用輸入數據進行推理:
#include <torch/torch.h>
#include <iostream>
int main() {
// 加載模型
torch::jit::script::Module module;
try {
module = torch::jit::load("path/to/model.pt");
} catch (const c10::Error& e) {
std::cerr << "Error loading the model\n";
return -1;
}
// 準備輸入數據
torch::Tensor input = torch::ones({1, 3, 224, 224}); // 示例輸入數據
// 運行推理
at::Tensor output = module.forward({input}).toTensor();
// 輸出結果
std::cout << "Output tensor: " << output << std::endl;
return 0;
}
在這個示例中,首先加載了一個PyTorch模型(假設模型保存在model.pt
文件中)。然后創建了一個示例輸入張量input
,并將其傳遞給模型進行推理。最后,輸出了模型的輸出張量。
請注意,為了能夠編譯這段代碼,需要在項目中鏈接LibTorch庫并設置正確的包含路徑。更多關于LibTorch的用法和配置信息,請參考PyTorch官方文檔。