ONNX Runtime
|
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() method. Use it to safely access context attributes, input and output parameters with exception safety guarantees. See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc. More...
#include <onnxruntime_cxx_api.h>
Public Member Functions | |
KernelContext (OrtKernelContext *context) | |
size_t | GetInputCount () const |
size_t | GetOutputCount () const |
ConstValue | GetInput (size_t index) const |
UnownedValue | GetOutput (size_t index, const int64_t *dim_values, size_t dim_count) const |
UnownedValue | GetOutput (size_t index, const std::vector< int64_t > &dims) const |
void * | GetGPUComputeStream () const |
Logger | GetLogger () const |
OrtAllocator * | GetAllocator (const OrtMemoryInfo &memory_info) const |
OrtKernelContext * | GetOrtKernelContext () const |
void | ParallelFor (void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const |
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() method. Use it to safely access context attributes, input and output parameters with exception safety guarantees. See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc.
|
explicit |
OrtAllocator * Ort::KernelContext::GetAllocator | ( | const OrtMemoryInfo & | memory_info | ) | const |
void * Ort::KernelContext::GetGPUComputeStream | ( | ) | const |
ConstValue Ort::KernelContext::GetInput | ( | size_t | index | ) | const |
size_t Ort::KernelContext::GetInputCount | ( | ) | const |
Logger Ort::KernelContext::GetLogger | ( | ) | const |
|
inline |
UnownedValue Ort::KernelContext::GetOutput | ( | size_t | index, |
const int64_t * | dim_values, | ||
size_t | dim_count | ||
) | const |
UnownedValue Ort::KernelContext::GetOutput | ( | size_t | index, |
const std::vector< int64_t > & | dims | ||
) | const |
size_t Ort::KernelContext::GetOutputCount | ( | ) | const |
void Ort::KernelContext::ParallelFor | ( | void(*)(void *, size_t) | fn, |
size_t | total, | ||
size_t | num_batch, | ||
void * | usr_data | ||
) | const |