trtllm build configuration file
To create a script that takes the buildconfig.yaml
file and parses its arguments into the trtllm build
command line, you can use the argparse
and yaml
modules in Python. Here's an example script that accomplishes this:
import argparse
import subprocess
import yaml
def parse_buildconfig(config_file):
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
args = []
# Model Configuration
if 'model' in config:
model_config = config['model']
if 'model_dir' in model_config:
args.extend(['--checkpoint_dir', model_config['model_dir']])
if 'output_dir' in model_config:
args.extend(['--output_dir', model_config['output_dir']])
if 'dtype' in model_config:
args.extend(['--logits_dtype', model_config['dtype']])
# Checkpoint Configuration
if 'checkpoint' in config:
checkpoint_config = config['checkpoint']
if 'checkpoint_dir' in checkpoint_config:
args.extend(['--checkpoint_dir', checkpoint_config['checkpoint_dir']])
if 'tp_size' in checkpoint_config:
args.extend(['--tp_size', str(checkpoint_config['tp_size'])])
if 'pp_size' in checkpoint_config:
args.extend(['--pp_size', str(checkpoint_config['pp_size'])])
if 'vocab_size' in checkpoint_config:
args.extend(['--vocab_size', str(checkpoint_config['vocab_size'])])
if 'n_positions' in checkpoint_config:
args.extend(['--n_positions', str(checkpoint_config['n_positions'])])
if 'n_layer' in checkpoint_config:
args.extend(['--n_layer', str(checkpoint_config['n_layer'])])
if 'n_head' in checkpoint_config:
args.extend(['--n_head', str(checkpoint_config['n_head'])])
if 'n_embd' in checkpoint_config:
args.extend(['--n_embd', str(checkpoint_config['n_embd'])])
if 'inter_size' in checkpoint_config:
args.extend(['--inter_size', str(checkpoint_config['inter_size'])])
if 'n_kv_head' in checkpoint_config:
args.extend(['--n_kv_head', str(checkpoint_config['n_kv_head'])])
if 'rms_norm_eps' in checkpoint_config:
args.extend(['--rms_norm_eps', str(checkpoint_config['rms_norm_eps'])])
if 'bos_token_id' in checkpoint_config:
args.extend(['--bos_token_id', str(checkpoint_config['bos_token_id'])])
if 'eos_token_id' in checkpoint_config:
args.extend(['--eos_token_id', str(checkpoint_config['eos_token_id'])])
if 'tie_word_embeddings' in checkpoint_config:
args.extend(['--tie_word_embeddings', str(checkpoint_config['tie_word_embeddings'])])
if 'use_cache' in checkpoint_config:
args.extend(['--use_cache', str(checkpoint_config['use_cache'])])
if 'torch_dtype' in checkpoint_config:
args.extend(['--torch_dtype', checkpoint_config['torch_dtype']])
if 'hidden_act' in checkpoint_config:
args.extend(['--hidden_act', checkpoint_config['hidden_act']])
# Build Configuration
if 'build' in config:
build_config = config['build']
if 'max_input_len' in build_config:
args.extend(['--max_input_len', str(build_config['max_input_len'])])
if 'max_output_len' in build_config:
args.extend(['--max_output_len', str(build_config['max_output_len'])])
if 'max_batch_size' in build_config:
args.extend(['--max_batch_size', str(build_config['max_batch_size'])])
if 'max_beam_width' in build_config:
args.extend(['--max_beam_width', str(build_config['max_beam_width'])])
if 'max_prompt_embedding_table_size' in build_config:
args.extend(['--max_prompt_embedding_table_size', str(build_config['max_prompt_embedding_table_size'])])
if 'gather_context_logits' in build_config:
args.extend(['--gather_context_logits', str(build_config['gather_context_logits'])])
if 'gather_generation_logits' in build_config:
args.extend(['--gather_generation_logits', str(build_config['gather_generation_logits'])])
if 'strongly_typed' in build_config:
args.extend(['--strongly_typed', str(build_config['strongly_typed'])])
if 'profiling_verbosity' in build_config:
args.extend(['--profiling_verbosity', build_config['profiling_verbosity']])
if 'enable_debug_output' in build_config:
args.extend(['--enable_debug_output', str(build_config['enable_debug_output'])])
if 'max_draft_len' in build_config:
args.extend(['--max_draft_len', str(build_config['max_draft_len'])])
if 'use_refit' in build_config:
args.extend(['--use_refit', str(build_config['use_refit'])])
if 'weight_sparsity' in build_config:
args.extend(['--weight_sparsity', str(build_config['weight_sparsity'])])
if 'max_encoder_input_len' in build_config:
args.extend(['--max_encoder_input_len', str(build_config['max_encoder_input_len'])])
if 'use_fused_mlp' in build_config:
args.extend(['--use_fused_mlp', str(build_config['use_fused_mlp'])])
if 'dry_run' in build_config:
args.extend(['--dry_run', str(build_config['dry_run'])])
if 'visualize_network' in build_config:
args.extend(['--visualize_network', str(build_config['visualize_network'])])
return args
def main():
parser = argparse.ArgumentParser(description='Parse buildconfig.yaml and run trtllm-build')
parser.add_argument('--config', type=str, required=True, help='Path to the buildconfig.yaml file')
args = parser.parse_args()
buildconfig_args = parse_buildconfig(args.config)
command = ['trtllm-build'] + buildconfig_args
subprocess.run(command, check=True)
if __name__ == '__main__':
main()
# TensorRT-LLM Build Configuration File
model:
model_dir: ./llama-2-7b-chat-hf # Path to the pretrained model directory
output_dir: ./llama-2-7b-chat-engine # Path to save the built engine
dtype: float16 # Data type for the model (choices: float32, float16, bfloat16)
checkpoint:
checkpoint_dir: ../llama-2-7b-chat-hf-output # Path to the TensorRT-LLM checkpoint directory
tp_size: 1 # Tensor parallelism size, increase for multi-GPU tensor parallelism
pp_size: 1 # Pipeline parallelism size, increase for multi-GPU pipeline parallelism
vocab_size: 32000 # Vocabulary size of the model
n_positions: 2048 # Maximum number of positions (sequence length)
n_layer: 32 # Number of layers in the model
n_head: 32 # Number of attention heads
n_embd: 4096 # Hidden size of the model
inter_size: 11008 # Intermediate size of the model's feed-forward layers
#meta_ckpt_dir: # Path to the meta checkpoint directory
#n_kv_head: # Number of key-value heads (defaults to n_head if not specified)
#rms_norm_eps: 1e-6 # Epsilon value for RMS normalization
#use_weight_only: false # Enable weight-only quantization
#weight_only_precision: int8 # Precision for weight-only quantization (choices: int8, int4)
#smoothquant: 0.5 # Smoothquant parameter for quantization
#per_channel: false # Enable per-channel quantization
#per_token: false # Enable per-token quantization
#int8_kv_cache: false # Enable int8 quantization for key-value cache
#ammo_quant_ckpt_path: # Path to the quantized checkpoint file in .npz format
#per_group: false # Enable per-group quantization for GPTQ/AWQ quantization
#load_by_shard: false # Load the pretrained model shard-by-shard
#hidden_act: silu # Activation function used in the model (default: silu)
#rotary_base: 10000.0 # Base value for rotary positional embeddings
#group_size: 128 # Group size used in GPTQ quantization
#dataset_cache_dir: # Path to the dataset cache directory
#load_model_on_cpu: false # Load the model on CPU
#use_parallel_embedding: false # Enable embedding parallelism
#embedding_sharding_dim: 0 # Dimension for embedding sharding (choices: 0, 1)
#use_embedding_sharing: false # Enable embedding sharing to reduce engine size
#workers: 1 # Number of workers for parallel checkpoint conversion
#moe_num_experts: 0 # Number of experts for Mixture of Experts (MoE) layers
#moe_top_k: 0 # Top-k value for MoE layers (defaults to 1 if moe_num_experts is set)
#moe_tp_mode: 0 # Parallelism mode for distributing MoE experts in tensor parallelism
#moe_renorm_mode: 1 # Renormalization mode for MoE gate logits
#save_config_only: false # Only save the model configuration without building the engine
#disable_weight_only_quant_plugin: false # Disable the weight-only quantization plugin
build:
max_input_len: 256 # Maximum input sequence length
max_output_len: 256 # Maximum output sequence length
max_batch_size: 8 # Maximum batch size
max_beam_width: 1 # Maximum beam width for beam search
#max_num_tokens: # Maximum number of tokens to generate
#opt_num_tokens: # Optimal number of tokens to generate
max_prompt_embedding_table_size: 0 # Maximum size of the prompt embedding table
gather_context_logits: false # Gather context logits during generation
gather_generation_logits: false # Gather generation logits during generation
strongly_typed: false # Enable strongly typed network definition
#builder_opt: # Builder optimization level
profiling_verbosity: layer_names_only # Profiling verbosity level (choices: layer_names_only, detailed, none)
enable_debug_output: false # Enable debug output
max_draft_len: 0 # Maximum draft length for Medusa-style generation
use_refit: false # Enable engine refitting
#input_timing_cache: # Path to the input timing cache file
#output_timing_cache: # Path to save the output timing cache file
lora_config: # Configuration for LoRA (Low-Rank Adaptation)
#lora_dir: # Path to the LoRA checkpoint directory
#lora_target_modules: # Target modules for LoRA adaptation
#lora_ckpt_source: hf # Source of LoRA checkpoints (choices: hf, nemo)
#max_lora_rank: 4 # Maximum rank for LoRA adaptation
auto_parallel_config: # Configuration for automatic parallelization
#enabled: false # Enable automatic parallelization
#tp_size: 1 # Tensor parallelism size for automatic parallelization
#pp_size: 1 # Pipeline parallelism size for automatic parallelization
#max_memory_MB: 80000 # Maximum memory in MB for automatic parallelization
#max_dram_memory_MB: 30000 # Maximum DRAM memory in MB for automatic parallelization
#compile_max_memory_MB: 17000 # Maximum memory in MB for compilation during automatic parallelization
#compile_max_dram_memory_MB: 8000 # Maximum DRAM memory in MB for compilation during automatic parallelization
#debug_mode: false # Enable debug mode for automatic parallelization
weight_sparsity: false # Enable weight sparsity
plugin_config: # Configuration for plugins
#use_custom_all_reduce: false # Use custom all-reduce plugin
#use_fp8_all_reduce: false # Use FP8 all-reduce plugin
#use_fp8_cast_plugin: false # Use FP8 cast plugin
#use_async_malloc: false # Use asynchronous memory allocation plugin
#use_paged_context_fmha: false # Use paged context fused multi-head attention plugin
#use_fp8_context_fmha: false # Use FP8 context fused multi-head attention plugin
#lora_plugin: # Configuration for LoRA plugin
#type: # Type of LoRA plugin
max_encoder_input_len: 1024 # Maximum encoder input sequence length for encoder-decoder models
use_fused_mlp: false # Use fused MLP layers
dry_run: false # Perform a dry run without building the engine
visualize_network: false # Visualize the network graph
The updated script includes all the relevant configurations from the buildconfig.yaml
file.
It parses the model, checkpoint, and build configurations and constructs the corresponding command-line arguments for the trtllm-build
command.
python3 buildrun.py --config buildconfig.yaml
python build_trtllm.py --config buildconfig.yaml
This configuration file is divided into three main sections:
Model Configuration
Specifies the paths to the pretrained model directory and the output directory where the built engine will be saved. It also allows you to set the data type for the model (float32, float16, or bfloat16).
Checkpoint Configuration
Defines the settings related to the TensorRT-LLM checkpoint, such as the checkpoint directory, tensor parallelism size, pipeline parallelism size, and various model-specific parameters like vocabulary size, number of layers, attention heads, hidden size, etc.
Many of these settings are optional and can be uncommented and adjusted based on the specific model requirements.
Build Configuration
Contains the parameters for building the TensorRT engine, including maximum input and output sequence lengths, maximum batch size, beam width, prompt embedding table size, and various optimization and debugging options.
It also allows you to configure LoRA (Low-Rank Adaptation), automatic parallelization, weight sparsity, and plugin-specific settings.
Last updated