Page cover image

PreTrainedModel class

From the Transformers Library

The PreTrainedModel class is a base class provided by the Transformers library that serves as a foundation for all pretrained models.

It provides common functionality and methods for loading, saving, and modifying pretrained models.

Configuration

  • The PreTrainedModel class takes a config parameter, which is an instance of the PretrainedConfig class or its subclasses.

  • The config object stores the configuration of the model, such as the number of layers, hidden size, attention heads, etc.

Class Attributes

  • config_class: A subclass of PretrainedConfig that is used as the configuration class for the specific model architecture.

  • load_tf_weights: A callable method for loading weights from a TensorFlow checkpoint into a PyTorch model.

  • base_model_prefix: A string indicating the attribute associated with the base model in derived classes that add modules on top of the base model.

  • is_parallelizable: A boolean flag indicating whether the model supports parallelization.

  • main_input_name: The name of the main input to the model (e.g., input_ids for NLP models).

Methods

  • from_pretrained: A class method that instantiates a pretrained model from a configuration and pretrained weights.

    • It allows loading models from a local directory, a remote repository, or a TensorFlow/Flax checkpoint.

    • It supports various options such as specifying the configuration, state dictionary, cache directory, etc.

  • save_pretrained: A method to save the model's configuration and state dictionary to a specified directory.

  • push_to_hub: A method to upload the model to the Hugging Face Model Hub repository.

  • from_tf: A method to load the model weights from a TensorFlow checkpoint.

  • from_flax: A method to load the model weights from a Flax checkpoint.

Parallelization and Distributed Training

  • The is_parallelizable attribute indicates whether the model can be parallelized across multiple devices or not.

  • The from_pretrained method supports loading models in a distributed manner using the device_map argument, which allows specifying the device placement for each submodule of the model.

Model Modifications

  • The PreTrainedModel class provides methods to modify the model's architecture, such as resize_token_embeddings to resize the input token embeddings and prune_heads to prune the attention heads.

Quantization and Optimization

  • The from_pretrained method supports quantization and optimization configurations through the quantization_config argument, allowing for quantized model loading using libraries like bitsandbytes.

Saving and Loading

  • The save_pretrained method allows saving the model's configuration and state dictionary to a specified directory.

  • The from_pretrained method supports loading models from a saved directory, a pre-trained model configuration, or a TensorFlow/Flax checkpoint.

The PreTrainedModel class provides a unified interface for working with pretrained models in the Transformers library.

It abstracts away the complexities of loading, saving, and modifying models, making it easier to use and extend pretrained models for various tasks.

Developers can subclass the PreTrainedModel class to create their own custom models while leveraging the common functionalities provided by the base class.

This promotes code reuse, maintainability, and consistency across different model architectures.

Overall, the PreTrainedModel class is a fundamental building block in the Transformers library, enabling seamless integration and utilisation of pretrained models in a wide range of natural language processing and computer vision tasks.

Last updated

Was this helpful?