5#include "onnxruntime_training_c_api.h"
11#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
12 void OrtRelease(Ort##NAME* ptr);
16ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
17ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
21#include "onnxruntime_cxx_api.h"
34#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
35 inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
40#undef ORT_DECLARE_TRAINING_RELEASE
41#undef ORT_DEFINE_TRAINING_RELEASE
45using Property = std::variant<int64_t, float, std::string>;
114 const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
115 const bool include_optimizer_state =
false);
182 size_t training_model_output_count_, eval_model_output_count_;
203 const std::basic_string<ORTCHAR_T>& train_model_path,
204 const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
205 const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
219 const std::vector<uint8_t>& train_model_data,
const std::vector<uint8_t>& eval_model_data = {},
220 const std::vector<uint8_t>& optim_model_data = {});
241 std::vector<Value>
TrainStep(
const std::vector<Value>& input_values);
261 std::vector<Value>
EvalStep(
const std::vector<Value>& input_values);
347 const std::vector<std::string>& graph_output_names);
418#include "onnxruntime_training_cxx_inline.h"
Holds the state of the training session.
Definition onnxruntime_training_cxx_api.h:65
Value GetParameter(const std::string ¶meter_name)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
static CheckpointState LoadCheckpointFromBuffer(const std::vector< uint8_t > &buffer)
Load a checkpoint state from a buffer.
void AddProperty(const std::string &property_name, const Property &property_value)
Adds or updates the given property to/in the checkpoint state.
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
void UpdateParameter(const std::string ¶meter_name, const Value ¶meter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Property GetProperty(const std::string &property_name)
Gets the property value associated with the given name from the checkpoint state.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition onnxruntime_training_cxx_api.h:180
void OptimizerStep()
Performs the weight updates for the trainable parameters using the optimizer model.
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, float initial_lr)
Registers a linear learning rate scheduler for the training session.
std::vector< Value > EvalStep(const std::vector< Value > &input_values)
Computes the outputs for the eval model for the given inputs.
std::vector< std::string > InputNames(const bool training)
Retrieves the names of the user inputs for the training and eval models.
float GetLearningRate() const
Gets the current learning rate for this training session.
void ExportModelForInferencing(const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names)
Export a model that can be used for inferencing.
void LazyResetGrad()
Reset the gradients of all trainable parameters to zero lazily.
Value ToBuffer(const bool only_trainable)
Returns a contiguous buffer that holds a copy of all training state parameters.
std::vector< Value > TrainStep(const std::vector< Value > &input_values)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt)
Create a training session that can be used to begin or resume training.
void SchedulerStep()
Update the learning rate based on the registered learing rate scheduler.
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::vector< uint8_t > &train_model_data, const std::vector< uint8_t > &eval_model_data={}, const std::vector< uint8_t > &optim_model_data={})
Create a training session that can be used to begin or resume training. This constructor allows the u...
std::vector< std::string > OutputNames(const bool training)
Retrieves the names of the user outputs for the training and eval models.
void FromBuffer(Value &buffer)
Loads the training session model parameters from a contiguous buffer.
void SetLearningRate(float learning_rate)
Sets the learning rate for this training session.
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
void SetSeed(const int64_t seed)
This function sets the seed for generating random numbers.
Definition onnxruntime_cxx_api.h:499
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:47
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition onnxruntime_cxx_api.h:124
std::variant< int64_t, float, std::string > Property
Definition onnxruntime_training_cxx_api.h:45
const OrtTrainingApi & GetTrainingApi()
This function returns the C training api struct with the pointers to the ort training C functions....
Definition onnxruntime_training_cxx_api.h:30
The Env (Environment)
Definition onnxruntime_cxx_api.h:697
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:919
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:1614
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:556
contained_type * p_
Definition onnxruntime_cxx_api.h:584
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition onnxruntime_c_api.h:3719
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122