top_model_mixin.py
The top_model_mixin.py
script defines a mixin class called TopModelMixin
that provides common functionalities and interfaces for top-level model classes in the TensorRT-LLM framework.
Let's break down the script:
Imports
The script imports necessary modules and classes from the TensorRT-LLM framework, including
LoraBuildConfig
,Mapping
, andPluginConfig
.
TopModelMixin
Class
TopModelMixin
ClassThe
TopModelMixin
class is defined as a mixin class that can be inherited by top-level model classes likeLLaMAForCausalLM
.It provides common functionalities and interfaces that are specific to top-level models and not applicable to building blocks like Attention or MLP.
__init__
Method:
__init__
Method:The
__init__
method is an empty method that can be overridden by subclasses if needed.
from_hugging_face
Class Method
from_hugging_face
Class MethodThe
from_hugging_face
class method is a placeholder method that subclasses should override.It is intended to create an LLM object and load weights from a Hugging Face model directory.
The method takes parameters such as
hf_model_dir
,dtype
, andmapping
to specify the model directory, default weights data type, and multi-GPU parallel strategy.
convert_hf_checkpoint
Class Method
convert_hf_checkpoint
Class MethodThe
convert_hf_checkpoint
class method is another placeholder method that subclasses should override.It is intended to convert a Hugging Face checkpoint to a TRT-LLM checkpoint.
The method takes parameters such as
hf_model_dir
,dtype
, andoutput_dir
to specify the Hugging Face model directory, default weights data type, and output directory for the converted checkpoint.
use_lora
Method
use_lora
MethodThe
use_lora
method is a placeholder method that subclasses should override.It is intended to load LoRA (Low-Rank Adaptation) weights from a given configuration to the module.
The method takes a
lora_config
parameter of typeLoraBuildConfig
.
use_prompt_tuning
Method
use_prompt_tuning
MethodThe
use_prompt_tuning
method is a placeholder method that subclasses should override.It is intended to enable prompt tuning when building the TRT engine.
The method takes parameters such as
max_prompt_embedding_table_size
andprompt_table_path
to specify the maximum size of the prompt embedding table and the path to the prompt table.
default_plugin_config
Method
default_plugin_config
MethodThe
default_plugin_config
method is a placeholder method that subclasses should override.It is intended to return the default plugin configuration for the model when the
plugin_config
value is not provided in theto_trt()
call.The method takes arbitrary keyword arguments (
**kwargs
) and returns aPluginConfig
object.
Overall, the TopModelMixin
class serves as a blueprint for top-level model classes in the TensorRT-LLM framework
It defines common methods and interfaces that subclasses should implement to support functionalities like loading weights from Hugging Face, converting checkpoints, using LoRA, enabling prompt tuning, and configuring plugins.
Subclasses can inherit from this mixin and override the placeholder methods with their specific implementations.
Last updated