# chatbot.py
import openai
import streamlit as st
import wandb
from set_env import set_env
import weave
_ = set_env("OPENAI_API_KEY")
_ = set_env("WANDB_API_KEY")
# highlight-next-line
wandb.login()
# highlight-next-line
weave_client = weave.init("feedback-example")
oai_client = openai.OpenAI()
def init_states():
    """Set up session_state keys if they don't exist yet."""
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "calls" not in st.session_state:
        st.session_state["calls"] = []
    if "session_id" not in st.session_state:
        st.session_state["session_id"] = "123abc"
# highlight-next-line
@weave.op
def chat_response(full_history):
    """
    Calls the OpenAI API in streaming mode given the entire conversation history so far.
    full_history is a list of dicts: [{"role":"user"|"assistant","content":...}, ...]
    """
    stream = oai_client.chat.completions.create(
        model="gpt-4", messages=full_history, stream=True
    )
    response_text = st.write_stream(stream)
    return {"response": response_text}
def render_feedback_buttons(call_idx):
    """Renders thumbs up/down and text feedback for the call."""
    col1, col2, col3 = st.columns([1, 1, 4])
    # Thumbs up button
    with col1:
        if st.button("👍", key=f"thumbs_up_{call_idx}"):
            st.session_state.calls[call_idx].feedback.add_reaction("👍")
            st.success("Thanks for the feedback!")
    # Thumbs down button
    with col2:
        if st.button("👎", key=f"thumbs_down_{call_idx}"):
            st.session_state.calls[call_idx].feedback.add_reaction("👎")
            st.success("Thanks for the feedback!")
    # Text feedback
    with col3:
        feedback_text = st.text_input("Feedback", key=f"feedback_input_{call_idx}")
        if (
            st.button("Submit Feedback", key=f"submit_feedback_{call_idx}")
            and feedback_text
        ):
            st.session_state.calls[call_idx].feedback.add_note(feedback_text)
            st.success("Feedback submitted!")
def display_old_messages():
    """Displays the conversation stored in st.session_state.messages with feedback buttons"""
    for idx, message in enumerate(st.session_state.messages):
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            # If it's an assistant message, show feedback form
            if message["role"] == "assistant":
                # Figure out index of this assistant message in st.session_state.calls
                assistant_idx = (
                    len(
                        [
                            m
                            for m in st.session_state.messages[: idx + 1]
                            if m["role"] == "assistant"
                        ]
                    )
                    - 1
                )
                # Render thumbs up/down & text feedback
                if assistant_idx < len(st.session_state.calls):
                    render_feedback_buttons(assistant_idx)
def display_chat_prompt():
    """Displays the chat prompt input box."""
    if prompt := st.chat_input("Ask me anything!"):
        # Immediately render new user message
        with st.chat_message("user"):
            st.markdown(prompt)
        # Save user message in session
        st.session_state.messages.append({"role": "user", "content": prompt})
        # Prepare chat history for the API
        full_history = [
            {"role": msg["role"], "content": msg["content"]}
            for msg in st.session_state.messages
        ]
        with st.chat_message("assistant"):
            # Attach Weave attributes for tracking of conversation instances
            with weave.attributes(
                {"session": st.session_state["session_id"], "env": "prod"}
            ):
                # Call the OpenAI API (stream)
                result, call = chat_response.call(full_history)
                # Store the assistant message
                st.session_state.messages.append(
                    {"role": "assistant", "content": result["response"]}
                )
                # Store the weave call object to link feedback to the specific response
                st.session_state.calls.append(call)
                # Render feedback buttons for the new message
                new_assistant_idx = (
                    len(
                        [
                            m
                            for m in st.session_state.messages
                            if m["role"] == "assistant"
                        ]
                    )
                    - 1
                )
                # Render feedback buttons
                if new_assistant_idx < len(st.session_state.calls):
                    render_feedback_buttons(new_assistant_idx)
def main():
    st.title("Chatbot with immediate feedback forms")
    init_states()
    display_old_messages()
    display_chat_prompt()
if __name__ == "__main__":
    main()