Datawhale Insights
Author: Song Zhixue, Datawhale Member
Hello everyone, I am not a fan of garlic and ginger.
Next, I will guide you step by step to implement a simple RAG model, which is a simplified version of RAG, called Tiny-RAG. Tiny-RAG is a simplified version of RAG that only includes the core functions of RAG, namely Retrieval and Generation. The purpose of Tiny-RAG is to help everyone better understand the principles and implementation of the RAG model.
OK, let’s get started!
1. Introduction to RAG
LLMs can produce misleading “hallucinations,” the information they rely on may be outdated, and they are not efficient in handling specific knowledge, lacking deep insights in specialized fields, and also have deficiencies in reasoning capabilities.
In this context, Retrieval-Augmented Generation (RAG) emerged as a major trend in the AI era.
RAG improves the accuracy and relevance of content by first retrieving relevant information from a broad document database before generating answers with the language model, greatly enhancing the content’s accuracy and relevance. RAG effectively mitigates the hallucination problem, increases the speed of knowledge updates, and enhances the traceability of content generation, making large language models more practical and trustworthy in real-world applications.
What are the basic components of RAG?
-
There should be a vectorization module to vectorize document fragments. -
There should be a document loading and segmentation module to load documents and split them into fragments. -
There should be a database to store document fragments and their corresponding vector representations. -
There should be a retrieval module to fetch relevant document fragments based on a query. -
There should be a large model module to answer user questions based on the retrieved documents.
OK, the above are all the modules in the TinyRAG repository.

Next, let’s clarify what the RAG process looks like.
-
Indexing: Split the document library into shorter chunks and build a vector index using an encoder. -
Retrieval: Retrieve relevant document fragments based on the similarity of the question and chunks. -
Generation: Generate answers to the questions based on the retrieved context.
This is the process illustrated in the following figure, sourced from Retrieval-Augmented Generation for Large Language Models: A Survey

2. Vectorization
First, let’s implement a vectorization class, which is the foundation of the RAG architecture. The vectorization class is mainly used to vectorize document fragments, mapping a piece of text to a vector.
First, we need to set up a Embedding
base class so that when we use other models, we only need to inherit this base class and modify it as needed, making code expansion easier.
class BaseEmbeddings:
"""
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
"""
calculate cosine similarity between two vectors
"""
dot_product = np.dot(vector1, vector2)
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
if not magnitude:
return 0
return dot_product / magnitude
Let’s observe what methods the BaseEmbeddings
base class has. First, there is a get_embedding
method, which is used to obtain the vector representation of the text, and then there is a cosine_similarity
method, which calculates the cosine similarity between two vectors. Additionally, when initializing the class, we set the model’s path and whether it is an API model. For example, if using OpenAI’s Embedding API, you need to set self.is_api=True
.
If you inherit the BaseEmbeddings
class, you only need to implement the get_embedding
method, and the cosine_similarity
method will be inherited and can be used directly. This is the benefit of writing a base class.
class OpenAIEmbedding(BaseEmbeddings):
"""
class for OpenAI embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from openai import OpenAI
self.client = OpenAI()
self.client.api_key = os.getenv("OPENAI_API_KEY")
self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else:
raise NotImplementedError
3. Document Loading and Segmentation
Next, we will implement a document loading and segmentation class, which is mainly used to load documents and split them into fragments.
What types of documents do we need to segment? These documents can be an article, a book, a conversation, a piece of code, etc. The content of these documents can be anything as long as it is text. For example: pdf files, md files, txt files, etc.
Only a part of the content is shown here; the complete code can be found in RAG/utils.py. In this code, you can see that the types of files that can be loaded include: pdf, md, txt, and you just need to write the corresponding functions.
def read_file_content(cls, file_path: str):
# Choose reading method based on file extension
if file_path.endswith('.pdf'):
return cls.read_pdf(file_path)
elif file_path.endswith('.md'):
return cls.read_markdown(file_path)
elif file_path.endswith('.txt'):
return cls.read_text(file_path)
else:
raise ValueError("Unsupported file type")
Once we read the file content, we also need to segment it! So how do we segment it? OK, next we will segment the document based on the length of tokens. We can set a maximum token length and segment the document according to this maximum token length. The segments produced will be document fragments of roughly the same length.
However, when segmenting, it is best to have some overlapping content between segments, so that relevant document fragments can be retrieved during the search. Additionally, it is best to segment the document by sentence, which means roughly segmenting by \n
, ensuring that the content of the sentences is complete.
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
chunk_text = []
curr_len = 0
curr_chunk = ''
lines = text.split('\n') # Assume splitting text by newline into lines
for line in lines:
line = line.replace(' ', '')
line_len = len(enc.encode(line))
if line_len > max_token_len:
print('warning line_len = ', line_len)
if curr_len + line_len <= max_token_len:
curr_chunk += line
curr_chunk += '\n'
curr_len += line_len
curr_len += 1
else:
chunk_text.append(curr_chunk)
curr_chunk = curr_chunk[-cover_content:]+line
curr_len = line_len + cover_content
if curr_chunk:
chunk_text.append(curr_chunk)
return chunk_text
4. Database && Vector Retrieval
Now that we have completed document segmentation and loading the embedding model, we need to design a vector database to store document fragments and their corresponding vector representations.
We also need to design a retrieval module to fetch relevant document fragments based on a query. OK, let’s go for it!
What functionalities does a database need to implement for the minimal RAG architecture?
-
persist
: Database persistence, saving locally -
load_vector
: Load the database from local storage -
get_vector
: Obtain the vector representation of a document -
query
: Retrieve relevant document fragments based on a query
These four modules are the minimum functionalities that a RAG structure database needs to implement. The specific code can be found in RAG/VectorBase.py.
class VectorStore:
def __init__(self, document: List[str] = ['']) -> None:
self.document = document
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
# Obtain the vector representation of the document
pass
def persist(self, path: str = 'storage'):
# Database persistence, save locally
pass
def load_vector(self, path: str = 'storage'):
# Load the database from local storage
pass
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
# Retrieve relevant document fragments based on a query
pass
Now let’s look at how the query
method is implemented.
First, vectorize the user’s question, then retrieve relevant document fragments from the database, and finally return the retrieved document fragments. During vector retrieval, we only use Numpy
for acceleration, and the code is very easy to understand and modify.
It is mainly for ease of rewriting and understanding, and a mature database is not used so that the principles of RAG can be better understood.
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector)
for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()
5. Large Model Module
Now we come to the last module, the large model module. This module is mainly used to answer user questions based on the retrieved documents.
Similarly, we will first implement a base class so that when we encounter other models of interest, we can quickly expand.
class BaseModel:
def __init__(self, path: str = '') -> None:
self.path = path
def chat(self, prompt: str, history: List[dict], content: str) -> str:
pass
def load_model(self):
pass
The BaseModel
includes two methods, chat
and load_model
. If using an API model, such as OpenAI, then the load_model
method is not needed; if you want to run it locally, you will choose to use an open-source model, then the load_model
method will be necessary.
Here we take the InternLM2-chat-7B model as an example.
class InternLMChat(BaseModel):
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()
def chat(self, prompt: str, history: List = [], content: str='') -> str:
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history)
return response
def load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
We can use a dictionary to store all the prompts, which makes it easier to maintain.
PROMPT_TEMPLATE = dict(
InternLM_PROMPT_TEMPALTE="""First summarize the context, then use the context to answer the user's question. If you don't know the answer, just say you don't know. Always respond in Chinese.
Question: {question}
Referable context:
···
{context}
···
If the given context cannot help you answer, please respond that there is no such content in the database, and you don't know.
Useful response:"""
)
With this, we can use the InternLM2 model to perform RAG!
6. LLM Tiny-RAG Demo
Next, let’s take a look at the Tiny-RAG Demo!
from RAG.VectorBase import VectorStore
from RAG.utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding
# Without saving the database
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # Get all file contents in the data directory and segment them
vector = VectorStore(docs)
embedding = ZhipuEmbedding() # Create EmbeddingModel
vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage') # Save vectors and document content to the storage directory, so you can load the local database directly next time
question = 'What is the principle of git?'
content = vector.query(question, model='zhipu', k=1)[0]
chat = InternLMChat(path='model_path')
print(chat.chat(question, [], content))
Of course, we can also load an already processed database from local storage since we have already implemented this functionality in the database section above.
from RAG.VectorBase import VectorStore
from RAG.utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding
# After saving the database
vector = VectorStore()
vector.load_vector('./storage') # Load the local database
question = 'What is the principle of git?'
embedding = ZhipuEmbedding() # Create EmbeddingModel
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
chat = InternLMChat(path='model_path')
print(chat.chat(question, [], content))
7. Summary
After the above learning, have you learned how to build a minimal RAG architecture? I believe you have definitely learned, haha.
Now let’s review what a minimal RAG should include? (Write it down here!)
-
Vectorization module -
Document loading and segmentation module -
Database -
Vector retrieval -
Large model module
OK, you have learned, but don’t forget to give my project a star!
Project address: https://github.com/KMnO4-zx/TinyRAG
Let’s “like it” three times↓