#!/usr/bin/env python # coding: utf-8 # In[1]: # import required libraries from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration import gradio as gr # In[2]: # pipeline function with default values def query(image, user_question): """ image: single image or batch of images; question: user prompt question; """ # select model from hugging face model_name = "google/deplot" # set preprocessor for current model processor = Pix2StructProcessor.from_pretrained(model_name) # load pre-trained model model = Pix2StructForConditionalGeneration.from_pretrained(model_name) # process the inputs for prediction inputs = processor(images=image, text=user_question, return_tensors="pt") # save the results predictions = model.generate(**inputs, max_new_tokens=512) # save output result = processor.decode(predictions[0], skip_special_tokens=True) # process the results for output table outs = [x.strip() for x in result.split("<0x0A>")] # create an empty list nested = list() # loop for splitting the data for data in outs: if "|" in data: nested.append([x.strip() for x in data.split("|")]) else: nested.append(data) # return the converted output return nested # In[ ]: # Interface framework to customize the io page ui = gr.Interface(title="Chart Q/A", fn=query, inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")], outputs="list", examples=[["./samples/sample1.png", "Generate underlying data table of the figure"], ["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]], # ["./samples/sample3.webp", "What are the 2020 net sales?"]], cache_examples=True, allow_flagging='never') ui.queue(api_open=True) ui.launch(inline=False, share=False, debug=True)