主页 > 其他  > 

Text-to-SQL将自然语言转换为数据库查询语句

Text-to-SQL将自然语言转换为数据库查询语句

有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究

直接与数据库对话-text2sql

Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。

使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。

这里简单做一下记录,以供后续学习使用。

基本思路

1、数据库DDL语句,SQL-Question,Dcoument信息获取

2、基于用户提问question和数据库Document锁定要分析的数据库

3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。

这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。

4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容

语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性

文本召回匹配:BM25算法文本召回,找到与问题最相关的内容

rerank结果重排序:对搜索结果进行排序。

5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。

实现逻辑图

实现架构图:

具体实现方式如下所示:

1.数据库的选择 class DataBaseSearch(object): def __init__(self, _model): self.name = 'DataBaseSearch' self.model = _model self.instruction = "为这段内容生成表示以用于匹配文本描述:" self.SIZE = 1024 self.index = faiss.IndexFlatL2(self.SIZE) self.textdata = [] self.subdata = {} self.i2key = {} self.id2ddls = {} self.id2sqlques = {} self.id2docs = {} self.strtexts = {} # self.ddldata = [] # self.sqlques_data = [] # self.document_data = [] self.load_textdata() # 加载text数据 self.load_textdata_vec() # text数据向量化 def load_textdata(self): try: response = requests.post( url="xxx", verify=False) print(response.text) jsonobj = json.loads(response.text) textdatas = jsonobj["data"] for textdata in textdatas: # 提取每一个数据库内容 cid = textdata["dataSetID"] cddls = textdata["ddl"] csql_ques = textdata["exp"] cdocuments = textdata["Intro"] self.textdata.append((cid, cddls, csql_ques, cdocuments)) # 整合所有数据 except Exception as e: print(e) # print("load textdata ", self.textdata) def load_textdata_vec(self): num0 = 0 for recode in self.textdata: _id = recode[0] _ddls = recode[1] _sql_ques = recode[2] _documents = recode[3] # _strtexts = str(_ddls) + str(_sql_ques) + str(_documents) _strtexts = str(_sql_ques) + str(_documents) text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True) self.index.add(text_embeddings) self.i2key[num0] = _id self.strtexts[_id] = _strtexts self.id2ddls[_id] = _ddls self.id2sqlques[_id] = _sql_ques self.id2docs[_id] = _documents num0 += 1 # print("init instruction vec", num0) def calculate_score(self, score, question, kws): pass def find_vec_database(self, question, k, theata): # print(question) q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True) D, I = self.index.search(q_embeddings, k) result = [] for i in range(k): sim_i = I[0][i] uuid = self.i2key.get(sim_i, "none") sim_v = D[0][i] database_texts = self.strtexts.get(uuid, "none") # score = self.calculate_score(sim_v, question, database_texts) # wait implement score = int(sim_v*1000) if score < theata: doc = {} doc["score"] = score doc["dataSetID"] = uuid result.append(doc) # print(result) return result if __name__ == '__main__': modelpath = "E:\module\bge-large-zh-v1.5" model = SentenceTransformer(modelpath) vs = DataBaseSearch(model) result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000) print(result) 2.sql_ques:sql问题训练 class SqlQuesSearch(object): def __init__(self, _model): self.name = "SqlQuesSearch" self.model = _model self.instruction = "为这段内容生成表示以用于匹配文本描述:" self.SIZE = 1024 self.index = faiss.IndexFlatL2(self.SIZE) self.sqlquedata = [] self.i2dbid = {} self.i2sqlid = {} self.id2sqlque = {} self.id2que = {} self.id2sql = {} self.dbid2sqlques = {} # # self.sqlques = {} # # self.i2key = {} # # self.id2sqlques = {} # # self.num2sqlque = {} # self.ddldata = [] # self.sqlques_data = [] # self.document_data = [] self.load_textdata() # 加载text数据 self.load_textdata_vec() # text数据向量化 def load_textdata(self): try: response = requests.post( url="xxx", verify=False) print(response.text) jsonobj = json.loads(response.text) textdatas = jsonobj["data"] datadatas = jsonobj["data"] for datadata in datadatas: # 提取每一个数据库sql-ques内容 dbid = datadata["dataSetID"] sql_ques = datadata["exp"] self.sqlquedata.append((dbid, sql_ques)) # 整合sql数据 except Exception as e: print(e) # print("load textdata ", self.sqlquedata) def load_textdata_vec(self): num0 = 0 for recode in self.sqlquedata: db_id = recode[0] sql_ques = recode[1] for sql_que in sql_ques: sql_id = sql_que["sql_id"] question = sql_que["question"] sql = sql_que["sql"] ddl_embeddings = self.model.encode([question], normalize_embeddings=True) self.index.add(ddl_embeddings) self.i2dbid[num0] = db_id self.i2sqlid[num0] = sql_id self.id2que[sql_id] = question self.id2sql[sql_id] = sql num0 += 1 print("init sql-que vec", num0) def calculate_score(sim_v, question, sql_ques): pass def find_vec_sqlque(self, question, k, theta, dataSetID, number): q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True) D, I = self.index.search(q_embeddings, k) result = [] for i in range(k): sim_i = I[0][i] dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id sqlid = self.i2sqlid.get(sim_i, "none") question = self.id2que.get(sqlid, "none") sql = self.id2sql.get(sqlid, "none") if dbid == dataSetID: sim_v = D[0][i] score = int(sim_v * 1000) if score < theta: doc = {} doc["score"] = score doc["question"] = question doc["sql"] = sql result.append(doc) if len(result) == number: break return result if __name__ == '__main__': modelpath = "E:\module\bge-large-zh-v1.5" model = SentenceTransformer(modelpath) vs = SqlQuesSearch(model) result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111) print(result) 3.数据库DDL训练 class DdlQuesSearch(object): def __init__(self, _model): self.name = "DdlQuesSearch" self.model = _model self.instruction = "为这段内容生成表示以用于匹配文本描述:" self.SIZE = 1024 self.index = faiss.IndexFlatL2(self.SIZE) self.ddldata = [] self.sqlques = {} self.i2dbid = {} self.i2ddlid = {} self.dbid2ddls = {} self.id2ddl = {} self.ddlid2dbid = {} # self.ddldata = [] # self.sqlques_data = [] # self.document_data = [] self.load_ddldata() # 加载text数据 self.load_ddl_vec() # text数据向量化 def load_ddldata(self): try: response = requests.post( url="xxx", verify=False) print(response.text) jsonobj = json.loads(response.text) for database in databases: db_id = database["dataSetID"] ddls = database["ddl"] self.ddldata.append((db_id, ddls)) # print(db_id) # for ddl in database["ddl"]: # ddl_id = ddl["ddl_id"] # ddl = ddl['ddl'] # # self.id2ddl[ddl_id] = ddl # self.dbid2ddls[db_id] = self.id2ddl except Exception as e: print(e) # print("load textdata ", self.ddldata) def load_ddl_vec(self): num0 = 0 for recode in self.ddldata: db_id = recode[0] ddls = recode[1] for ddl in ddls: ddl_id = ddl["ddl_id"] ddl_name = ddl["TABLE"] ddl = ddl['ddl'] ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True) self.index.add(ddl_embeddings) self.i2dbid[num0] = db_id self.i2ddlid[num0] = ddl_id self.id2ddl[ddl_id] = ddl self.ddlid2dbid[ddl_id] = db_id num0 += 1 self.dbid2ddls[db_id] = self.id2ddl print("init ddl vec", num0) def find_vec_ddl(self, question, k, theata, dataSetID, number): # dataSetID:数据库id # self.id2ddls.get(action_id) q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True) D, I = self.index.search(q_embeddings, k) result = [] for i in range(k): sim_i = I[0][i] dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id ddlid = self.i2ddlid.get(sim_i, "none") if dbid == dataSetID: sim_v = D[0][i] score = int(sim_v * 1000) if score < theata: doc = {} doc["score"] = score doc["ddl"] = self.id2ddl.get(ddlid, "none") result.append(doc) if len(result) == number: break return result if __name__ == '__main__': modelpath = "E:\module\bge-large-zh-v1.5" model = SentenceTransformer(modelpath) vs = DdlQuesSearch(model) ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111) print(ss) 4.数据库document训练 class DocQuesSearch(object): def __init__(self): self.name = "TestDataSearch" self.docdata = [] self.load_doc_data() def load_doc_data(self): try: response = requests.post( url="xxx", verify=False) print(response.text) jsonobj = json.loads(response.text) databases = jsonobj["data"] for database in databases: db_id = database["dataSetID"] doc = database["Intro"] self.docdata.append((db_id, doc)) except Exception as e: print(e) # print("load ddldata ", self.docdata) def find_similar_doc(self, dataSetID): result = [] for recode in self.docdata: dbid = recode[0] doc = recode[1] if dbid == dataSetID: result.append(doc) return result if __name__ == '__main__': docques_search = DocQuesSearch() result = docques_search.find_similar_doc(222) print(result) 5.生成sql语句,这里使用的qwen-max模型 import re import random import os, json import dashscope from dashscope.api_entities.dashscope_response import Message from ddl_engine import DdlQuesSearch from dashscope import Generation from sqlques_engine import SqlQuesSearch from sentence_transformers import SentenceTransformer class Genarate(object): def __init__(self): self.api_key = os.environ.get('api_key') self.model_name = os.environ.get('model') def system_message(self, message): return {'role': 'system', 'content': message} def user_message(self, message): return {'role': 'user', 'content': message} def assistant_message(self, message): return {'role': 'assistant', 'content': message} def submit_prompt(self, prompt): resp = Generation.call( model=self.model_name, messages=prompt, seed=random.randint(1, 10000), result_format='message', api_key=self.api_key) if resp["status_code"] == 200: answer = resp.output.choices[0].message.content global DEBUG_INFO DEBUG_INFO = (prompt, answer) return answer else: answer = None return answer def generate_sql(self, question, sql_result, ddl_result, doc_result): prompt = self.get_sql_prompt( question = question, sql_result = sql_result, ddl_result = ddl_result, doc_result = doc_result) print("SQL Prompt:",prompt) llm_response = self.submit_prompt(prompt) sql = self.extrat_sql(llm_response) return sql def extrat_sql(self, llm_response): sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL) if sqls: sql = sqls[-1] return sql sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL) if sqls: sql = sqls[-1] return sql sqls = re.findall(r"```sql (.*)```", llm_response, re.DOTALL) if sqls: sql = sqls[-1] return sql sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL) if sqls: sql = sqls[-1] return sql return llm_response def get_sql_prompt(self, question, sql_result, ddl_result, doc_result): initial_prompt = "You are a SQL expert. " + "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result) initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result) initial_prompt += ( "===Response Guidelines " "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. " "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql " "3. If the provided context is insufficient, please explain why it can't be generated. " "4. Please use the most relevant table(s). " "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. " ) message_log = [self.system_message(initial_prompt)] message_log = self.add_sqlques_to_prompt(question, sql_result, message_log) return message_log def add_ddl_to_prompt(self, initial_prompt, ddl_result): """ :param initial_prompt: :param ddl_result: :return: """ ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result] if len(ddl_list) > 0: initial_prompt += " ===Tables " for ddl in ddl_list: initial_prompt += f"{ddl} " return initial_prompt def add_sqlques_to_prompt(self, question, sql_result, message_log): """ :param sql_result: :return: """ if len(sql_result) > 0: for example in sql_result: if example is not None and "question" in example and "sql" in example: message_log.append(self.user_message(example["question"])) message_log.append(self.assistant_message(example["sql"])) message_log.append(self.user_message(question)) return message_log def add_documentation_to_prompt(self, initial_prompt, doc_result): if len(doc_result) > 0: initial_prompt += " ===Additional Context " for doc in doc_result: initial_prompt += f"{doc} " return initial_prompt if __name__ == '__main__': modelpath = "E:\module\bge-large-zh-v1.5" model = SentenceTransformer(modelpath) vs = DdlQuesSearch(model) ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111) print(ss) 6.执行结果显示

如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。

需要源码的同学,可以留言。

标签:

Text-to-SQL将自然语言转换为数据库查询语句由讯客互联其他栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“Text-to-SQL将自然语言转换为数据库查询语句