PreTrainedModel class
From the Transformers Library
The PreTrainedModel
class is a base class provided by the Transformers library that serves as a foundation for all pretrained models.
It provides common functionality and methods for loading, saving, and modifying pretrained models.
Configuration
The
PreTrainedModel
class takes aconfig
parameter, which is an instance of thePretrainedConfig
class or its subclasses.The
config
object stores the configuration of the model, such as the number of layers, hidden size, attention heads, etc.
Class Attributes
config_class
: A subclass ofPretrainedConfig
that is used as the configuration class for the specific model architecture.load_tf_weights
: A callable method for loading weights from a TensorFlow checkpoint into a PyTorch model.base_model_prefix
: A string indicating the attribute associated with the base model in derived classes that add modules on top of the base model.is_parallelizable
: A boolean flag indicating whether the model supports parallelization.main_input_name
: The name of the main input to the model (e.g.,input_ids
for NLP models).
Methods
from_pretrained
: A class method that instantiates a pretrained model from a configuration and pretrained weights.It allows loading models from a local directory, a remote repository, or a TensorFlow/Flax checkpoint.
It supports various options such as specifying the configuration, state dictionary, cache directory, etc.
save_pretrained
: A method to save the model's configuration and state dictionary to a specified directory.push_to_hub
: A method to upload the model to the Hugging Face Model Hub repository.from_tf
: A method to load the model weights from a TensorFlow checkpoint.from_flax
: A method to load the model weights from a Flax checkpoint.
Parallelization and Distributed Training
The
is_parallelizable
attribute indicates whether the model can be parallelized across multiple devices or not.The
from_pretrained
method supports loading models in a distributed manner using thedevice_map
argument, which allows specifying the device placement for each submodule of the model.
Model Modifications
The
PreTrainedModel
class provides methods to modify the model's architecture, such asresize_token_embeddings
to resize the input token embeddings andprune_heads
to prune the attention heads.
Quantization and Optimization
The
from_pretrained
method supports quantization and optimization configurations through thequantization_config
argument, allowing for quantized model loading using libraries likebitsandbytes
.
Saving and Loading
The
save_pretrained
method allows saving the model's configuration and state dictionary to a specified directory.The
from_pretrained
method supports loading models from a saved directory, a pre-trained model configuration, or a TensorFlow/Flax checkpoint.
The PreTrainedModel
class provides a unified interface for working with pretrained models in the Transformers library.
It abstracts away the complexities of loading, saving, and modifying models, making it easier to use and extend pretrained models for various tasks.
Developers can subclass the PreTrainedModel
class to create their own custom models while leveraging the common functionalities provided by the base class.
This promotes code reuse, maintainability, and consistency across different model architectures.
Overall, the PreTrainedModel
class is a fundamental building block in the Transformers library, enabling seamless integration and utilisation of pretrained models in a wide range of natural language processing and computer vision tasks.
Last updated