barathm111 commited on
Commit
bc41330
1 Parent(s): 6bfb9ac

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -13,8 +13,6 @@ class QueryRequest(BaseModel):
13
  @app.get("/")
14
  def home():
15
  return {"message": "SQL Generation Server is running"}
16
-
17
- @app.post("/generate")
18
  def generate(request: QueryRequest):
19
  try:
20
  text = request.text
@@ -24,14 +22,30 @@ def generate(request: QueryRequest):
24
  generated_text = output[0]['generated_text']
25
  sql_query = generated_text.split("SQL query:")[-1].strip()
26
 
27
- # Basic validation
28
- if not sql_query.lower().startswith(('select', 'show', 'describe')):
 
29
  raise ValueError("Generated text is not a valid SQL query")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return {"output": sql_query}
32
  except Exception as e:
33
  raise HTTPException(status_code=500, detail=str(e))
34
 
 
 
35
  if __name__ == "__main__":
36
  import uvicorn
37
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
13
  @app.get("/")
14
  def home():
15
  return {"message": "SQL Generation Server is running"}
 
 
16
  def generate(request: QueryRequest):
17
  try:
18
  text = request.text
 
22
  generated_text = output[0]['generated_text']
23
  sql_query = generated_text.split("SQL query:")[-1].strip()
24
 
25
+
26
+ # Basic validation to ensure it's a valid SQL query
27
+ if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
28
  raise ValueError("Generated text is not a valid SQL query")
29
 
30
+ # Further validation to ensure no additional text
31
+ sql_query = sql_query.split(';')[0].strip()
32
+
33
+ # Comprehensive list of SQL keywords
34
+ allowed_keywords = {
35
+ 'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
36
+ }
37
+ # Ensure the query only contains allowed keywords
38
+ tokens = sql_query.lower().split()
39
+ for token in tokens:
40
+ if not any(token.startswith(keyword) for keyword in allowed_keywords):
41
+ raise ValueError("Generated text contains invalid SQL syntax")
42
+
43
  return {"output": sql_query}
44
  except Exception as e:
45
  raise HTTPException(status_code=500, detail=str(e))
46
 
47
+
48
+
49
  if __name__ == "__main__":
50
  import uvicorn
51
+ uvicorn.run(app, host="0.0.0.0", port=7860)@app.post("/generate")