4 Basic Strategies for Optimizing RAG Process

4 Basic Strategies for Optimizing RAG Process

Author: Deephub Imba


This article is about 3000 words long, and it is recommended to read it in 7 minutes.
This article will introduce four strategies for optimizing Retrieval-Augmented Generation (RAG) using private data.


In this article, we will introduce four strategies for optimizing Retrieval-Augmented Generation (RAG) using private data, which can enhance the quality and accuracy of generation tasks. By employing certain optimization strategies, the performance and output quality of the retrieval-augmented generation system can be effectively improved, allowing it to better meet practical application needs.

Brief Review of RAG

RAG mainly consists of two processes. The first is the “data collection process”, which gathers data from different sources, converts it into text, segments it into smaller, coherent, and semantically relevant parts, and stores the results in a vector database. The second is the “inference process”, which starts with a user query, then uses the results from the first process to identify relevant data blocks, and finally enriches the model’s context to obtain output.
4 Basic Strategies for Optimizing RAG Process
Let’s summarize the key points that can be optimized in the RAG process:
1. Chunking Method: Optimize chunk size to ensure meaningful and contextually relevant data segments.
2. Embedding Model: Select and fine-tune models to improve semantic representation.
3. Vector Search Method: Choose effective similarity measures and search parameters.
4. Final Prompts for the Model: Create effective prompts to improve output quality.

A/B Testing in RAG

A/B testing can compare two versions of each component with different configurations to determine which version performs better.It runs the two versions separately and measures their performance based on predefined metrics. So how do we measure the metrics? What metrics? To answer this question, we refer to the paper “RAGAS: Automated Evaluation of Retrieval Augmented Generation”, which proposes three key metrics:
Authenticity: Check whether the information in the answer matches the context provided. If everything stated in the answer can be directly found or inferred from the context, then the answer is reliable.
Relevance: Check whether the generated answer is complete and directly answers the question asked. The correctness of the information is irrelevant. For example, if the question is “What is the capital of Portugal?” and the answer is “Lisbon is the capital of Portugal”, this answer is relevant because it directly answers the question. If the answer is “Lisbon is a beautiful city with many attractions”, it may be partially relevant but contains extra information that is not directly needed to answer the question. This metric ensures that the answer is focused and to the point.
Contextual Relevance: Check to what extent the information provided by the context helps answer the question. This metric ensures that only necessary and relevant details are included, removing any extra, irrelevant information that does not help directly answer the question. This metric ensures that the information provided is directly helpful in answering the question, avoiding unnecessary details. This measure is also known as contextual accuracy.
In addition, a new metric has been added:
Context Recall: This metric measures the consistency between the context and the actual answer, similar to contextual relevance; however, the actual answer, not the generated answer, is used. A fundamental truth is required to obtain this metric. To evaluate the effectiveness of these strategies, I prepared a set of 10 questions with actual answers based on ColdF data.
Authenticity and answer relevance are generator metrics that measure hallucination and the directness of the answer to the question, respectively.
Contextual relevance and context recall are retrieval metrics that measure the ability to retrieve the correct data blocks from the vector database and obtain all necessary information.

Next, we will start using LangChain to implement the RAG process. First, we will install the libraries:

 pip install ollama==0.2.1 pip install chromadb==0.5.0 pip install transformers==4.41.2 pip install torch==2.3.1 pip install langchain==0.2.0 pip install ragas==0.1.9

Below is a code snippet using LangChain:


# Import necessary libraries and modules
from langchain.embeddings.base import Embeddings
from transformers import BertModel, BertTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaModel, RobertaTokenizer
from langchain.prompts import ChatPromptTemplate
from langchain_text_splitters import MarkdownHeaderTextSplitter
import requests
from langchain_chroma import Chroma
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from operator import itemgetter

# Define a custom embedding class using the DPRQuestionEncoder class
class DPRQuestionEncoderEmbeddings(Embeddings):
    show_progress: bool = False
    """Whether to show a tqdm progress bar. Must have `tqdm` installed."""
    def __init__(self, model_name: str = 'facebook/dpr-question_encoder-single-nq-base'):
        # Initialize the tokenizer and model with the specified model name
        self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_name)
        self.model = DPRQuestionEncoder.from_pretrained(model_name)
    def embed(self, texts):
        # Ensure texts is a list
        if isinstance(texts, str):
            texts = [texts]
        embeddings = []
        if self.show_progress:
            try:
                from tqdm import tqdm
                iter_ = tqdm(texts, desc="Embeddings")
            except ImportError:
                logger.warning(
                    "Unable to show progress bar because tqdm could not be imported. "
                    "Please install with `pip install tqdm`."
                )
                iter_ = texts
        else:
            iter_ = texts
        for text in iter_:
            # Tokenize the input text
            inputs = self.tokenizer(text, return_tensors='pt')
            # Generate embeddings using the model
            outputs = self.model(**inputs)
            # Extract the embedding and convert it to a list
            embedding = outputs.pooler_output.detach().numpy()[0]
            embeddings.append(embedding.tolist())
        return embeddings
    def embed_documents(self, documents):
        return self.embed(documents)
    def embed_query(self, query):
        return self.embed([query])[0]

# Define a template for generating prompts
template = """
### CONTEXT {context}
### QUESTION Question: {question}
### INSTRUCTIONS Answer the user's QUESTION using the CONTEXT markdown text above. Provide short and concise answers. Base your answer solely on the facts from the CONTEXT. If the CONTEXT does not contain the necessary facts to answer the QUESTION, return 'NONE'.
"""

# Create a ChatPromptTemplate instance using the template
prompt = ChatPromptTemplate.from_template(template)

# Fetch text data from a URL
url = "https://raw.githubusercontent.com/cgrodrigues/rag-intro/main/coldf_secret_experiments.txt"
response = requests.get(url)
if response.status_code == 200:
    text = response.text
else:
    raise Exception(f"Failed to fetch the file: {response.status_code}")

# Define headers to split the markdown text
headers_to_split_on = [
    ("#", "Header 1")
]

# Create an instance of MarkdownHeaderTextSplitter with the specified headers
markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on, strip_headers=False
)

# Split the text using the markdown splitter
docs_splits = markdown_splitter.split_text(text)

# Initialize a chat model
llm = ChatOllama(model="llama3")

# Create a Chroma vector store from the documents using the custom embeddings
vectorstore = Chroma.from_documents(documents=docs_splits, embedding=DPRQuestionEncoderEmbeddings())

# Create a retriever from the vector store
retriever = vectorstore.as_retriever()

# Define a function to format documents for display
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Create a retrieval-augmented generation (RAG) chain
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"answer": prompt | llm | StrOutputParser(),
        "context": itemgetter("context")}
)

# Invoke the RAG chain with a question
result = rag_chain.invoke("Who led the Experiment 1?")
print(result)

Use the following code to evaluate the metrics:

# Import necessary libraries and modules
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
        context_precision,
        faithfulness,
        answer_relevancy,
        context_recall
)
from langchain_community.chat_models import ChatOllama

def get_questions_answers_contexts(rag_chain):
    """ Read the list of questions and answers and return a
        ragas dataset for evaluation """
    # URL of the file
    url = 'https://raw.githubusercontent.com/cgrodrigues/rag-intro/main/coldf_question_and_answer.psv'

    # Fetch the file from the URL
    response = requests.get(url)
    data = response.text

    # Split the data into lines
    lines = data.split('\n')
    # Split each line by the pipe symbol and create tuples
    rag_dataset = []
    for line in lines[1:10]: # Only 10 first questions
        if line.strip(): # Ensure the line is not empty
            question, reference_answer = line.split('|')
            result = rag_chain.invoke(question)
            generated_answer = result['answer']
            contexts = result['context']
            rag_dataset.append({
                "question": question,
                "answer": generated_answer,
                "contexts": [contexts],
                "ground_truth": reference_answer
            })

    rag_df = pd.DataFrame(rag_dataset)
    rag_eval_datset = Dataset.from_pandas(rag_df)
    # Return the ragas dataset
    return rag_eval_datset

def get_metrics(rag_dataset):
    """ For a RAG Dataset calculate the metrics faithfulness,
        answer_relevancy, context_precision and context_recall """
    # The list of metrics that we want to evaluate
    metrics = [
        faithfulness,
        answer_relevancy,
        context_precision,
        context_recall
    ]
    # We will use our local ollama with the LLaMA 3 model
    langchain_llm = ChatOllama(model="llama3")
    langchain_embeddings = DPRQuestionEncoderEmbeddings('facebook/dpr-question_encoder-single-nq-base')
    # Return the metrics
    results = evaluate(rag_dataset, metrics=metrics, llm=langchain_llm, embeddings=langchain_embeddings)
    return results

# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
If your code runs correctly, it should return results like the following:
{
  'faithfulness': 0.8611,
  'answer_relevancy': 0.8653,
  'context_precision': 0.7778,
  'context_recall': 0.8889
}

4 Basic Strategies for Optimizing RAG Process

The first two metrics are model-related; to improve these metrics, it is necessary to change the language model or provide better prompts for the model; the latter two metrics are retrieval-related; to improve these metrics, it is necessary to study the storage, indexing, and selection of documents.

Next, we will begin making improvements.

Chunking

The chunking method ensures that data is segmented into optimal retrieval segments. Experimenting with different chunk sizes helps to find a balance between too small (lacking context) and too large (retrieval system redundancy). In the baseline, we group documents based on each experiment; this means that some parts of the experiment may be diluted and not reflected in the final embeddings. One way to address this situation is to use a parent document retriever. This method retrieves not only specific relevant document fragments or paragraphs but also their parent documents. This approach ensures that the context surrounding relevant fragments is preserved. The following code is used to test this method:
# Import necessary libraries and modules
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Create the parent document retriever
parent_document_retriever = ParentDocumentRetriever(
    vectorstore = Chroma(collection_name="parents",
                          embedding_function=DPRQuestionEncoderEmbeddings('facebook/dpr-question_encoder-single-nq-base')),
    docstore = InMemoryStore(),
    child_splitter = RecursiveCharacterTextSplitter(chunk_size=200),
    parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500),
)
parent_document_retriever.add_documents(docs_splits)

# Create a retrieval-augmented generation (RAG) chain
rag_chain_pr = (
    {"context": parent_document_retriever | format_docs, "question": RunnablePassthrough()}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"answer": prompt | llm | StrOutputParser(),
        "context": itemgetter("context")}
)
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_pr)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
The results are as follows:
4 Basic Strategies for Optimizing RAG Process
This change reduced performance; through the metrics, we can see that the context recall rate dropped, indicating that the retrieval process was incorrect and the context lacked complete information. The changes in authenticity and answer relevance metrics stem from complex contexts. Therefore, we need to try other chunking and retrieval methods.

Embedding Models

Embedding models convert text chunks into dense vector representations. Different models can be trained on different topics, and selecting the correct model can improve embeddings. The choice of embedding method should consider the balance between computational efficiency and embedding quality.
Here, we compare different embedding models, such as Dense Passage Retrieval, Sentence-BERT, or Chroma’s default model (“all-MiniLM-L6-v2”). Each model has its strengths, and evaluating them on domain-specific data helps determine which model provides the most accurate semantic representation.
We define a new class “SentenceBertEncoderEmbeddings”. This new class implements the Sentence-BERT model. This new class will replace our previous embedding, “DPRQuestionEncoderEmbeddings”,
# Import necessary libraries and modules
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
        context_precision,
        faithfulness,
        answer_relevancy,
        context_recall
)
from langchain_community.chat_models import ChatOllama
from sentence_transformers import SentenceTransformer

# Define a custom embedding class using the Sentence-BERT model
class SentenceBertEncoderEmbeddings(Embeddings):
    show_progress: bool = False
    """Whether to show a tqdm progress bar. Must have `tqdm` installed."""
    def __init__(self, model_name: str = 'paraphrase-MiniLM-L6-v2'):
        # Initialize the tokenizer and model with the specified model name
        self.model = SentenceTransformer(model_name)
    def embed(self, texts):
        # Ensure texts is a list
        if isinstance(texts, str):
            texts = [texts]
        embeddings = []
        if self.show_progress:
            try:
                from tqdm import tqdm
                iter_ = tqdm(texts, desc="Embeddings")
            except ImportError:
                logger.warning(
                    "Unable to show progress bar because tqdm could not be imported. "
                    "Please install with `pip install tqdm`.")
                iter_ = texts
        else:
            iter_ = texts
        for text in iter_:
            embeddings.append(self.model.encode(text).tolist())
        return embeddings
    def embed_documents(self, documents):
        return self.embed(documents)
    def embed_query(self, query):
        return self.embed([query])[0]

# Create a Chroma vector store from the documents using the custom embeddings
vectorstore = Chroma.from_documents(documents=docs_splits, embedding=SentenceBertEncoderEmbeddings())
# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# Create a retrieval-augmented generation (RAG) chain
rag_chain_ce = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"answer": prompt | llm | StrOutputParser(),
        "context": itemgetter("context")}
)
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_ce)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
The results are as follows:
4 Basic Strategies for Optimizing RAG Process
We can see that performance has also declined. This is because DPR has higher retrieval accuracy than Sentence-BERT, making it more suitable for our case, where precise document retrieval is crucial. The significant drop in authenticity and answer relevance metrics when switching to Sentence-BERT highlights the importance of selecting the appropriate embedding model for tasks requiring high retrieval accuracy. It also indicates that different types of RAG tasks may require domain-specific embedding models.

Vector Search Methods

Vector search methods retrieve the most relevant blocks based on similarity measures. Common methods include Euclidean (L2) distance, cosine similarity, etc. Changing this search method can enhance the quality of the final output.
The code is as follows:
# Import necessary libraries and modules
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
        context_precision,
        faithfulness,
        answer_relevancy,
        context_recall
)
from langchain_community.chat_models import ChatOllama

# Create a Chroma vector store from the documents
# using the custom embeddings and also changing to
# cosine similarity search
vectorstore = Chroma.from_documents(collection_name="dist",
                                    documents=docs_splits,
                                    embedding=DPRQuestionEncoderEmbeddings(),
                                    collection_metadata={"hnsw:space": "cosine"})

# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# Create a retrieval-augmented generation (RAG) chain
rag_chain_dist = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"answer": prompt | llm | StrOutputParser(),
        "context": itemgetter("context")}
)
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_dist)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
We can see that “authenticity” has improved; using cosine similarity has enhanced the alignment of the retrieved documents with the query, even though “context precision” has decreased. Overall, the higher “faithfulness” and “context recall” indicate that cosine similarity is a more effective vector search method in this case, supporting the importance of selecting vector search methods in optimizing retrieval performance.

Final Prompts for the Input Model

The construction of final prompts involves integrating the retrieved data into the model’s query. Minor changes in the prompts can significantly affect the results, making it an iterative process. Providing examples in the prompts can guide the model to achieve more accurate and relevant outputs; modifying prompts does not involve changing the code, so we will not demonstrate it here.

Conclusion

Optimizing Retrieval-Augmented Generation (RAG) is an iterative process that largely depends on the specific data and context of the application. We explored four key optimization directions: refining chunking methods, selecting and fine-tuning embedding models, choosing effective vector search methods, and creating precise prompts. Each of these components plays a crucial role in improving the performance of RAG systems.
The process of optimizing RAG requires continuous testing, learning from failures, and making informed adjustments. An iterative approach is necessary to customize AI solutions that effectively meet specific needs. Most importantly, the key to success lies in understanding existing data, trying different strategies, and continuously improving the process.
Editor: Wang Jing

About Us

Data Hub THU, as a data science public account, backed by the Tsinghua University Big Data Research Center, shares cutting-edge data science and big data technology innovation research dynamics, continuously disseminating data science knowledge, striving to build a data talent aggregation platform, and creating the strongest group of big data in China.

4 Basic Strategies for Optimizing RAG Process

Weibo: @Data Hub THU

WeChat Video Account: Data Hub THU

Today’s Headlines: Data Hub THU

Leave a Comment