ONNX Runtime
|
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models. More...
#include <onnxruntime_training_cxx_api.h>
Public Member Functions | |
Constructing the Training Session | |
TrainingSession (const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt) | |
Create a training session that can be used to begin or resume training. | |
TrainingSession (const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::vector< uint8_t > &train_model_data, const std::vector< uint8_t > &eval_model_data={}, const std::vector< uint8_t > &optim_model_data={}) | |
Create a training session that can be used to begin or resume training. This constructor allows the users to load the models from buffers instead of files. | |
Implementing The Training Loop | |
std::vector< Value > | TrainStep (const std::vector< Value > &input_values) |
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. | |
void | LazyResetGrad () |
Reset the gradients of all trainable parameters to zero lazily. | |
std::vector< Value > | EvalStep (const std::vector< Value > &input_values) |
Computes the outputs for the eval model for the given inputs. | |
void | SetLearningRate (float learning_rate) |
Sets the learning rate for this training session. | |
float | GetLearningRate () const |
Gets the current learning rate for this training session. | |
void | RegisterLinearLRScheduler (int64_t warmup_step_count, int64_t total_step_count, float initial_lr) |
Registers a linear learning rate scheduler for the training session. | |
void | SchedulerStep () |
Update the learning rate based on the registered learing rate scheduler. | |
void | OptimizerStep () |
Performs the weight updates for the trainable parameters using the optimizer model. | |
Prepare For Inferencing | |
void | ExportModelForInferencing (const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names) |
Export a model that can be used for inferencing. | |
Model IO Information | |
std::vector< std::string > | InputNames (const bool training) |
Retrieves the names of the user inputs for the training and eval models. | |
std::vector< std::string > | OutputNames (const bool training) |
Retrieves the names of the user outputs for the training and eval models. | |
Accessing The Training Session State | |
Value | ToBuffer (const bool only_trainable) |
Returns a contiguous buffer that holds a copy of all training state parameters. | |
void | FromBuffer (Value &buffer) |
Loads the training session model parameters from a contiguous buffer. | |
Public Member Functions inherited from Ort::detail::Base< OrtTrainingSession > | |
constexpr | Base ()=default |
constexpr | Base (contained_type *p) noexcept |
Base (const Base &)=delete | |
Base (Base &&v) noexcept | |
~Base () | |
Base & | operator= (const Base &)=delete |
Base & | operator= (Base &&v) noexcept |
constexpr | operator contained_type * () const noexcept |
contained_type * | release () |
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed. | |
Additional Inherited Members | |
Public Types inherited from Ort::detail::Base< OrtTrainingSession > | |
using | contained_type = OrtTrainingSession |
Protected Attributes inherited from Ort::detail::Base< OrtTrainingSession > | |
contained_type * | p_ |
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
The training session requires four training artifacts
These artifacts can be generated using the onnxruntime-training
python utility.
Ort::TrainingSession::TrainingSession | ( | const Env & | env, |
const SessionOptions & | session_options, | ||
CheckpointState & | checkpoint_state, | ||
const std::basic_string< char > & | train_model_path, | ||
const std::optional< std::basic_string< char > > & | eval_model_path = std::nullopt , |
||
const std::optional< std::basic_string< char > > & | optimizer_model_path = std::nullopt |
||
) |
Create a training session that can be used to begin or resume training.
This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).
[in] | env | Env to be used for the training session. |
[in] | session_options | SessionOptions that the user can customize for this training session. |
[in] | checkpoint_state | Training states that the training session uses as a starting point for training. |
[in] | train_model_path | Model to be used to perform training. |
[in] | eval_model_path | Model to be used to perform evaluation. |
[in] | optimizer_model_path | Model to be used to perform gradient descent. |
Ort::TrainingSession::TrainingSession | ( | const Env & | env, |
const SessionOptions & | session_options, | ||
CheckpointState & | checkpoint_state, | ||
const std::vector< uint8_t > & | train_model_data, | ||
const std::vector< uint8_t > & | eval_model_data = {} , |
||
const std::vector< uint8_t > & | optim_model_data = {} |
||
) |
Create a training session that can be used to begin or resume training. This constructor allows the users to load the models from buffers instead of files.
[in] | env | Env to be used for the training session. |
[in] | session_options | SessionOptions that the user can customize for this training session. |
[in] | checkpoint_state | Training states that the training session uses as a starting point for training. |
[in] | train_model_data | Buffer containing training model data. |
[in] | eval_model_data | Buffer containing evaluation model data. |
[in] | optim_model_data | Buffer containing optimizer model (used for performing weight/parameter update). |
Computes the outputs for the eval model for the given inputs.
This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.
[in] | input_values | The user inputs to the eval model. |
void Ort::TrainingSession::ExportModelForInferencing | ( | const std::basic_string< char > & | inference_model_path, |
const std::vector< std::string > & | graph_output_names | ||
) |
Export a model that can be used for inferencing.
If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
[in] | inference_model_path | Path where the inference model should be serialized to. |
[in] | graph_output_names | Names of the outputs that are needed in the inference model. |
void Ort::TrainingSession::FromBuffer | ( | Value & | buffer | ) |
Loads the training session model parameters from a contiguous buffer.
In case the training session was created with a nominal checkpoint, invoking this function is required to load the updated parameters onto the checkpoint to complete it.
[in] | buffer | Contiguous buffer to load the parameters from. |
float Ort::TrainingSession::GetLearningRate | ( | ) | const |
Gets the current learning rate for this training session.
This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.
std::vector< std::string > Ort::TrainingSession::InputNames | ( | const bool | training | ) |
Retrieves the names of the user inputs for the training and eval models.
This function returns the names of inputs of the training or eval model that can be associated with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep function.
[in] | training | Whether the training model input names are requested or eval model input names. |
void Ort::TrainingSession::LazyResetGrad | ( | ) |
Reset the gradients of all trainable parameters to zero lazily.
This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next Ort::TrainingSession::TrainStep.
void Ort::TrainingSession::OptimizerStep | ( | ) |
Performs the weight updates for the trainable parameters using the optimizer model.
This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next Ort::TrainingSession::TrainStep function call.
std::vector< std::string > Ort::TrainingSession::OutputNames | ( | const bool | training | ) |
Retrieves the names of the user outputs for the training and eval models.
This function returns the names of outputs of the training or eval model that can be associated with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep function.
[in] | training | Whether the training model output names are requested or eval model output names. |
void Ort::TrainingSession::RegisterLinearLRScheduler | ( | int64_t | warmup_step_count, |
int64_t | total_step_count, | ||
float | initial_lr | ||
) |
Registers a linear learning rate scheduler for the training session.
Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.
[in] | warmup_step_count | Warmup steps for LR warmup. |
[in] | total_step_count | Total step count. |
[in] | initial_lr | The initial learning rate to be used by the training session. |
void Ort::TrainingSession::SchedulerStep | ( | ) |
Update the learning rate based on the registered learing rate scheduler.
Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.
void Ort::TrainingSession::SetLearningRate | ( | float | learning_rate | ) |
Sets the learning rate for this training session.
This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.
[in] | learning_rate | Desired learning rate to be set. |
Value Ort::TrainingSession::ToBuffer | ( | const bool | only_trainable | ) |
Returns a contiguous buffer that holds a copy of all training state parameters.
[in] | only_trainable | Whether to only copy trainable parameters or to copy all parameters. |
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the Ort::TrainingSession::OptimizerStep function. The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.
[in] | input_values | The user inputs to the training model. |