|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer |
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
from typing import Iterable |
|
|
|
class SQLGEN(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.stone, |
|
secondary_hue: colors.Color | str = colors.green, |
|
neutral_hue: colors.Color | str = colors.gray, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_lg, |
|
font: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-sans-serif", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-monospace", |
|
"monospace", |
|
), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
|
|
|
|
model_id = "alibidaran/Gemma2_SQLGEN" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto') |
|
tokenizer.padding_side = 'right' |
|
|
|
|
|
def generate_sql(query,context): |
|
prompt = query |
|
context=context |
|
text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:" |
|
inputs=tokenizer(text,return_tensors='pt').to('cuda') |
|
with torch.no_grad(): |
|
outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5) |
|
output_text=outputs[:, inputs.input_ids.shape[1]:] |
|
output_text=tokenizer.decode(output_text[0], skip_special_tokens=True) |
|
return output_text |
|
|
|
|
|
interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN()) |
|
|
|
if __name__=='__main__': |
|
interface.launch() |