import streamlit as st from streamlit_elements import elements, mui, editor, dashboard from stqdm import stqdm import textgrad as tg import os class CodeEditor: def __init__(self, data) -> None: self.data = data # Initialize only if not already set to ensure it retains the original content if 'original_code_content' not in st.session_state: st.session_state.original_code_content = self.data["default_initial_solution"] self.llm_engine = tg.get_engine("gpt-4o") print("="*50, "init", "="*50) self.loss_value = "" self.code_gradients = "" if 'iteration' not in st.session_state: st.session_state.iteration = 0 if 'results' not in st.session_state: st.session_state.results = [] tg.set_backward_engine(self.llm_engine, override=True) def load_layout(self): # Initialize session state for problem description and other fields if not already set if 'problem' not in st.session_state: st.session_state.problem = self.data["default_problem_description"] if 'loss_system_prompt' not in st.session_state: st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"] if 'instruction' not in st.session_state: st.session_state.instruction = self.data["instruction"] col1, col2 = st.columns([1, 1]) with col1: st.session_state.problem = st.text_area("Problem description:", st.session_state.problem, height=300) with col2: st.session_state.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=150) st.session_state.instruction = st.text_area("Instruction for formatted LLM call:", st.session_state.instruction, height=100) # Assume the code content also needs to be persistent if 'code_content' not in st.session_state: st.session_state.code_content = self.data["default_initial_solution"] def update_code_content(value): if st.session_state.iteration == 0: st.session_state.code_content = value # print(f"Code updated: {st.session_state.code_content}") col1, col2 = st.columns(2) with col1: with elements("monaco_editors_widget_original"): st.markdown(f"**Initial solution:**") # code = editor.Monaco( # height=300, # defaultLanguage="python", # defaultValue=st.session_state.original_code_content, # onChange=update_code_content, # label="Initial Solution Viewer", # ) code = st.text_area("Edit your code here:", st.session_state.original_code_content, height=300) # Update session state when text changes if code is not None and st.session_state.original_code_content != code: update_code_content(code) # if st.session_state.code_content != code: # update_code_content(code) # with col2: def _run(self): # Code is the variable of interest we want to optimize -- so requires_grad=True solution = st.session_state.code_content code = tg.Variable(value=solution, requires_grad=True, role_description="code instance to optimize") # We are not interested in optimizing the problem -- so requires_grad=False problem = tg.Variable(st.session_state.problem, requires_grad=False, role_description="the coding problem") # Let TGD know to update code! optimizer = tg.TGD(parameters=[code]) instruction = st.session_state.instruction llm_engine = self.llm_engine loss_system_prompt = st.session_state.loss_system_prompt loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function") format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}" format_string = format_string.format(instruction=st.session_state.instruction) fields = {"problem": None, "code": None} formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine, format_string=format_string, fields=fields, system_prompt=loss_system_prompt) # Finally, the loss function def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable: inputs = {"problem": problem, "code": code} return formatted_llm_call(inputs=inputs, response_role_description=f"evaluation of the {code.get_role_description()}") loss = loss_fn(problem, code) self.loss_value = loss.value self.graph = loss.generate_graph() loss.backward() self.gradients = code.gradients optimizer.step() # Let's update the code st.session_state.code_content = code.value def show_results(self): self._run() st.session_state.iteration += 1 st.session_state.results.append({ 'iteration': st.session_state.iteration, 'loss_value': self.loss_value, 'gradients': self.gradients, 'code_content': st.session_state.code_content, }) tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)]) # Include Highlight.js library and a theme CSS st.markdown(""" """, unsafe_allow_html=True) for i, tab in enumerate(tabs): with tab: result = st.session_state.results[i] st.markdown(f"Current iteration: **{result['iteration']}**") st.markdown("### Current solution") st.markdown(f"""
{result["code_content"]}
""", unsafe_allow_html=True) col1, col2 = st.columns([1, 1]) with col1: st.markdown("### Loss value") st.markdown("**Loss value is based on previous code.**") st.markdown(result['loss_value']) with col2: st.markdown("### Code gradients") for j, g in enumerate(result['gradients']): # st.markdown(f"### Gradient {j}") st.markdown(g.value)