Build a RAG Application Using Gemma 2 (Part 1)

3 minute read

Published:

Gemma2 When working on a project that requires handling the Arabic language, you might wonder whether to use Retrieval-Augmented Generation (RAG) or to fine-tune a model with a new Arabic dataset. In this series of tutorials with two parts, we will explore both options: using RAG and fine-tuning a model with Arabic data, specifically Wikipedia. Throughout the project, we will focus on open-source models, utilizing Gemma 2 Instruct and an open-source embeddings model. We will also leverage the LangChain framework to streamline the process of building the RAG and fine-tuning the model. Let’s dive into the practical implementation.

Step 1: Install Necessary Packages

First, install the required packages using the following command:

!pip install langchain langchainhub langchain_community langchain-huggingface faiss-gpu transformers accelerate datasets bitsandbytes langchain-text-splitters sentence-transformers huggingface_hub chromadb gradio > /dev/null 2>&1

Step 2: Import Libraries

Import the necessary libraries for your application:

import os
import torch
from langchain.document_loaders import HuggingFaceDatasetLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain import hub
from langchain.schema import Document
from huggingface_hub import login
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr

Step 3: Login to Hugging Face Hub

Log in to the Hugging Face Hub using your token:

from google.colab import userdata

hf_token = userdata.get('gemma2')
login(token=hf_token)

Step 4: Load and Process Dataset

Load a dataset from Hugging Face and process it into smaller chunks:

# Load dataset
dataset_name = "wikimedia/wikipedia"
page_content_column = "text"
name = "20231101.ar"
loader = HuggingFaceDatasetLoader(dataset_name, page_content_column, name)
data = loader.load()

# Select the first 20 entries for demonstration
documents = data[:20]

# Split text into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=0, length_function=len, is_separator_regex=False)

def process_document(document):
    chunks = text_splitter.split_text(document.page_content)
    split_docs = []
    for chunk in chunks:
        try:
            decoded_content = chunk.encode().decode('unicode_escape')
        except UnicodeDecodeError:
            decoded_content = chunk
        split_docs.append(Document(page_content=decoded_content, metadata=document.metadata))
    return split_docs

split_documents = []

with ThreadPoolExecutor() as executor:
    futures = [executor.submit(process_document, doc) for doc in documents]
    for future in as_completed(futures):
        split_documents.extend(future.result())

split_documents[:2]

Step 5: Initialize Embeddings Model

Initialize an embeddings model for processing the text:

model_path = "sentence-transformers/all-MiniLM-L12-v2"
model_kwargs = {'device': 'cuda'}  
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(model_name=model_path, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)

text = split_documents[0].page_content
query_result = embeddings.embed_query(text)
query_result[:3]

Step 6: Create and Save Vector Database

Create a vector database using FAISS and save it locally:

vector_db = FAISS.from_documents(split_documents, embeddings)
vector_db.save_local("/faiss_index")

Step 7: Initialize LLM Model for Text Generation

Initialize a language model for text generation using Gemma 2:

base_model = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model, return_dict=True, low_cpu_mem_usage=True,
                                            torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=20)
llm = HuggingFacePipeline(pipeline=pipe)

Step 8: Configure Retrieval

Configure the retriever to use the vector database:

retriever = vector_db.as_retriever()

Step 9: Set Up QA Chain

Set up a QA chain using the retriever and the language model:

rag_prompt = hub.pull("rlm/rag-prompt")
qa_chain = ({"context": retriever, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser())

def extract_answer(qa_chain_output):
    lines = qa_chain_output.split('\n')
    for line in lines:
        if line.startswith('Answer:'):
            return line.split(':', 1)[1].strip()
    return None

question = "ما هي النماذج اللغوية؟"
result = qa_chain.invoke(question)
extract_answer(result)

Step 10: Initialize Gradio Interface

Set up a Gradio interface for user interaction:

def chatbot_response(question):
    result = qa_chain.invoke(question)
    return result.split("Answer: ")[1]  # Extract the answer part

chatbot = gr.Interface(
    fn=chatbot_response,
    inputs="text",
    outputs="text",
    live=False,
    title="Gemma 2 Chatbot",
    description="Ask me anything"
)

chatbot.launch(debug=True)

Now you have a fully functional RAG application using Gemma 2. This setup allows you to load and process a dataset, create a vector database for retrieval, and generate responses using a language model. The Gradio interface provides an interactive way to ask questions and get answers in real-time.