ONNX-LRE
C++ API documentation
Loading...
Searching...
No Matches
ONNX-LRE (Latent Runtime Engine)

The ONNX-LRE library provides a machine learning runtime environment for executing ONNX (Open Neural Network Exchange) models.

ONNX-LRE C++ APIs offer an easy-to-use interface to onboard and execute ONNX models from LEIP Optimize.

Inference Options

ONNX-LRE supports three different input formats for inference:

Each approach offers different tradeoffs between ease of use, performance, and integration complexity. See the examples below for practical usage patterns.

These examples demonstrate usage of the ONNX-LRE.

Examples

Example 1: DLPack Tensors with Smart Pointers

#include <onnx_lre/onnx_lre.hpp>
#include <memory>
#include <functional>
#include <iostream>
// Custom deleter for DLManagedTensor
struct DLTensorDeleter {
void operator()(DLManagedTensor* tensor) const {
if (tensor && tensor->deleter) {
tensor->deleter(tensor);
}
}
};
// Use unique_ptr to manage DLManagedTensor lifecycle
using DLTensorPtr = std::unique_ptr<DLManagedTensor, DLTensorDeleter>;
int main() {
try {
OnnxLre::LatentRuntimeEngine engine("/path/to/model.onnx", options);
// Get results and immediately wrap in smart pointers
std::vector<DLTensorPtr> outputTensors;
for (auto* tensor : engine.getOutput()) {
outputTensors.emplace_back(tensor);
}
// Process results safely - everything cleaned up automatically
for (const auto& tensor : outputTensors) {
if (!tensor) continue;
const auto& dl_tensor = tensor->dl_tensor;
std::cout << "Shape: [";
for (int j = 0; j < dl_tensor.ndim; j++) {
std::cout << dl_tensor.shape[j] << " ";
}
std::cout << "]" << std::endl;
}
// All resources automatically freed when outputTensors goes out of scope
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
The LatentRuntimeEngine class provides a C++ interface to load and run ONNX models using ONNX Runtime...
Definition onnx_lre.hpp:253
@ CPU
CPU execution - universal fallback with no special hardware requirements.
Definition onnx_lre.hpp:174
Configuration parameters for the inference engine.
Definition onnx_lre.hpp:233
ExecutionProvider executionProvider
Specifies the execution provider (e.g., CPU, CUDA, TensorRT). Defaults to the best available EP.
Definition onnx_lre.hpp:235

Example 2: Using ONNX Runtime Tensors with RAII

This approach uses ONNX Runtime's tensor types with automatic memory management:

#include <onnx_lre/onnx_lre.hpp>
#include <memory>
#include <iostream>
int main() {
try {
// Configuration with scope-limited lifetime
// Create engine (automatically cleaned up when going out of scope)
OnnxLre::LatentRuntimeEngine engine("/path/to/model.onnx", options);
// Fetch model requirements
const auto& inputShapes = engine.getInputShapes();
// Create environment for memory management
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "example");
Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
// Prepare inputs using std::vector (memory managed automatically)
std::vector<Ort::Value> inputTensors;
for (size_t i = 0; i < engine.getNumberOfInputs(); i++) {
// Calculate elements needed
size_t totalElements = 1;
for (auto dim : inputShapes[i]) {
totalElements *= (dim > 0) ? dim : 1; // Handle dynamic dimensions
}
// Use std::vector for memory safety
std::vector<float> data(totalElements, 0.5f);
// Create tensor (moved into vector, no raw pointer leaks)
inputTensors.push_back(Ort::Value::CreateTensor<float>(
memInfo, data.data(), data.size() * sizeof(float),
inputShapes[i].data(), inputShapes[i].size()));
}
// Run inference (Ort::Value has proper move semantics)
engine.infer(inputTensors);
// Get results with ownership transfer
auto outputTensors = engine.getOutputOrt();
// Process results (no cleanup needed - RAII handles it)
for (size_t i = 0; i < outputTensors.size(); i++) {
// Tensors automatically released when going out of scope
auto info = outputTensors[i].GetTensorTypeAndShapeInfo();
std::cout << "Output " << i << " shape: [";
for (auto dim : info.GetShape()) {
std::cout << dim << " ";
}
std::cout << "]" << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
@ Float16
16-bit floating point - reduced precision, ~50% memory reduction, faster on compatible hardware
Definition onnx_lre.hpp:185
@ CUDA
NVIDIA CUDA - GPU acceleration without TensorRT optimizations.
Definition onnx_lre.hpp:173
Precision precision
Specifies the precision type for model execution. Defaults to the best precision runtime can run.
Definition onnx_lre.hpp:236

Example 3: Using CUDA Graphs for Optimized Inference

This example demonstrates how to leverage CUDA graphs for optimized inference performance with static input shapes:

#include <onnx_lre/onnx_lre.hpp>
#include <dlpack/dlpack.h>
#include <torch/torch.h>
#include <ATen/DLConvertor.h> // For DLPack conversion
#include <memory>
#include <iostream>
int main() {
try {
// Configure options for CUDA execution with graph capture
options.enableCudaGraph = true; // Enable CUDA graph support
// Initialize engine with graph capture enabled
OnnxLre::LatentRuntimeEngine engine("/path/to/model.onnx", options);
// Get input requirements
const auto& inputShapes = engine.getInputShapes();
torch::Tensor dummy_input;
std::vector<DLManagedTensor *> input_tensors, output_tensors;
// Create random input on CUDA
for (const auto& shape : inputShapes) {
dummy_input = torch::randn(shape, torch::device(torch::kCUDA)); // Static shaped models only
input_tensors.push_back(at::toDLPack(dummy_input));
}
// Warm-up run to capture CUDA graph
engine.infer(input_tensors);
// Multiple inference runs using captured graph
const int numInferences = 100;
for (int i = 0; i < numInferences; i++) {
// Update input data as needed
input_tensors.clear();
for (const auto& shape : inputShapes) {
dummy_input = torch::randn(shape, torch::device(torch::kCUDA));
input_tensors.push_back(at::toDLPack(dummy_input));
}
// Execute inference using captured graph
engine.infer(input_tensors);
// Get output for post processing
output_tensors = engine.getOutput();
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
@ Float32
32-bit floating point - highest precision, largest memory footprint
Definition onnx_lre.hpp:184
std::optional< bool > enableCudaGraph
Enables CUDA Graph optimization for inference. When true, static models use CUDA Graphs for faster ex...
Definition onnx_lre.hpp:240

Example 4: Using CUDA Streams for multi-stream inference

This example demonstrates how to leverage CUDA Streams to run two models in parallel:

#include <onnx_lre/onnx_lre.hpp>
#include <dlpack/dlpack.h>
#include <torch/torch.h>
#include <ATen/DLConvertor.h> // For DLPack conversion
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <iostream>
#include <thread>
int main() {
try {
// Create CUDA streams for concurrent inference
auto stream1 = at::cuda::getStreamFromPool();
auto stream2 = at::cuda::getStreamFromPool();
// Configure options for each engine
OnnxLre::Options options1, options2;
options1.enableCudaGraph = false; // CUDA Graph support with streams in not supported
options1.cudaStream = stream1;
options2 = options1;
options2.cudaStream = stream2;
// Initialize both engines
OnnxLre::LatentRuntimeEngine engine1("/path/to/model1.onnx", options1);
OnnxLre::LatentRuntimeEngine engine2("/path/to/model2.onnx", options2);
// Query input shapes
const auto& inputShapes1 = engine1.getInputShapes();
const auto& inputShapes2 = engine2.getInputShapes();
// Create dummy input tensors using input shapes
std::vector<torch::Tensor> dummy_inputs1;
std::vector<torch::Tensor> dummy_inputs2;
{
c10::cuda::CUDAGuard guard1(stream1.device());
at::cuda::setCurrentCUDAStream(stream1);
for (const auto& shape : inputShapes1) {
dummy_inputs1.push_back(torch::randn(shape, torch::device(torch::kCUDA)));
}
}
{
c10::cuda::CUDAGuard guard2(stream2.device());
at::cuda::setCurrentCUDAStream(stream2);
for (const auto& shape : inputShapes2) {
dummy_inputs2.push_back(torch::randn(shape, torch::device(torch::kCUDA)));
}
}
// Wrap input tensors using DLPack (ownership transferred to engine.infer)
std::vector<DLManagedTensor*> input_tensors1, input_tensors2, output_tensors1, output_tensors2;
for (const auto& tensor : dummy_inputs1) {
input_tensors1.push_back(at::toDLPack(tensor));
}
for (const auto& tensor : dummy_inputs2) {
input_tensors2.push_back(at::toDLPack(tensor));
}
// Warm-up to capture CUDA graphs
engine1.infer(input_tensors1);
engine2.infer(input_tensors2);
// Run multiple concurrent inferences using captured graphs
const int numInferences = 50;
for (int i = 0; i < numInferences; ++i) {
// Optionally update input values in-place
for (auto& tensor : dummy_inputs1) {
tensor.normal_();
}
for (auto& tensor : dummy_inputs2) {
tensor.normal_();
}
// Re-wrap inputs for each inference (infer deletes DLPack)
input_tensors1.clear();
input_tensors2.clear();
for (const auto& tensor : dummy_inputs1) {
input_tensors1.push_back(at::toDLPack(tensor));
}
for (const auto& tensor : dummy_inputs2) {
input_tensors2.push_back(at::toDLPack(tensor));
}
// Run inference in parallel
std::thread thread1([&]() {
c10::cuda::CUDAGuard guard(stream1.device());
at::cuda::setCurrentCUDAStream(stream1);
engine1.infer(input_tensors1);
output_tensors1 = engine1.getOutputs();
// Post-process outputs here if needed
});
std::thread thread2([&]() {
c10::cuda::CUDAGuard guard(stream2.device());
at::cuda::setCurrentCUDAStream(stream2);
engine2.infer(input_tensors2);
output_tensors2 = engine2.getOutputs();
// Post-process outputs here if needed
});
thread1.join();
thread2.join();
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
@ TensorRT
NVIDIA TensorRT - highest performance for supported operations with optimization passes.
Definition onnx_lre.hpp:172
void * cudaStream
Definition onnx_lre.hpp:243