demo / examples /example_imageQA_scripts.py
huangzhii
update
32486dc
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os
from PIL import Image
from textgrad.autograd import MultimodalLLMCall
from textgrad.loss import ImageQALoss
from io import BytesIO
class ImageQA:
def __init__(self, data) -> None:
self.data = data
self.llm_engine = tg.get_engine("gpt-4o")
print("="*50, "init", "="*50)
self.loss_value = ""
self.gradients = ""
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
st.session_state.results = []
tg.set_backward_engine(self.llm_engine, override=True)
def load_layout(self):
st.markdown(f"**This is a solution optimization for image QA.**")
col1, col2 = st.columns([1, 1])
with col1:
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image")
else:
image_url = self.data["image_URL"]
image = Image.open(image_url)
st.image(image_url, caption="Default: MathVista image")
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # You can choose the format you want
img_byte_arr = img_byte_arr.getvalue()
self.image_variable = tg.Variable(img_byte_arr, role_description="image to answer a question about", requires_grad=False)
with col2:
question_text = st.text_area("Question:", self.data["question_text"], height=150)
self.question_variable = tg.Variable(question_text, role_description="question", requires_grad=False)
self.evaluation_instruction_text = st.text_area("Evaluation instruction:", self.data["evaluation_instruction"], height=100)
self.loss_fn = ImageQALoss(
evaluation_instruction=self.evaluation_instruction_text,
engine="gpt-4o",
)
if "current_response" not in st.session_state:
st.session_state.current_response = ""
def _run(self):
# Set up the textgrad variables
self.response = MultimodalLLMCall("gpt-4o")([
self.image_variable,
self.question_variable
])
optimizer = tg.TGD(parameters=[self.response])
loss = self.loss_fn(question=self.question_variable, image=self.image_variable, response=self.response)
self.loss_value = loss.value
# self.graph = loss.generate_graph()
loss.backward()
self.gradients = self.response.gradients
optimizer.step() # Let's update the response
st.session_state.current_response = self.response.value
def show_results(self):
self._run()
st.session_state.iteration += 1
st.session_state.results.append({
'iteration': st.session_state.iteration,
'loss_value': self.loss_value,
'response': self.response.value,
'gradients': self.gradients
})
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])
for i, tab in enumerate(tabs):
with tab:
result = st.session_state.results[i]
st.markdown(f"Current iteration: **{result['iteration']}**")
st.markdown("## Current solution:")
st.markdown(result['response'])
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("## Loss value")
st.markdown(result['loss_value'])
with col2:
st.markdown("## Code gradients")
for j, g in enumerate(result['gradients']):
st.markdown(f"### Gradient")
st.markdown(g.value)