barathm111 commited on
Commit
9dfd473
1 Parent(s): 130beb7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -17
app.py CHANGED
@@ -13,6 +13,8 @@ class QueryRequest(BaseModel):
13
  @app.get("/")
14
  def home():
15
  return {"message": "SQL Generation Server is running"}
 
 
16
  @app.post("/generate")
17
  def generate(request: QueryRequest):
18
  try:
@@ -22,29 +24,15 @@ def generate(request: QueryRequest):
22
 
23
  generated_text = output[0]['generated_text']
24
  sql_query = generated_text.split("SQL query:")[-1].strip()
25
- # Basic validation to ensure it's a valid SQL query
26
- if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
 
27
  raise ValueError("Generated text is not a valid SQL query")
28
 
29
- # Further validation to ensure no additional text
30
- sql_query = sql_query.split(';')[0].strip()
31
-
32
- # Comprehensive list of SQL keywords
33
- allowed_keywords = {
34
- 'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
35
- }
36
- # Ensure the query only contains allowed keywords
37
- tokens = sql_query.lower().split()
38
- for token in tokens:
39
- if not any(token.startswith(keyword) for keyword in allowed_keywords):
40
- raise ValueError("Generated text contains invalid SQL syntax")
41
-
42
  return {"output": sql_query}
43
  except Exception as e:
44
  raise HTTPException(status_code=500, detail=str(e))
45
 
46
-
47
-
48
  if __name__ == "__main__":
49
  import uvicorn
50
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
13
  @app.get("/")
14
  def home():
15
  return {"message": "SQL Generation Server is running"}
16
+
17
+
18
  @app.post("/generate")
19
  def generate(request: QueryRequest):
20
  try:
 
24
 
25
  generated_text = output[0]['generated_text']
26
  sql_query = generated_text.split("SQL query:")[-1].strip()
27
+
28
+ # Basic validation
29
+ if not sql_query.lower().startswith(('select', 'show', 'describe')):
30
  raise ValueError("Generated text is not a valid SQL query")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return {"output": sql_query}
33
  except Exception as e:
34
  raise HTTPException(status_code=500, detail=str(e))
35
 
 
 
36
  if __name__ == "__main__":
37
  import uvicorn
38
  uvicorn.run(app, host="0.0.0.0", port=7860)