barathm111 commited on
Commit
0d9dd80
1 Parent(s): 3b0a41a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -6
app.py CHANGED
@@ -1,11 +1,22 @@
1
- from fastapi import FastAPI, HTTPException
 
 
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
 
 
 
 
 
5
  app = FastAPI()
6
 
7
  # Initialize the text generation pipeline
8
- pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
 
 
 
 
 
9
 
10
  class QueryRequest(BaseModel):
11
  text: str
@@ -15,9 +26,11 @@ def home():
15
  return {"message": "SQL Generation Server is running"}
16
 
17
  @app.post("/generate")
18
- def generate(request: QueryRequest):
19
  try:
20
  text = request.text
 
 
21
  prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
22
  output = pipe(prompt, max_new_tokens=100)
23
 
@@ -35,16 +48,29 @@ def generate(request: QueryRequest):
35
  allowed_keywords = {
36
  'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
37
  }
 
38
  # Ensure the query only contains allowed keywords
39
  tokens = sql_query.lower().split()
40
  for token in tokens:
41
  if not any(token.startswith(keyword) for keyword in allowed_keywords):
42
- raise ValueError("Generated text contains invalid SQL syntax")
43
 
 
44
  return {"output": sql_query}
 
 
 
45
  except Exception as e:
46
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  import uvicorn
50
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import logging
2
+ from fastapi import FastAPI, HTTPException, Request
3
+ from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
 
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
9
+ logger = logging.getLogger(__name__)
10
+
11
  app = FastAPI()
12
 
13
  # Initialize the text generation pipeline
14
+ try:
15
+ pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
16
+ logger.info("Model loaded successfully")
17
+ except Exception as e:
18
+ logger.error(f"Failed to load the model: {str(e)}")
19
+ raise
20
 
21
  class QueryRequest(BaseModel):
22
  text: str
 
26
  return {"message": "SQL Generation Server is running"}
27
 
28
  @app.post("/generate")
29
+ async def generate(request: QueryRequest):
30
  try:
31
  text = request.text
32
+ logger.info(f"Received request: {text}")
33
+
34
  prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
35
  output = pipe(prompt, max_new_tokens=100)
36
 
 
48
  allowed_keywords = {
49
  'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
50
  }
51
+
52
  # Ensure the query only contains allowed keywords
53
  tokens = sql_query.lower().split()
54
  for token in tokens:
55
  if not any(token.startswith(keyword) for keyword in allowed_keywords):
56
+ raise ValueError(f"Generated text contains invalid SQL syntax: {token}")
57
 
58
+ logger.info(f"Generated SQL query: {sql_query}")
59
  return {"output": sql_query}
60
+ except ValueError as ve:
61
+ logger.warning(f"Validation error: {str(ve)}")
62
+ raise HTTPException(status_code=400, detail=str(ve))
63
  except Exception as e:
64
+ logger.error(f"Error in generate endpoint: {str(e)}", exc_info=True)
65
+ raise HTTPException(status_code=500, detail="An error occurred while generating the SQL query")
66
+
67
+ @app.exception_handler(HTTPException)
68
+ async def http_exception_handler(request: Request, exc: HTTPException):
69
+ return JSONResponse(
70
+ status_code=exc.status_code,
71
+ content={"message": exc.detail},
72
+ )
73
 
74
  if __name__ == "__main__":
75
  import uvicorn
76
+ uvicorn.run(app, host="0.0.0.0", port=7860)