Train the Model on the Device#
Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.
The expected training artifacts are:
The training onnx model
The checkpoint state
The optimizer onnx model
The eval onnx model (optional)
Sample usage:
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#
Bases:
object
Class that represents a model parameter
This class represents a model parameter and provides access to its data, gradient and other properties. This class is not expected to be instantiated directly. Instead, it is returned by the CheckpointState object.
- Parameters:
parameter – The C.Parameter object that holds the underlying parameter data.
state – The C.CheckpointState object that holds the underlying session state.
- class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#
Bases:
object
Class that holds all the model parameters
This class holds all the model parameters and provides access to them. This class is not expected to be instantiated directly. Instead, it is returned by the CheckpointState’s parameters attribute. This class behaves like a dictionary and provides access to the parameters by name.
- Parameters:
state – The C.CheckpointState object that holds the underlying session state.
- __getitem__(name: str) Parameter [source]#
Gets the parameter associated with the given name
Searches for the name in the parameters of the checkpoint state.
- Parameters:
name – The name of the parameter
- Returns:
The value of the parameter
- Raises:
KeyError – If the parameter is not found
- __setitem__(name: str, value: ndarray) None [source]#
Sets the parameter value for the given name
Searches for the name in the parameters of the checkpoint state. If the name is found in parameters, the value is updated.
- Parameters:
name – The name of the parameter
value – The value of the parameter as a numpy array
- Raises:
KeyError – If the parameter is not found
- class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#
Bases:
object
- __getitem__(name: str) int | float | str [source]#
Gets the property associated with the given name
Searches for the name in the properties of the checkpoint state.
- Parameters:
name – The name of the property
- Returns:
The value of the property
- Raises:
KeyError – If the property is not found
- __setitem__(name: str, value: int | float | str) None [source]#
Sets the property value for the given name
Searches for the name in the properties of the checkpoint state. The value is added or updated in the properties.
- Parameters:
name – The name of the property
value – The value of the property Properties only support int, float and str values.
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#
Bases:
object
Class that holds the state of the training session
This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties.
To create the CheckpointState, use the CheckpointState.load_checkpoint method.
- Parameters:
state – The C.Checkpoint state object that holds the underlying session state.
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState [source]#
Loads the checkpoint state from the checkpoint file
The checkpoint file can either be the complete checkpoint or the nominal checkpoint.
- Parameters:
checkpoint_uri – The path to the checkpoint file.
- Returns:
The checkpoint state object.
- Return type:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None [source]#
Saves the checkpoint state to the checkpoint file
- Parameters:
state – The checkpoint state object.
checkpoint_uri – The path to the checkpoint file.
include_optimizer_state – If True, the optimizer state is also saved to the checkpoint file.
- property parameters: Parameters#
Returns the model parameters from the checkpoint state
- property properties: Properties#
Returns the properties from the checkpoint state
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#
Bases:
object
Trainer class that provides training and evaluation methods for ONNX models.
Before instantiating the Module class, it is expected that the training artifacts have been generated using the onnxruntime.training.artifacts.generate_artifacts utility.
- The training artifacts include:
The training model
The evaluation model (optional)
The optimizer model (optional)
The checkpoint file
- Parameters:
train_model_uri – The path to the training model.
state – The checkpoint state object.
eval_model_uri – The path to the evaluation model.
device – The device to run the model on. Default is “cpu”.
session_options – The session options to use for the model.
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue [source]#
Invokes either the training or the evaluation step of the model.
- Parameters:
*user_inputs – The inputs to the model. The user inputs can be either numpy arrays or OrtValues.
- Returns:
The outputs of the model.
- train(mode: bool = True) Module [source]#
Sets the Module in training mode.
- Parameters:
mode – whether to set the model to training mode (True) or evaluation mode (False). Default: True.
- Returns:
self
- lazy_reset_grad()[source]#
Lazily resets the training gradients.
This function sets the internal state of the module such that the module gradients will be scheduled to be reset just before the new gradients are computed on the next invocation of train().
- get_contiguous_parameters(trainable_only: bool = False) OrtValue [source]#
Creates a contiguous buffer of the training session parameters
- Parameters:
trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.
- Returns:
The contiguous buffer of the training session parameters.
- get_parameters_size(trainable_only: bool = True) int [source]#
Returns the size of the parameters.
- Parameters:
trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.
- Returns:
The number of primitive (example floating point) elements in the parameters.
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None [source]#
Copies the OrtValue buffer to the training session parameters.
In case the module was loaded from a nominal checkpoint, invoking this function is required to load the updated parameters onto the checkpoint to complete it.
- Parameters:
buffer – The OrtValue buffer to copy to the training session parameters.
- export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None [source]#
Exports the model for inferencing.
Once training is complete, this function can be used to drop the training specific nodes in the onnx model. In particular, this function does the following:
Parse over the training graph and identify nodes that generate the given output names.
Drop all subsequent nodes in the graph since they are not relevant to the inference graph.
- Parameters:
inference_model_uri – The path to the inference model.
graph_output_names – The list of output names that are required for inferencing.
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
Bases:
object
Class that provides methods to update the model parameters based on the computed gradients.
- Parameters:
optimizer_uri – The path to the optimizer model.
model – The module to be trained.
- step() None [source]#
Updates the model parameters based on the computed gradients.
This method updates the model parameters by taking a step in the direction of the computed gradients. The optimizer used depends on the optimizer model provided.
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
Bases:
object
Linearly updates the learning rate in the optimizer
The linear learning rate scheduler decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.
- Parameters:
optimizer – User’s onnxruntime training Optimizer
warmup_step_count – The number of steps in the warm up phase.
total_step_count – The total number of training steps.
initial_lr – The initial learning rate.