ONNX Runtime
|
The Training C API that holds onnxruntime training function pointers. More...
#include <onnxruntime_training_c_api.h>
Accessing The Training Session State | |
OrtStatus * | LoadCheckpoint (const char *checkpoint_path, OrtCheckpointState **checkpoint_state) |
Load a checkpoint state from a file on disk into checkpoint_state. | |
OrtStatus * | SaveCheckpoint (OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state) |
Save the given state to a checkpoint file on disk. | |
OrtStatus * | GetParametersSize (OrtTrainingSession *sess, size_t *out, bool trainable_only) |
Retrieves the size of all the parameters. | |
OrtStatus * | CopyParametersToBuffer (OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only) |
Copy all parameters to a contiguous buffer held by the argument parameters_buffer. | |
OrtStatus * | CopyBufferToParameters (OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only) |
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state. | |
OrtStatus * | AddProperty (OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value) |
Adds or updates the given property to/in the checkpoint state. | |
OrtStatus * | GetProperty (const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value) |
Gets the property value associated with the given name from the checkpoint state. | |
OrtStatus * | LoadCheckpointFromBuffer (const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state) |
Load a checkpoint state from a buffer into checkpoint_state. | |
OrtStatus * | GetParameterTypeAndShape (const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape) |
Retrieves the type and shape information of the parameter associated with the given parameter name. | |
OrtStatus * | UpdateParameter (OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter) |
Updates the data associated with the model parameter in the checkpoint state for the given parameter name. | |
OrtStatus * | GetParameter (const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter) |
Gets the data associated with the model parameter from the checkpoint state for the given parameter name. | |
Implementing The Training Loop | |
OrtStatus * | CreateTrainingSession (const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out) |
Create a training session that can be used to begin or resume training. | |
OrtStatus * | CreateTrainingSessionFromBuffer (const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out) |
Create a training session that can be used to begin or resume training. This api provides a way to load all the training artifacts from buffers instead of files. | |
OrtStatus * | LazyResetGrad (OrtTrainingSession *session) |
Reset the gradients of all trainable parameters to zero lazily. | |
OrtStatus * | TrainStep (OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs) |
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. | |
OrtStatus * | EvalStep (const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs) |
Computes the outputs for the eval model for the given inputs. | |
OrtStatus * | SetLearningRate (OrtTrainingSession *sess, float learning_rate) |
Sets the learning rate for this training session. | |
OrtStatus * | GetLearningRate (OrtTrainingSession *sess, float *learning_rate) |
Gets the current learning rate for this training session. | |
OrtStatus * | OptimizerStep (OrtTrainingSession *sess, const OrtRunOptions *run_options) |
Performs the weight updates for the trainable parameters using the optimizer model. | |
OrtStatus * | RegisterLinearLRScheduler (OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr) |
Registers a linear learning rate scheduler for the training session. | |
OrtStatus * | SchedulerStep (OrtTrainingSession *sess) |
Update the learning rate based on the registered learing rate scheduler. | |
Model IO Information | |
OrtStatus * | TrainingSessionGetTrainingModelOutputCount (const OrtTrainingSession *sess, size_t *out) |
Retrieves the number of user outputs in the training model. | |
OrtStatus * | TrainingSessionGetEvalModelOutputCount (const OrtTrainingSession *sess, size_t *out) |
Retrieves the number of user outputs in the eval model. | |
OrtStatus * | TrainingSessionGetTrainingModelOutputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output) |
Retrieves the names of user outputs in the training model. | |
OrtStatus * | TrainingSessionGetEvalModelOutputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output) |
Retrieves the names of user outputs in the eval model. | |
OrtStatus * | TrainingSessionGetTrainingModelInputCount (const OrtTrainingSession *sess, size_t *out) |
Retrieves the number of user inputs in the training model. | |
OrtStatus * | TrainingSessionGetEvalModelInputCount (const OrtTrainingSession *sess, size_t *out) |
Retrieves the number of user inputs in the eval model. | |
OrtStatus * | TrainingSessionGetTrainingModelInputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output) |
Retrieves the name of the user input at given index in the training model. | |
OrtStatus * | TrainingSessionGetEvalModelInputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output) |
Retrieves the name of the user input at given index in the eval model. | |
Release Training Resources | |
void | ReleaseTrainingSession (OrtTrainingSession *input) |
Frees up the memory used up by the training session. | |
void | ReleaseCheckpointState (OrtCheckpointState *input) |
Frees up the memory used up by the checkpoint state. | |
Prepare For Inferencing | |
OrtStatus * | ExportModelForInferencing (OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names) |
Export a model that can be used for inferencing. | |
Training Utilities | |
OrtStatus * | SetSeed (const int64_t seed) |
Sets the seed used for random number generation in Onnxruntime. | |
The Training C API that holds onnxruntime training function pointers.
All the Training C API functions are defined inside this structure as pointers to functions. Call OrtApi::GetTrainingApi to get a pointer to this struct.
OrtStatus * OrtTrainingApi::AddProperty | ( | OrtCheckpointState * | checkpoint_state, |
const char * | property_name, | ||
enum OrtPropertyType | property_type, | ||
void * | property_value | ||
) |
Adds or updates the given property to/in the checkpoint state.
Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user by calling this function with the corresponding property name and value. The given property name must be unique to be able to successfully add the property.
[in] | checkpoint_state | The checkpoint state which should hold the property. |
[in] | property_name | Name of the property being added or updated. |
[in] | property_type | Type of the property associated with the given name. |
[in] | property_value | Property value associated with the given name. |
OrtStatus * OrtTrainingApi::CopyBufferToParameters | ( | OrtTrainingSession * | sess, |
OrtValue * | parameters_buffer, | ||
bool | trainable_only | ||
) |
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state.
The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, with matching setting for trainable_only argument. All the target parameters must be of the same datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer and can be used to load updated buffer values onto the training state. Parameter ordering is preserved. User is responsible for allocating and freeing the resources used by the parameters_buffer. In case the training session was created with a nominal checkpoint, invoking this function is required to load the updated parameters onto the checkpoint to complete it.
[in] | sess | The this pointer to the training session. |
[in] | trainable_only | Whether to skip non-trainable parameters |
[out] | parameters_buffer | The pre-allocated OrtValue buffer to copy from. |
OrtStatus * OrtTrainingApi::CopyParametersToBuffer | ( | OrtTrainingSession * | sess, |
OrtValue * | parameters_buffer, | ||
bool | trainable_only | ||
) |
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
The parameters_buffer has to be of the size given by GetParametersSize api call, with matching setting for the argument trainable_only. All the target parameters must be of the same datatype. The OrtValue must be pre-allocated onto the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters. Parameter ordering is preserved. User is responsible for allocating and freeing the resources used by the parameters_buffer.
[in] | sess | The this pointer to the training session. |
[in] | trainable_only | Whether to skip non-trainable parameters |
[out] | parameters_buffer | The pre-allocated OrtValue buffer to copy onto. |
OrtStatus * OrtTrainingApi::CreateTrainingSession | ( | const OrtEnv * | env, |
const OrtSessionOptions * | options, | ||
OrtCheckpointState * | checkpoint_state, | ||
const char * | train_model_path, | ||
const char * | eval_model_path, | ||
const char * | optimizer_model_path, | ||
OrtTrainingSession ** | out | ||
) |
Create a training session that can be used to begin or resume training.
This function creates a training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary). The training session requires four training artifacts
These artifacts can be generated using the onnxruntime-training
python utility.
[in] | env | Environment to be used for the training session. |
[in] | options | Session options that the user can customize for this training session. |
[in] | checkpoint_state | Training states that the training session uses as a starting point for training. |
[in] | train_model_path | Model to be used to perform training. |
[in] | eval_model_path | Model to be used to perform evaluation. |
[in] | optimizer_model_path | Model to be used to perform gradient descent. |
[out] | out | Created training session. |
OrtStatus * OrtTrainingApi::CreateTrainingSessionFromBuffer | ( | const OrtEnv * | env, |
const OrtSessionOptions * | options, | ||
OrtCheckpointState * | checkpoint_state, | ||
const void * | train_model_data, | ||
size_t | train_data_length, | ||
const void * | eval_model_data, | ||
size_t | eval_data_length, | ||
const void * | optim_model_data, | ||
size_t | optim_data_length, | ||
OrtTrainingSession ** | out | ||
) |
Create a training session that can be used to begin or resume training. This api provides a way to load all the training artifacts from buffers instead of files.
[in] | env | Environment to be used for the training session. |
[in] | options | Session options that the user can customize for this training session. |
[in] | checkpoint_state | Training states that the training session uses as a starting point for training. |
[in] | train_model_data | Buffer containing the model data to be used to perform training |
[in] | train_data_length | Length of the buffer containing train_model_data |
[in] | eval_model_data | Buffer containing the model data to be used to perform evaluation |
[in] | eval_data_length | Length of the buffer containing eval_model_data |
[in] | optim_model_data | Buffer containing the model data to be used to perform weight update |
[in] | optim_data_length | Length of the buffer containing optim_model_data |
[out] | out | Created training session. |
OrtStatus * OrtTrainingApi::EvalStep | ( | const OrtTrainingSession * | sess, |
const OrtRunOptions * | run_options, | ||
size_t | inputs_len, | ||
const OrtValue *const * | inputs, | ||
size_t | outputs_len, | ||
OrtValue ** | outputs | ||
) |
Computes the outputs for the eval model for the given inputs.
This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.
[in] | sess | The this pointer to the training session. |
[in] | run_options | Run options for this eval step. |
[in] | inputs_len | Number of user inputs to the eval model. |
[in] | inputs | The user inputs to the eval model. |
[in] | outputs_len | Number of user outputs expected from this eval step. |
[out] | outputs | User outputs computed by eval step. |
OrtStatus * OrtTrainingApi::ExportModelForInferencing | ( | OrtTrainingSession * | sess, |
const char * | inference_model_path, | ||
size_t | graph_outputs_len, | ||
const char *const * | graph_output_names | ||
) |
Export a model that can be used for inferencing.
If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
[in] | sess | The this pointer to the training session. |
[in] | inference_model_path | Path where the inference model should be serialized to. |
[in] | graph_outputs_len | Size of the graph output names array. |
[in] | graph_output_names | Names of the outputs that are needed in the inference model. |
OrtStatus * OrtTrainingApi::GetLearningRate | ( | OrtTrainingSession * | sess, |
float * | learning_rate | ||
) |
Gets the current learning rate for this training session.
This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.
[in] | sess | The this pointer to the training session. |
[out] | learning_rate | Learning rate currently in use by the training session. |
OrtStatus * OrtTrainingApi::GetParameter | ( | const OrtCheckpointState * | checkpoint_state, |
const char * | parameter_name, | ||
OrtAllocator * | allocator, | ||
OrtValue ** | parameter | ||
) |
Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
This function retrieves the model parameter data from the checkpoint state for the given parameter name. The parameter is copied over and returned as an OrtValue. The training session must be already created with the checkpoint state that contains the parameter being retrieved. The parameter must exist in the checkpoint state to be able to retrieve it successfully.
[in] | checkpoint_state | The checkpoint state. |
[in] | parameter_name | Name of the parameter being retrieved. |
[in] | allocator | Allocator used to allocate the memory for the parameter. |
[out] | parameter | The parameter data that is retrieved from the checkpoint state. |
OrtStatus * OrtTrainingApi::GetParametersSize | ( | OrtTrainingSession * | sess, |
size_t * | out, | ||
bool | trainable_only | ||
) |
Retrieves the size of all the parameters.
Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the training state. When trainable_only argument is true, the size is calculated for trainable params only.
[in] | sess | The this pointer to the training session. |
[out] | out | Size of all parameter elements. |
[in] | trainable_only | Whether to skip non-trainable parameters |
OrtStatus * OrtTrainingApi::GetParameterTypeAndShape | ( | const OrtCheckpointState * | checkpoint_state, |
const char * | parameter_name, | ||
OrtTensorTypeAndShapeInfo ** | parameter_type_and_shape | ||
) |
Retrieves the type and shape information of the parameter associated with the given parameter name.
This function retrieves the type and shape of the parameter associated with the given parameter name. The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
[in] | checkpoint_state | The checkpoint state. |
[in] | parameter_name | Name of the parameter being retrieved. |
[out] | parameter_type_and_shape | The type and shape of the parameter being retrieved. |
OrtStatus * OrtTrainingApi::GetProperty | ( | const OrtCheckpointState * | checkpoint_state, |
const char * | property_name, | ||
OrtAllocator * | allocator, | ||
enum OrtPropertyType * | property_type, | ||
void ** | property_value | ||
) |
Gets the property value associated with the given name from the checkpoint state.
Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.
[in] | checkpoint_state | The checkpoint state that is currently holding the property. |
[in] | property_name | Name of the property being retrieved. |
[in] | allocator | Allocator used to allocate the memory for the property_value. |
[out] | property_type | Type of the property associated with the given name. |
[out] | property_value | Property value associated with the given name. |
OrtStatus * OrtTrainingApi::LazyResetGrad | ( | OrtTrainingSession * | session | ) |
Reset the gradients of all trainable parameters to zero lazily.
This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next OrtTrainingApi::TrainStep.
[in] | session | The this pointer to the training session. |
OrtStatus * OrtTrainingApi::LoadCheckpoint | ( | const char * | checkpoint_path, |
OrtCheckpointState ** | checkpoint_state | ||
) |
Load a checkpoint state from a file on disk into checkpoint_state.
This function will parse a checkpoint file, pull relevant data and load the training state into the checkpoint_state. This checkpoint state can then be used to create the training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training session will resume training from the given checkpoint state.
[in] | checkpoint_path | Path to the checkpoint file |
[out] | checkpoint_state | Checkpoint state that contains the states of the training session. |
OrtStatus * OrtTrainingApi::LoadCheckpointFromBuffer | ( | const void * | checkpoint_buffer, |
const size_t | num_bytes, | ||
OrtCheckpointState ** | checkpoint_state | ||
) |
Load a checkpoint state from a buffer into checkpoint_state.
This function will parse a checkpoint bytes buffer, pull relevant data and load the training state into the checkpoint_state. This checkpoint state can then be used to create the training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training session will resume training from the given checkpoint state.
[in] | checkpoint_buffer | Path to the checkpoint bytes buffer. |
[in] | num_bytes | Number of bytes in the checkpoint buffer. |
[out] | checkpoint_state | Checkpoint state that contains the states of the training session. |
OrtStatus * OrtTrainingApi::OptimizerStep | ( | OrtTrainingSession * | sess, |
const OrtRunOptions * | run_options | ||
) |
Performs the weight updates for the trainable parameters using the optimizer model.
This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next OrtTrainingApi::TrainStep function call.
[in] | sess | The this pointer to the training session. |
[in] | run_options | Run options for this optimizer step. |
OrtStatus * OrtTrainingApi::RegisterLinearLRScheduler | ( | OrtTrainingSession * | sess, |
const int64_t | warmup_step_count, | ||
const int64_t | total_step_count, | ||
const float | initial_lr | ||
) |
Registers a linear learning rate scheduler for the training session.
Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.
[in] | sess | The this pointer to the training session. |
[in] | warmup_step_count | Warmup steps for LR warmup. |
[in] | total_step_count | Total step count. |
[in] | initial_lr | The initial learning rate to be used by the training session. |
void OrtTrainingApi::ReleaseCheckpointState | ( | OrtCheckpointState * | input | ) |
Frees up the memory used up by the checkpoint state.
This function frees up any memory that was allocated in the checkpoint state. The checkpoint state can no longer be used after this call.
void OrtTrainingApi::ReleaseTrainingSession | ( | OrtTrainingSession * | input | ) |
Frees up the memory used up by the training session.
This function frees up any memory that was allocated in the training session. The training session can no longer be used after this call.
OrtStatus * OrtTrainingApi::SaveCheckpoint | ( | OrtCheckpointState * | checkpoint_state, |
const char * | checkpoint_path, | ||
const bool | include_optimizer_state | ||
) |
Save the given state to a checkpoint file on disk.
This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume training from this snapshot of the state.
[in] | checkpoint_state | The checkpoint state to save. |
[in] | checkpoint_path | Path to the checkpoint file. |
[in] | include_optimizer_state | Flag to indicate whether to save the optimizer state or not. |
OrtStatus * OrtTrainingApi::SchedulerStep | ( | OrtTrainingSession * | sess | ) |
Update the learning rate based on the registered learing rate scheduler.
Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.
[in] | sess | The this pointer to the training session. |
OrtStatus * OrtTrainingApi::SetLearningRate | ( | OrtTrainingSession * | sess, |
float | learning_rate | ||
) |
Sets the learning rate for this training session.
This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.
[in] | sess | The this pointer to the training session. |
[in] | learning_rate | Desired learning rate to be set. |
OrtStatus * OrtTrainingApi::SetSeed | ( | const int64_t | seed | ) |
Sets the seed used for random number generation in Onnxruntime.
Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed.
[in] | seed | The seed to be set. |
OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelInputCount | ( | const OrtTrainingSession * | sess, |
size_t * | out | ||
) |
Retrieves the number of user inputs in the eval model.
This function returns the number of inputs of the eval model so that the user can accordingly allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
[in] | sess | The this pointer to the training session. |
[out] | out | Number of user inputs in the eval model. |
OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelInputName | ( | const OrtTrainingSession * | sess, |
size_t | index, | ||
OrtAllocator * | allocator, | ||
char ** | output | ||
) |
Retrieves the name of the user input at given index in the eval model.
This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
[in] | sess | The this pointer to the training session. |
[in] | index | The index of the eval model input name requested. |
[in] | allocator | The allocator to use to allocate the memory for the requested name. |
[out] | output | Name of the user input for the eval model at the given index. |
OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelOutputCount | ( | const OrtTrainingSession * | sess, |
size_t * | out | ||
) |
Retrieves the number of user outputs in the eval model.
This function returns the number of outputs of the eval model so that the user can allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
[in] | sess | The this pointer to the training session. |
[out] | out | Number of user outputs in the eval model. |
OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelOutputName | ( | const OrtTrainingSession * | sess, |
size_t | index, | ||
OrtAllocator * | allocator, | ||
char ** | output | ||
) |
Retrieves the names of user outputs in the eval model.
This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned by the OrtTrainingApi::EvalStep function.
[in] | sess | The this pointer to the training session. |
[in] | index | Index of the output name requested. |
[in] | allocator | Allocator to use to allocate the memory for the name. |
[out] | output | Name of the eval model output at the given index. |
OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelInputCount | ( | const OrtTrainingSession * | sess, |
size_t * | out | ||
) |
Retrieves the number of user inputs in the training model.
This function returns the number of inputs of the training model so that the user can accordingly allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
[in] | sess | The this pointer to the training session. |
[out] | out | Number of user inputs in the training model. |
OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelInputName | ( | const OrtTrainingSession * | sess, |
size_t | index, | ||
OrtAllocator * | allocator, | ||
char ** | output | ||
) |
Retrieves the name of the user input at given index in the training model.
This function returns the names of inputs of the training model that can be associated with the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
[in] | sess | The this pointer to the training session. |
[in] | index | The index of the training model input name requested. |
[in] | allocator | The allocator to use to allocate the memory for the requested name. |
[out] | output | Name of the user input for the training model at the given index. |
OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelOutputCount | ( | const OrtTrainingSession * | sess, |
size_t * | out | ||
) |
Retrieves the number of user outputs in the training model.
This function returns the number of outputs of the training model so that the user can allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
[in] | sess | The this pointer to the training session. |
[out] | out | Number of user outputs in the training model. |
OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelOutputName | ( | const OrtTrainingSession * | sess, |
size_t | index, | ||
OrtAllocator * | allocator, | ||
char ** | output | ||
) |
Retrieves the names of user outputs in the training model.
This function returns the names of outputs of the training model that can be associated with the OrtValue(s) returned by the OrtTrainingApi::TrainStep function.
[in] | sess | The this pointer to the training session. |
[in] | index | Index of the output name requested. |
[in] | allocator | Allocator to use to allocate the memory for the name. |
[out] | output | Name of the training model output at the given index. |
OrtStatus * OrtTrainingApi::TrainStep | ( | OrtTrainingSession * | sess, |
const OrtRunOptions * | run_options, | ||
size_t | inputs_len, | ||
const OrtValue *const * | inputs, | ||
size_t | outputs_len, | ||
OrtValue ** | outputs | ||
) |
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OrtTrainingApi::OptimizerStep function. The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
[in] | sess | The this pointer to the training session. |
[in] | run_options | Run options for this training step. |
[in] | inputs_len | Number of user inputs to the training model. |
[in] | inputs | The user inputs to the training model. |
[in] | outputs_len | Number of user outputs expected from this training step. |
[out] | outputs | User outputs computed by train step. |
OrtStatus * OrtTrainingApi::UpdateParameter | ( | OrtCheckpointState * | checkpoint_state, |
const char * | parameter_name, | ||
OrtValue * | parameter | ||
) |
Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
This function updates a model parameter in the checkpoint state with the given parameter data. The training session must be already created with the checkpoint state that contains the parameter being updated. The given parameter is copied over to the registered device for the training session. The parameter must exist in the checkpoint state to be able to update it successfully.
[in] | checkpoint_state | The checkpoint state. |
[in] | parameter_name | Name of the parameter being updated. |
[in] | parameter | The parameter data that should replace the existing parameter data. |