tvm.relax.training
The Relax training APIs.
- class tvm.relax.training.SetupTrainer(loss: Loss, optimizer: Optimizer, loss_args: list[TensorStructInfo], legalize=True)
Transform a backbone module to a complete, legalized trainer module.
The provided backbone module should contain at least a function named backbone, and has two int attributes param_num and state_num, as follows:
Here each of input_instances, parameters, states, backbone_result and updated_states can denote a number of parameters. The length of parameters and the length of states is specified by param_num and state_num respectively.
states denote the states that we need to maintain as the training process proceeds, such as the running mean and the running var of the batch norm operator. The updated states is returned in updated_states. States can be empty if there is no state that needs to be updated.
The transformed module will at least contain the functions and attributes listed below:
The transformed module contains an attribute optim_states as the initial optimizer states.
Then the transformed module will be legalized by relax.transform.LegalizeOps() to lower relax operators into TIR functions.
- Parameters:
loss (Loss) – The loss function. It will be appended to the backbone function using relax.transform.AppendLoss.
optimizer (Optimizer) – The optimizer. It will be put as the optimizer function of the transformed module.
loss_args (List[TensorStructInfo]) – The arguments to call the loss function.
legalize (bool) – Whether to legalize the module. Default: True.
- class tvm.relax.training.Trainer(train_mod: IRModule, vm: VirtualMachine, device: Device, zero_init_param_state: bool = True)
Unified wrapper for relax training. It accepts the IRModule (that is the result of SetupTrainer) and the relax VM (that contains the built result of the IRModule), and helps run the VM. It maintains the parameters, the model states and the optimizer states internally.
- Parameters:
train_mod (tvm.IRModule) – The IRModule that will be run. Should be the result of a backbone module being transformed by the SetupTrainer pass.
vm (tvm.relax.VirtualMachine) – The relax virtual machine that contains the built result of train_mod. Considering the complexity and flexibility of building, we require user build the train_mod outside of trainer and pass the result vm.
device (tvm.runtime.Device) – The device to place the parameters and states in.
zero_init_param_state (bool) – If true, all parameters and states will be inited to zero. It requires all parameters and states have static shape.
Examples
- xaiver_uniform_init_params()
Xaiver uniformly initialize parameters using the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010).
Requires all parameters have static shapes.
- zero_init_params()
Zero initialize all parameters. Requires all parameters have static shapes.
- zero_init_states()
Zero initialize all states. Requires all states have static shapes.
- load_params(params: list[ndarray | Tensor] | dict[str, ndarray | Tensor])
Load parameters from a dict or a list. Will convert parameters into tvm.runtime.Tensor in self.device.
- Parameters:
params (List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]]) –
The numerical value of the parameters.
If params is a list, its length should be param_num. The value of parameters at the corresponding index will be updated.
If params is a dict, it should map variable name to value. The name should be the same as the parameter name in the backbone function. The values of the corresponding parameters will be updated.
- load_states(states: list[ndarray | Tensor] | dict[str, ndarray | Tensor])
Load model states from a dict or a list. Will convert states into tvm.runtime.Tensor in self.device.
- Parameters:
states (List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]]) –
The numerical value of the model states.
If states is a list, its length should be state_num. The value of states at the corresponding index will be updated.
If params is a dict, it should map variable name to value. The name should be the same as the state name in the backbone function. The values of the corresponding states will be updated.
- predict(*input_instances: ndarray | Tensor) Tensor
relax.Call the backbone function and return the prediction result of the backbone.
- Parameters:
*input_instances (Union[np.ndarray, Tensor]) – The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide.
- Returns:
output – The result of the backbone function. If the backbone contains model states, the updated states WILL NOT be returned.
- Return type:
- update(input_instances: ndarray | Tensor | list[ndarray | Tensor], targets: ndarray | Tensor | list[ndarray | Tensor]) Tensor
Update parameters and model states. It will calculate the gradients of parameters and update them using the optimizer function.
Parameters, model states and optimizer states are provided in the function, so you do not need to provied them.
- Parameters:
input_instances (Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]]) –
The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide.
If there are more than one input instances, you can provide a list.
targets (Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]]) –
The values corresponding to the targets part of the backbone function.
If there are more than one targets, you can provide a list.
- Returns:
loss – The loss stored in tvm.runtime.Tensor.
- Return type:
- profile_adjoint(input_instances: list[ndarray | Tensor], targets: list[ndarray | Tensor]) Report
Profile the adjoint function. It requires the VM to be constructed with profile=True, and runs tvm.relax.VirtualMachine.profile() internally.
- Parameters:
input_instances (Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]]) –
The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide.
If there are more than one input instances, you can provide a list.
targets (Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]]) –
The values corresponding to the targets part of the backbone function.
If there are more than one targets, you can provide a list.
- Returns:
report – The formatted profiling result.
- Return type:
- tvm.relax.training.AppendLoss(func_name: str, loss_function: Function, num_backbone_outputs: int = 1, new_func_name: str | None = None) Pass
Append the loss function to the backbone function specified by func_name. Generally, the loss function is generated by instances of relax.training.Loss.
The backbone function and the loss function should satisfy a few restrictions: - Both backbone and loss should contain exactly one DataflowBlock. - Backbone should return either one relax.Var, or a tuple of Vars - Loss should return a scalar(0-dim Tensor) relax.Var
They should be like:
Here each of input_instances, parameters, states, backbone_result and updated_states can denote a number of parameters.
states denote the states that we need to maintain as the training process proceeds, such as the running mean and the running var of the batch norm operator. The updated states is returned in updated_states. States can be empty if there is no state that needs to be updated.
The appended result contains only one DataflowBlock containing all bindings in backbone and loss. It will be like:
- Parameters:
func_name (str) – The name of the backbone function in the IRModule.
loss_func (Function) – The loss function.
num_backbone_outputs (int) – Specify the number of prediction_outputs of the backbone function. Default: 1.
new_func_name (Optional[str]) – Specify the name of the appended result. If it is not specified, the name will be func_name + “_loss”.
- Returns:
ret – The result function.
- Return type:
Examples
Will get
Notes
This util can be replaced if we have inline pass. It is equivalent to inline a tail call in some sense.