Skip to main content

Gemma 2B Fine Tuned Lightweight model

 Kaggle Notebook

Gemma 2B Fine Tuned Lightweight model

Step 1: Configure GPU for Memory Growth

python

gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print("GPU memory growth enabled.") except RuntimeError as e: print(e) else: print("No GPU found. Using CPU.")
  • Purpose: Ensures the GPU is set to dynamically allocate memory instead of pre-allocating all available GPU memory. This approach prevents memory wastage and allows multiple processes to use the GPU without running into memory allocation errors.
  • Technical Details:
    • tf.config.list_physical_devices('GPU'): Lists available GPUs.
    • tf.config.experimental.set_memory_growth(gpu, True): Allows TensorFlow to allocate GPU memory on demand.
    • Fallback: If no GPU is found, the code defaults to CPU computation.

Step 2: Enable Mixed Precision for Memory Optimization

python
policy = tf.keras.mixed_precision.Policy("mixed_float16") set_global_policy(policy) print(f"Mixed precision enabled with policy: {policy}")
  • Purpose: Reduces memory usage and increases computational speed by using lower-precision data types (e.g., float16) where appropriate, while keeping critical calculations in higher precision (e.g., float32).
  • Technical Details:
    • tf.keras.mixed_precision.Policy("mixed_float16"): Specifies the use of float16 for operations and float32 for accumulations.
    • set_global_policy(policy): Globally applies the mixed precision policy.
    • Mixed precision is especially effective on GPUs with Tensor Cores (e.g., NVIDIA Volta, Ampere).

Step 3: Load a Smaller Model Variant

python
try: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_1b_en") print("Loaded Gemma LM (1B model).") except: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en") print("Loaded Gemma LM (2B model).")
  • Purpose: Dynamically load a smaller model variant if possible, reducing memory and computational requirements. Falls back to a larger model if the smaller one is unavailable.
  • Technical Details:
    • keras_nlp.models.GemmaCausalLM.from_preset: Loads a preconfigured language model (Gemma LM) with pretrained weights.
    • "gemma2_instruct_1b_en": A 1-billion parameter variant.
    • "gemma2_instruct_2b_en": A 2-billion parameter variant used as a fallback.

Step 4: Apply Low-Rank Adaptation (LoRA) for Reduced Parameters

Python
gemma_lm.backbone.enable_lora(rank=2) print("LoRA enabled with rank=2.")
  • Purpose: Reduces the memory footprint of the model by adapting its parameter matrices using a low-rank decomposition.
  • Technical Details:
    • LoRA modifies the transformer layers to optimize memory use while retaining performance.
    • rank=2: Specifies the rank of the adaptation, balancing efficiency and accuracy.

Step 5: Reduce Sequence Length for Lower Memory Usage

python
gemma_lm.preprocessor.sequence_length = 128 print("Sequence length set to 128.")
  • Purpose: Lowers the memory usage during training or inference by reducing the number of tokens processed per sequence.
  • Technical Details:
    • Shorter sequences mean fewer computations, leading to faster and more memory-efficient runs.
    • The reduction from a typical length (e.g., 512 or 1024) to 128 is significant in terms of resource savings.

Step 6: Compile the Model with Optimized Settings

python
initializer = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05) gemma_lm.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5), weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], ) print("Model compiled successfully.")
  • Purpose: Prepares the model for training or inference by configuring loss functions, optimizers, and metrics.
  • Technical Details:
    • Initializer:
      • tf.keras.initializers.TruncatedNormal: Ensures model weights start close to zero, improving convergence.
    • Loss:
      • SparseCategoricalCrossentropy(from_logits=True): Suitable for multi-class classification tasks where outputs are logits.
    • Optimizer:
      • Adam(learning_rate=3e-5): A widely used optimizer balancing efficiency and convergence.
    • Metrics:
      • SparseCategoricalAccuracy: Tracks classification accuracy for sparse label formats.

Step 7: Save Optimized Versions of the Model

Saving Model Weights

python
weights_path = "gemma_lm_lightweight.weights.h5" gemma_lm.backbone.save_weights(weights_path) print(f"Model weights saved to: {weights_path}")
  • Purpose: Saves only the weights of the backbone model to a lightweight file for reuse or transfer.
  • Technical Details:
    • .h5 format: Common for Keras models and weights storage.

Saving Quantized TensorFlow Lite Model

python
converter = tf.lite.TFLiteConverter.from_keras_model(gemma_lm) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_model = converter.convert() tflite_path = "gemma_lm_lightweight_v2.tflite" with open(tflite_path, "wb") as f: f.write(quantized_model) print(f"Quantized model saved to: {tflite_path}")
  • Purpose: Converts the model to TensorFlow Lite format with quantization, making it suitable for deployment on resource-constrained devices.
  • Technical Details:
    • tf.lite.TFLiteConverter.from_keras_model: Converts a Keras model to TensorFlow Lite.
    • converter.optimizations = [tf.lite.Optimize.DEFAULT]: Applies optimizations such as quantization to reduce size and improve performance.
    • .tflite: A lightweight format for deployment.

Saving Backbone Only

python
backbone_path = "gemma_lm_lightweight_backbone.h5" gemma_lm.backbone.save(backbone_path) print(f"Backbone model saved to: {backbone_path}")
  • Purpose: Saves only the backbone of the model, excluding preprocessing or output layers.
  • Technical Details:
    • Backbone-saving allows reuse of core layers for transfer learning or fine-tuning in other tasks.

This structured code is optimized for memory, computational efficiency, and deployment versatility, addressing various stages of model training and optimization.

Comments

Popular posts from this blog

"How to maintain or retain tabs in same tab after button click events or postback?" using JQuery in ASP.NET C#

In this post I'll share an details about " How to maintain or retain tabs in same tab after button click events or postback? " Step 1: you need to download Jquery and JQueryUI Javascript libraries from this site http://jqueryui.com/ Step 2: As usually you can create ASP.NET website from Visual Studio IDE and add Jquery and JqueryUI plugins in the header section of aspx page. Step 3: Add HiddenField control inside aspx page which is very useful to retain tab in same page Step 4: Use the HiddenField ID in Jquery code to indicate that CurrentTab Index Step 5: In code Behind, using Enumerations concept give the tab index values as user defined variable  Step 6: Use the Enum values in every Button click events on different tabs to check that tab could be retained in the same tab Further, Here I'll give the code details and snap shot pictures, 1. Default.aspx: Design Page First Second Third ...

Login and Registration forms in C# windows application with Back end Microsoft SQL Server for data access

In this article, I'm gonna share about how to make login and register form with MS SQL database; 1. Flow Chart Logic 2. Normal Features 3. Form Designs Login Form Design Sign in Form Design Password Retrieve Form 4. Database Design and SQL queries and Stored Procedure Create new Database as "schooldata" create table registerdata (  ID int identity,  Username nvarchar(100),  Password nvarchar(100),  Fullname  nvarchar(100),  MobileNO nvarchar(100),  EmailID nvarchar(100)  ) select * from registerdata create procedure regis (  @Username as nvarchar(100),  @Password as nvarchar(100),  @Fullname as nvarchar(100),  @MobileNO as nvarchar(100),  @EmailID as nvarchar(100)  ) as begin insert into registerdata (Username, Password, Fullname, MobileNO,EmailID) values (@Username, @Password, @Fullname, @MobileNO, @EmailID) ...

Guidewire Related Interview Question and answers part 1

common Guidewire questions and answers 20 Guidewire BC Q&A Top 100 Guidewire Interview FAQ Guidewire Claimcenter 20 Interview Questions Guidewire Rating concepts