GPT on application data source using LangChain and private LLM.

Sujit Udhane
4 min readOct 2, 2023

GenAI high level solution

GenAI high level diagram — Applying private LLM to query a database.

Tech Stack

Private LLM: Llama-2–7b-chat

Vector database: Chroma [Optional]

Langchain library: SqlDatabaseChain

SQL Toolkit: SQLAlchemy

Webapp: Streamlit

Application database: SQLite [But, can be any RDBMS]

Potential Use Cases

1. Business analytics and business insights: The ability to gain knowledge as a product or business owner. A query is often written by the engineering team.
2. IT Support: As a system support engineer (L3/L2 support), intermediary steps that explain how the SQL query was created in the GPT answer assist in obtaining the actual database query (In the majority of the cases, 100% accurate). The speed of root cause identification and issue resolution will increase.
3. Observability: In the capacity as a system support engineer (L3/L2 support), support person can query application logs typically stored in data sources like ElasticSearch.
4. Q&A: An application user can naturally ask questions to acquire further insights, and GPT against a database can help extract the most knowledge from the system.
5. A new method of system communication: As a product manager, you can develop new channels of communication with the system to facilitate user tasks.

Reference code

from langchain.llms import LlamaCpp
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.experimental.sql import SQLDatabaseChain
from langchain.callbacks.manager import CallbackManager
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.schema import Document
from langchain.memory.vectorstore import VectorStoreRetrieverMemory

QUERY_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that corrsespond to questions are:

{few_shot_examples}

Question: {input}"""

prompt_to_query_dict = {
'List all artists.': 'SELECT * FROM artists;',
"Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
"List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
}

def settings():

callbacks = [StreamingStdOutCallbackHandler()]

DOMAIN_DOCS = [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings=HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)

chromaVectorStore = Chroma(
collection_name = "app_datasource_store",
embedding_function = embeddings,
persistent_directory = "./chroma_db_app_datasource"
)

chromaVectorStore.from_documents(DOMAIN_DOCS, embedding=embeddings)
datasource_retriever = chromaVectorStore.as_retriever(search_type="mmr")
applicationDataSourceDocsMemory = VectorStoreRetrieverMemory(retriever=datasource_retriever)

# Callbacks support token-wise streaming
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

n_gpu_layers = 1
n_batch = 512

# LLM (Place your model into models directory).
llm = LlamaCpp(
model_path="./models/llama-2-7b-chat.Q8_0.gguf",
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=4000,
f16_kv=True,
callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
)

db = SQLDatabase.from_uri("sqlite:///./Chinook.db")

#Create an SqlDatabaseChain to interpret the input prompt using Domain Docs and LLM.
chain = SQLDatabaseChain.from_llm(llm, db, verbose=False, use_query_checker=True, top_k=8,
return_intermediate_steps=True, memory=applicationDataSourceDocsMemory)

return chain

class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text

def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.info(self.text)

class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.container = container.expander("Context Retrieval")

def on_retriever_start(self, query: str, **kwargs):
self.container.write(f"**Question:** {query}")

def on_retriever_end(self, documents, **kwargs):
# self.container.write(documents)
for idx, doc in enumerate(documents):
source = doc.metadata["source"]
self.container.write(f"**Results from {source}**")
self.container.text(doc.page_content)

st.header("`Interacting with database`")

# Make SQLDatabaseChain
if 'chain' not in st.session_state:
st.session_state['chain'] = settings()
chain = st.session_state.chain

# User input
question = st.text_input("`Ask a question:`")

if question:

# Generate answer (w/ citations)
updated_question = QUERY_TEMPLATE.format(question=question)
qa_chain = chain(updated_question)
# Print Intermediate Steps
qa_chain["intermediate_steps"]
# Write answer
retrieval_streamer_cb = PrintRetrievalHandler(st.container())
answer = st.empty()
answer.info('`Answer:`\n\n' + qa_chain["result"])

Possible Improvements

1. You can investigate other RDBMS databases, such as PostgreSQL.
2. Memory can be an optional parameter when constructing a SQLDatabaseChain that uses a vector database to convey domain-specific knowledge. According to observations, this can occasionally lead to erroneous query construction and cause SQL syntax errors.
3. Other open source and commercial-friendly LLMs, such as a StarCoder, can produce better results than generic LLMs like Llama-2.

Challenges and possible remedies

1. Database access: Any application’s data source is a crucial component of the overall system. It needs to be effectively supervised and controlled in order to be kept in good health. A restricted user with read-only access and the ability to only access a group of tables must be developed for such a solution.
2. Sensitive information leakage: Response data should be sanitised with a technique like masking before producing the final response.
3. User management: The user should be able to get information specific to them based on their user persona. You could implement a role-based access control system.
4. Customer sensitive information: Solution must safeguard sensitive consumer data. This can be accomplished by adding the default filter condition to the resulting SQL. PromptTemplate’s technique can assist in achieving this.

Please remember to clap. Please spread the word if you enjoyed the post.

Additionally, you can connect with me on LinkedIn at https://www.linkedin.com/in/sujit-udhane/

--

--

Sujit Udhane

I am Lead Platform Architect, working in Pune-India. I have 20+ years of experience in technology, and last 10+ years working as an Architect.