Weight Bindings
Weight bindings refer to the process of assigning trained model weights to the corresponding parameters in the TensorRT-LLM model definition before compiling the TensorRT engine.
This is necessary because TensorRT engines embed the network weights, which must be known at the time of compilation.
The process is as follows
Model Definition
When defining a model using the TensorRT-LLM Python API, you create instances of various layers and modules, such as the Linear layer. These layers and modules have parameters that represent the learnable weights of the model.
Parameter Definition
In the model definition, you define the parameters for each layer or module.
For example, in the Linear layer, you define the weight
and bias
parameters using the Parameter
class. You specify the shape and data type of these parameters based on the layer's configuration.
Weight Loading
After defining the model architecture, you need to load the trained weights from a checkpoint or a pre-trained model. These weights are typically stored in a file format specific to the training framework, such as PyTorch or TensorFlow.
Weight Binding
To bind the loaded weights to the model parameters, you assign the weight values to the corresponding parameter attributes in the model definition.
This is done by accessing the value
attribute of each parameter and assigning the loaded weight data to it. For example, in the code snippet provided:
The fromfile
function is used to load the weight data from a file, and the loaded data is assigned to the value
attribute of the weight
and bias
parameters of the fully connected (FC) layer in the MLP module of the GPT model.
Engine Compilation
After binding the weights to the model parameters, you can proceed with building the TensorRT engine using the tensorrt_llm.Builder.build_engine
function.
During the engine compilation process, TensorRT takes the model definition along with the bound weights and optimises the computation graph for efficient execution on the target GPU.
Weight Refitting (Optional)
TensorRT also supports the ability to refit engines with updated weights after compilation.
This feature is available in TensorRT-LLM through the refit_engine
method in the tensorrt_llm.Builder
class.
Refitting allows you to update the weights of an existing engine without the need to recompile the entire engine from scratch, which can save time in certain scenarios.
By binding the weights to the model parameters before compiling the TensorRT engine, you ensure that the engine has access to the trained weights and can perform inference accurately.
The weight binding process bridges the gap between the model definition and the trained weights, allowing TensorRT to optimise the computation graph and generate an efficient engine for execution.
Last updated