ONNX Runtime
|
Holds the state of the training session. More...
#include <onnxruntime_training_cxx_api.h>
Public Member Functions | |
CheckpointState ()=delete | |
Public Member Functions inherited from Ort::detail::Base< OrtCheckpointState > | |
constexpr | Base ()=default |
constexpr | Base (contained_type *p) noexcept |
Base (const Base &)=delete | |
Base (Base &&v) noexcept | |
~Base () | |
Base & | operator= (const Base &)=delete |
Base & | operator= (Base &&v) noexcept |
constexpr | operator contained_type * () const noexcept |
contained_type * | release () |
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed. | |
Accessing The Training Session State | |
void | AddProperty (const std::string &property_name, const Property &property_value) |
Adds or updates the given property to/in the checkpoint state. | |
Property | GetProperty (const std::string &property_name) |
Gets the property value associated with the given name from the checkpoint state. | |
void | UpdateParameter (const std::string ¶meter_name, const Value ¶meter) |
Updates the data associated with the model parameter in the checkpoint state for the given parameter name. | |
Value | GetParameter (const std::string ¶meter_name) |
Gets the data associated with the model parameter from the checkpoint state for the given parameter name. | |
static CheckpointState | LoadCheckpoint (const std::basic_string< char > &path_to_checkpoint) |
Load a checkpoint state from a file on disk into checkpoint_state. | |
static CheckpointState | LoadCheckpointFromBuffer (const std::vector< uint8_t > &buffer) |
Load a checkpoint state from a buffer. | |
static void | SaveCheckpoint (const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false) |
Save the given state to a checkpoint file on disk. | |
Additional Inherited Members | |
Public Types inherited from Ort::detail::Base< OrtCheckpointState > | |
using | contained_type = OrtCheckpointState |
Protected Attributes inherited from Ort::detail::Base< OrtCheckpointState > | |
contained_type * | p_ |
Holds the state of the training session.
This class holds the entire training session state that includes model parameters, their gradients, optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState by accessing and updating the contained training state.
|
delete |
void Ort::CheckpointState::AddProperty | ( | const std::string & | property_name, |
const Property & | property_value | ||
) |
Adds or updates the given property to/in the checkpoint state.
Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user by calling this function with the corresponding property name and value. The given property name must be unique to be able to successfully add the property.
[in] | property_name | Name of the property being added or updated. |
[in] | property_value | Property value associated with the given name. |
Value Ort::CheckpointState::GetParameter | ( | const std::string & | parameter_name | ) |
Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
This function retrieves the model parameter data from the checkpoint state for the given parameter name. The parameter is copied over to the provided OrtValue. The training session must be already created with the checkpoint state that contains the parameter being retrieved. The parameter must exist in the checkpoint state to be able to retrieve it successfully.
[in] | parameter_name | Name of the parameter being retrieved. |
Property Ort::CheckpointState::GetProperty | ( | const std::string & | property_name | ) |
Gets the property value associated with the given name from the checkpoint state.
Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.
[in] | property_name | Name of the property being retrieved. |
|
static |
Load a checkpoint state from a file on disk into checkpoint_state.
This function will parse a checkpoint file, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.
[in] | path_to_checkpoint | Path to the checkpoint file |
|
static |
Load a checkpoint state from a buffer.
This function will parse a checkpoint buffer, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.
[in] | buffer | Buffer containing the checkpoint data. |
|
static |
Save the given state to a checkpoint file on disk.
This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume training from this snapshot of the state.
[in] | checkpoint_state | The checkpoint state to save. |
[in] | path_to_checkpoint | Path to the checkpoint file. |
[in] | include_optimizer_state | Flag to indicate whether to save the optimizer state or not. |
void Ort::CheckpointState::UpdateParameter | ( | const std::string & | parameter_name, |
const Value & | parameter | ||
) |
Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
This function updates a model parameter in the checkpoint state with the given parameter data. The training session must be already created with the checkpoint state that contains the parameter being updated. The given parameter is copied over to the registered device for the training session. The parameter must exist in the checkpoint state to be able to update it successfully.
[in] | parameter_name | Name of the parameter being updated. |
[in] | parameter | The parameter data that should replace the existing parameter data. |