My takeaway series follow a Q&A format to explain AI concepts at three levels:
Anyone with general knowledge can understand them.
For anyone who wants to dive into the code implementation details of the concept.
For anyone who wants to understand the mathematics behind the technique.
This article organises the notes I took when learning RAG course by DeepLearning.AI.
Retrieval Augmented Generation (RAG) is a framework in AI that combines pre-trained language models (PLMs) with external knowledge sources to enhance the generation of text. It leverages the strengths of both retrieval-based methods and generative models to produce more accurate and contextually relevant responses.
The RAG framework is built on top of a LLM with the following three main components:
Knowledge database: The external source of information. It can be a collection of documents, a database, or even the internet.
Retriever: This component is responsible for fetching relevant information from the knowledge base.
The RAG process typically involves the following steps:
- Retrieval: When a query or prompt is given, the retrieval component searches the knowledge database to find relevant information related to the query.
- Augmentation: The retrieved information is then used to augment the input to the generative language model.
- Generation: The augmented input is fed into the generative language model, which produces a response that incorporates both the original query and the retrieved information.
This is exactly what “R, A, G” means.
LLM can only output what it has learned during training, but LLM with RAG can access external knowledge sources, such as documents, databases, or the internet, to retrieve relevant information that may not be present in its training data.
For example, when you ask ChatGPT about “Why the sky in Beijing is blue?”, without RAG, LLM may answer something about the physical and geographical principle, but with RAG, it can search the web and find the real-time air quality index and the events in Beijing, and give you a reason that the sky is blue because of the good air quality or some policy changes.
LLM with RAG can provide more accurate and contextually relevant responses, especially for queries that require up-to-date or domain-specific information. It can provide the citation or source of the information it retrieves, enhancing the credibility and reliability of its responses. It reduces hallucination as well.
- Fine-tuning LLM is computationally expensive and time-consuming. RAG allows for the integration of new information without the need for extensive retraining.
- There might be too much knowledge to be learned by LLM, some of which:
- Can be private or hard to access.
- Can be real-time data, such as news. There is not enough time for training.
Realistically, directly feeding the prompt is not feasible because the context window of LLM is limited. If you have been using ChatGPT or similar models, you may have noticed that there is a maximum token limit for each input. This means that only a certain amount of information can be processed as a single prompt at once.
This has to facilitate some technology to be done. RAG is such a technology that allows for the integration of a much larger knowledge base without being constrained by the model’s input size.
The knowledge refers to the external source of information that the RAG framework relies on. It can take various forms, including:
- A collection of documents (e.g., PDFs, articles, reports)
- A structured database (e.g., SQL, NoSQL)
- Unstructured data sources (e.g., web pages, social media)
- APIs that provide access to real-time information
The knowledge database serves as the foundation for the retrieval component, enabling it to fetch relevant information to augment the input for the generative language model.
Knowledge database is the storage implementation of the external knowledge. It is the same thing as the one in the database field.
It can be traditional relational database that stores the knowledge in a structured format.
In recent years, as the deep learning techniques have advanced, vector database has been popular. It stores the knowledge in the form of high-dimensional vectors. Now it is widely used in RAG systems.
The knowledge can be in various formats, such as text documents, images, or structured data. They are naturally separated into entries, such as paragraphs, sections, or individual records. But sometimes, they need to be chunked into smaller pieces to fit the database schema or to improve retrieval efficiency.
Too large chunks may contain too much irrelevant information, and also exceed the context window limit of the LLM, while too small chunks may lose context and coherence. The chunking granularity can vary depending on the application and the nature of the knowledge. It doesn’t even have to be completely separate, overlap is allowed. It is a parameter to tune.
There are several techniques to chunk the knowledge into smaller pieces, including:
- Fixed-size chunking: Dividing the knowledge into chunks of a predetermined size, such as a specific number of words or characters.
- Delimiter-based chunking: Using specific delimiters, such as punctuation marks or line breaks, to identify chunk boundaries.
- Semantic chunking: Using natural language processing techniques to identify meaningful segments of text, such as sentences or paragraphs. This algorithm sweeps through the text and identifies the boundaries based on similarity scores between adjacent segments. This is smarter but computationally more expensive.
- LLM-based chunking: Using a pre-trained language model to identify chunk boundaries based on the context and semantics of the text. The LLM can even produce more information such as summaries or key points for the chunk. This is the most advanced but also the most computationally expensive.
There are many relational database management systems (RDBMS) available, such as MySQL, PostgreSQL, SQLite, and Microsoft SQL Server. The choice of RDBMS depends on the specific requirements of the application, such as scalability, performance, and ease of use.
These technologies have been widely used for decades and become classics in the database field, so I won’t go into the details here.
Vector database is a more recent development emerging in 2020s along with the rise of LLMs and embedding techniques. There are several vector database systems available, such as Pinecone, Weaviate, Milvus, and FAISS (Facebook AI Similarity Search).
Relational database and vector database are designed for different purposes. Relational database supports traditional key search, while vector database supports semantic search. The pros and cons between the search methods are discussed below.
In storage efficiency, relational database is more efficient for structured data, while vector database is more efficient for high-dimensional vector data. Relational database can be very large to store high-dimensional and unstructured data, while vector database often enlarges the small-sized data to high-dimensional vectors.
What is the retriever?
Retriever is the same thing as the algorithm in the information retrieval (IR) field. The objective is to find the most relevant documents from the knowledge database according to the input query.
There are various information retrieval algorithms, which can be broadly categorized into two types (based on the knowledge database type):
- Key search: Used in relational database, such as SQL queries.
- Semantic search: Used in vector database, such as cosine similarity search.
The retriever can use either one of them, or both of them (hybrid search). Before searching, the retriever can preprocess the prompt to improve the search quality. This is usually called query parsing. After searching, there is usually a further processing step called metadata filtering, which narrows down the retrieved documents based on their metadata (e.g., date, author, source). Please note metadata filtering is not a search technique, but a post-processing step after retrieval.
Retriever typically ranks the documents in the knowledge base according to their relevance to the input query, and selects the top-k most relevant documents to be used for augmentation (this can be called as kNN). After the retrieval, the documents can be further processed, such as re-ranking or deduplication, before being passed to the generative language model.
The retriever has its own hyperparameters, such as the number of documents to retrieve (k), similarity threshold, and more. This can be tuned for better performance for the RAG system.
We can use LLM to rephrase or expand the query, or extract keywords from the query. This can help to improve the search quality by making the query more specific and relevant to the knowledge database.
A NLP task called name entity recognition (NER) can be used to identify and extract named entities (e.g., people, organizations, locations) from the query.
An improvement is called ANN (Approximate Nearest Neighbors). ANN saves the proximity graph of the vectors in the database in the first place. This is pre-computed before the query comes. When a query comes, it starts from a few random points in the graph, and traverses the graph (DFS or BFS) to find the nearest neighbors based on heuristics (e.g., distance).
This is much faster than kNN, especially when the database is large. However, it is approximate and may not find the exact nearest neighbors. In practice though, ANN can often find very close candidates for RAG systems.
There is a more advanced version of ANN called HNSW (Hierarchical Navigable Small World), which builds a multi-layer graph structure to further improve the search efficiency. HNSW is widely used in modern vector databases. We won’t go into the details here.
Key search uses traditional information retrieval techniques, such as TF-IDF (Term Frequency-Inverse Document Frequency) or BM25 (Best Matching 25), to find documents that contain keywords matching the input query.
These techniques analyze the frequency and distribution of words in the documents and the query to determine their relevance. The query and documents are represented as bag-of-words vectors (In RAG, this query is the prompt) \(\mathbf{q}, \mathbf{d_1}, \cdots, \mathbf{d_N}\). The vectors are \(T\)-dimensional (\(T\) words in the vocabulary) where each dimension \(t\) corresponds to a term (word). This is usually a sparse vector, meaning that most of the elements are zero. Taking TF-IDF as an example, the relevance score between the query and a document can be calculated as follows:
\[\text{TF-IDF}(\mathbf{q}, \mathbf{d}_n) = \sum_{t \in \mathbf{q}} \text{TF}(t, \mathbf{d}_n) \cdot \text{IDF}(t)\]
where \(\text{TF}(t, \mathbf{d})\) is the term frequency of term \(t\) in document \(\mathbf{d}\):
\[\text{TF}(t, \mathbf{d}_n) = \frac{f_{t,\mathbf{d}_n}}{\sum_{t'} f_{t',\mathbf{d}_n}}\]
and \(\text{IDF}(t)\) is the inverse document frequency of term \(t\) in the entire document collection.
\[ \text{IDF}(t) = \log \frac{N}{n_t} \]
where \(N\) is the total number of documents, and \(n_t\) is the number of documents containing term \(t\). The document frequency is reversed because a term that appears in many documents is less informative than a term that appears in few documents. The logarithm is used to dampen the effect of very rare document frequencies.
BM25 is a fancier ranking function that builds upon TF-IDF, and introduces additional parameters. Most key search systems use BM25 nowadays. We won’t go into the details here.
The vector database stores the embeddings of knowledge contents (words, sentences, documents) generated by the embedding model. In this vector space, similar contents are already close to each other, and dissimilar contents are far apart.
When a query is made, the system converts the query into an embedding using the same model and then performs a similarity search in the vector space to find the most relevant documents. In RAG, this query is the prompt. The similarity can be measured using various metrics for two vectors, such as cosine similarity, Euclidean distance, or dot product.
Sometimes, the query and each document can be passed together to the embedding model to compute the similarity score directly. This is called cross-encoder, where the above method is called bi-encoder. It is more accurate but also more computationally expensive, since: 1. it needs to compute the embedding for each query-document pair (\(2N\)) instead of \(N+1\); 2. the embedding for documents cannot be pre-computed and stored, since it depends on the query. There is a more advanced technique called ColBERT that combines the two methods. We won’t go into the details here.
Key search:
- Pros:
- Works well for exact keyword matches.
- Simple and efficient for small to medium-sized datasets.
- Easy to implement and understand.
- Cons:
- Limited ability to capture semantic meaning and context.
- May struggle with synonyms or related terms.
Semantic search:
- Pros:
- Flexibility: Can handle synonyms and related terms effectively. More robust to variations in language and phrasing.
- Cons:
- Requires more computational resources for embedding generation and similarity calculations.
- Performance can be sensitive to the quality of the embedding model used.
Both key search and semantic search are applied and produced two ranked lists of documents. Then a fusion algorithm is used to combine the two lists into a single ranked list. The fusion can be done in various ways, such as reciprocal rank fusion (RRF), Borda count, or weighted sum. Take RRF as an example, the final score of a document \(d\) is calculated as follows:
\[\text{RRF}(d) = \sum_{i=1}^{m} \frac{1}{k + \text{rank}_i(d)}\]
where \(m\) is the number of ranking lists (in this case, 2), \(\text{rank}_i(d)\) is the rank of document \(d\) in the \(i\)-th list, and \(k\) is a constant to control the influence of lower-ranked documents.
There are also many parameters here to tune.
Re-ranking is the process of refining the initial list of retrieved documents to improve their relevance to the query. This can be done using the LLM itself, or a separate model specifically trained for re-ranking. Since the retrieval results are already smaller, a high-performing but computationally expensive model can be used here, such as cross-encoder or ColBERT.
All the components and techniques mentioned above can be used in a RAG project. The choice depends on the specific requirements and constraints of the application, such as the size and nature of the knowledge base, the type of queries, the desired response time, and the available computational resources.
There are also a lot of hyperparameters to tune in each component, which can significantly affect the performance of the RAG system. It is important to experiment with different configurations and evaluate their impact on the overall system performance.
We need to evaluate the retriever separately before integrating it into the RAG system. The retriever has input and output, where the output has ground truth answer, so we can evaluate it like a supervised learning model by comparing the output with the hand-marked ground truth and computing metrics such as precision, recall, and F1-score. Some common evaluation metrics for retrievers include MAP (Mean Average Precision), MRR (Mean Reciprocal Rank), and Recall@k.
Prompt
The AI assistants you are using, such as ChatGPT, have a web search capability, which is based on the RAG framework.
RAG systems can be built in various application scenarios, which have their own knowledge base, or need accurate information retrieval. They include: domain-specific question answering, enterprise knowledge management, customer support automation, academic research / education assistance, law / medical support, real-time information retrieval, and more.
personal knowledge base (code, notes, ), AI search engine summary (the knowledge being the entire internet), personal assistant (using your own data, such as emails, calendar, notes, etc.)