king007 commited on
Commit
3423dc9
1 Parent(s): 0dcb146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -3,6 +3,9 @@ import gradio as gr
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompt-generator-v12")
5
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompt-generator-v12", from_tf=True)
 
 
 
6
 
7
  def generate(prompt):
8
 
@@ -10,7 +13,18 @@ def generate(prompt):
10
  generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
11
  output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
12
  return output[0]
13
-
 
 
 
 
 
 
 
 
 
 
 
14
  input_component = gr.Textbox(label = "Input a persona, e.g. photographer", value = "photographer")
15
  output_component = gr.Textbox(label = "Prompt")
16
  examples = [["photographer"], ["developer"]]
 
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompt-generator-v12")
5
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompt-generator-v12", from_tf=True)
6
+ #
7
+ tokenizer2 = AutoTokenizer.from_pretrained("Kaludi/chatgpt-gpt4-prompts-bart-large-cnn-samsum")
8
+ model2 = AutoModelForSeq2SeqLM.from_pretrained("Kaludi/chatgpt-gpt4-prompts-bart-large-cnn-samsum", from_tf=True)
9
 
10
  def generate(prompt):
11
 
 
13
  generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
14
  output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
15
  return output[0]
16
+
17
+ def generate2(prompt, max_new_tokens):
18
+ batch = tokenizer2(prompt, return_tensors="pt")
19
+ generated_ids = model2.generate(batch["input_ids"], max_new_tokens=150)
20
+ output = tokenizer2.batch_decode(generated_ids, skip_special_tokens=True)
21
+ return output[0]
22
+ def generate_prompt(type, prompt, max_new_tokens):
23
+ if type==1:
24
+ return generate(prompt)
25
+ elif type==2:
26
+ return generate2(prompt, max_new_tokens)
27
+ #
28
  input_component = gr.Textbox(label = "Input a persona, e.g. photographer", value = "photographer")
29
  output_component = gr.Textbox(label = "Prompt")
30
  examples = [["photographer"], ["developer"]]