Fine-Tune Gemma 2 2b with Keras and LoRA (Part 3)

3 minute read

Published:

Gemma2 Continuing our educational series focusing on Arabic language handling with large language models, in this part, we will explore how to fine-tune the Gemma2-9b model on an Arabic dataset using the Keras library, Keras_nlp, and LoRA technique. We will cover how to set up the environment, load the model, make necessary modifications, and train the model using model parallelism to distribute model parameters across multiple accelerators.

Fine-Tune Gemma2 9b with Keras and LoRA (Part 3)

Overview

The Gemma model is a collection of lightweight and enhanced models created using research and techniques similar to those used in Google Gemini models. Gemma models can be adapted to meet specific needs, but handling very large models on a single accelerator can be challenging.

For the Gemma2-9b model, its size might be too large to fit on a single accelerator like Kaggle, making fine-tuning and inference difficult. Some options available include quantization to facilitate inference or using model parallelism to distribute the model parameters across multiple accelerators.

Environment Setup Steps

1. Accept the Terms of Use:

2. Install Libraries:

   !pip install -q -U keras-nlp tensorflow-text
   !pip install -q -U tensorflow-cpu
   !pip install -q -U huggingface_hub
   !pip install -q -U datasets

3. Set Up Keras with JAX:

   import jax
   jax.devices()

   import os
   os.environ["KERAS_BACKEND"] = "jax"
   os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Model Loading and Distribution Configuration

4. Load the Model:

   import keras
   import keras_nlp

   device_mesh = keras.distribution.DeviceMesh(
       (1, 8),
       ["batch", "model"],
       devices=keras.distribution.list_devices(),
   )
   model_dim = "model"

   layout_map = keras.distribution.LayoutMap(device_mesh)

   layout_map["token_embedding/embeddings"] = (model_dim, None)
   layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
   layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
   layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
   layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)
   model_parallel = keras.distribution.ModelParallel(
       layout_map=layout_map,
       batch_dim_name="batch",
   )

   keras.distribution.set_distribution(model_parallel)
   gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")

5. Run a Prediction Test Before Fine-Tuning:

   print(gemma_lm.generate("Create five realistic usernames in Arabic.", max_length=512))

Dataset Preparation

6. Load the CIDAR Dataset:

   from datasets import load_dataset

   dataset = load_dataset("arbml/CIDAR", split="train")
   df = dataset.to_pandas()
   df.head(10)

7. Format the Data:

   import json

   data = []

   for index, row in df.iterrows():
       template = "Instruction:\n{instruction}\n\nResponse:\n{output}"
       formatted_example = template.format(instruction=row['instruction'], output=row['output'])
       data.append(formatted_example)

   data = data[:2500]

Activating LoRA and Training the Model

8. Activate LoRA for the Model:

   gemma_lm.backbone.enable_lora(rank=8)
   gemma_lm.preprocessor.sequence_length = 512
   gemma_lm.compile(
       loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer=keras.optimizers.Adam(learning_rate=5e-5),
       weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
   )
   gemma_lm.summary()

9. Start Training:

   gemma_lm.fit(data, epochs=1, batch_size=4)

10. Run a Prediction Test After Fine-Tuning:

```python
print(gemma_lm.generate("Instruction:\n Create five realistic usernames in Arabic.\n\nResponse:\n", max_length=512))
```

Saving the Fine-Tuned Model

11. Set Up Access and Save the Model on Hugging Face:

```python
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

user_secrets = UserSecretsClient()
HF = user_secrets.get_secret("HF_WRITE")
login(token=HF)

gemma_lm.save("hf://Ruqiya/gemma2_9b_en_fine_tuned_arabic_dataset")
```

Conclusion

We have completed Part 3 of our series on fine-tuning the Gemma2-9b model using Keras and LoRA. Through these steps, you have learned how to enhance the performance of large models for Arabic data and provide a customized model capable of performing various tasks. You also learned how to upload the model to Hugging Face, making it available for experimentation and future projects.

Note: This tutorial aims to provide an understanding of the basic steps, but additional optimization steps may be needed to achieve satisfactory performance.

For further exploration, you can check the links below.