tensorrt_llm/network.py
The Network Class
The network.py
file is a core component of the TensorRT-LLM framework that defines the Network
class and its associated utility functions and classes.
It provides an abstraction layer over the TensorRT network definition and manages various aspects of the network, such as layer naming, plugin information, and module call stack.
Let's break down the file and analyse its components in detail:
Imports
The file starts by importing necessary modules, including
collections
,contextlib
,hashlib
,weakref
,numpy
,tensorrt
, and other custom modules from the TensorRT-LLM framework.
The Network
class
Network
class The Network
class in the provided script is a crucial component of the TensorRT-LLM architecture.
It serves as a high-level wrapper around the TensorRT INetworkDefinition
interface, providing a more convenient and intuitive way to define and manipulate the structure of a neural network.
Let's break down the Network
class and analyze its key components and functions:
Initialization
The
__init__
method is not meant to be called directly by users. Instead, theBuilder.create_network()
method should be used to create a newNetwork
instance.The
_init
method is responsible for initializing theNetwork
object with a TensorRTINetworkDefinition
instance.It sets up various attributes such as inputs, named parameters, plugin configuration, and auto-parallel configuration.
Input and Output Management
The
_add_input
method is used to add an input tensor to the network, specifying its name, data type, shape, and optional dimension range.The
_mark_output
method marks a tensor as an output of the network, optionally casting it to a specified data type if strongly typed mode is enabled.The
get_inputs
andget_outputs
methods allow retrieving the input and output tensors of the network, respectively.
Layer and Tensor Manipulation
The
get_layers
method returns an iterable of all the layers in the network.The
get_layer_by_name
method retrieves a layer by its name.The
get_tensor_users
method returns the layers that consume a specific tensor.The
get_tensor_parent
method returns the layer that produces a specific tensor.The
mark_removed_layer
andis_removed_layer
methods are used to mark and check if a layer has been removed from the network.
Graph Visualization
The
to_dot
method generates a graphviz representation of the network, allowing visualization of the layers and tensor connections.It creates nodes for each layer and tensor, and edges representing the producer-consumer relationships.
Internal Graph State
The
_get_graph
method returns the internal graph state of the network.The
_GraphState
class represents the internal graph state, which includes mappings between tensors and their consumer/producer layers, as well as input and output tensors.
Network Hashing
The
_get_network_hash
method generates a hash value for the network based on its structure and layer connections.This hash is used to uniquely identify the network and can be used for caching or comparison purposes.
Context Management
The
net_guard
function is a context manager that temporarily sets the current network to the providedNetwork
instance.It allows scoping the network usage within a specific context and restores the previous network when exiting the context.
In the context of TensorRT-LLM, the Network
class plays a vital role in defining the structure and connectivity of the neural network.
It provides a high-level interface for creating and manipulating layers, tensors, and their relationships.
The Network
class abstracts away the low-level details of the TensorRT API, making it easier to define and optimize the network for LLM tasks.
The Network
class is closely integrated with other components of TensorRT-LLM, such as the Builder
class, which is responsible for creating the network and building the optimised TensorRT engine.
The Network
class also interacts with the graph rewriting and optimisation modules to perform transformations and optimisations on the network structure.
Overall, the Network
class serves as a fundamental building block in TensorRT-LLM, providing a high-level representation of the neural network and enabling efficient definition, manipulation, and optimization of the network structure for large language model tasks.
_UniqueNameGenerator
class
This class is responsible for generating unique names for layers in the network.
It maintains a dictionary (
self.ids
) to keep track of the count for each key.The
__call__
method takes akey
and an optionalmodule_name
and returns a unique name by appending the count to the key.
PluginInfo
class
This class represents the information about a plugin in the network.
It contains the plugin creator (
plugin_creator
), plugin name (plugin_name
), and plugin field collection (pfc
).The
__init__
method initialises the plugin information and parses the plugin field collection.The
_parse_pfc
method converts the plugin field collection into NumPy arrays and lists for easier access.
Plugin Info Utility Functions
get_plugin_info
,set_plugin_info
, anddelete_plugin_info
functions are used to retrieve, set, and delete plugin information from the TensorRT network.They use the
get_extra_attr
andset_extra_attr
functions to store and retrieve plugin information as extra attributes of the network.
Weight Management Utility Functions
get_np_weight
andset_np_weight
functions are used to retrieve and set NumPy weights for a layer in the network.They use the
get_extra_attr
andset_extra_attr
functions to store and retrieve weights as extra attributes of the network.These functions are marked as TODO to be removed after a specific bug is fixed.
Network
class
Network
classThis is the main class that represents the TensorRT-LLM network.
The
__init_
_
method initializes various attributes of the network, such as removed layers, graph alteration flag, and a functional layer information memo.The
_init
method is called to initialize the network with a TensorRT network object (trt_network
).It sets up various attributes, including inputs, named parameters, layer precision, name generator, plugin configuration, module call stack, registered NumPy arrays, strongly typed flag, and unfilled weights.
The
_get_hash
method generates a hash of the network based on its layers, inputs, and outputs.
net_guard
context manager
net_guard
context managerThis context manager is used to temporarily set the current network context.
It takes a
network
object and sets it as the current network using theset_network
function from the_common
module.It yields control to the enclosed block of code and restores the previous network context when the block is exited.
_TrtLlmModuleCallStack
class
_TrtLlmModuleCallStack
classThis class manages the module call stack for the network.
It maintains a list of module names in the call stack and a mapping of module objects to their names.
It provides methods to set module names, get the current module, get the module name for a given module object, set layer range for a module, and manage the call stack.
Overall, the network.py
file provides the core functionality for managing the TensorRT-LLM network, including layer naming, plugin information handling, weight management, and module call stack tracking.
It offers a high-level abstraction over the TensorRT network definition and provides utilities for interacting with the network and its components.
In summary, the network.py
file is a critical component of the TensorRT-LLM framework that provides the foundation for building and managing the TensorRT network.
It offers a clean and efficient API for interacting with the network, handling plugins, weights, and module call stack, making it easier to work with TensorRT in the context of large language models.
Last updated