import os
import streamlit as st
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
load_dotenv(".env", override=True)
CONNECTION_STRING = os.getenv("CONNECTION_STRING")
with open('design.css') as source:
st.markdown(f"<style>{source.read()}</style>", unsafe_allow_html=True)
def get_session_history(session_id):
chat_history = SQLChatMessageHistory(
session_id=session_id,
connection=CONNECTION_STRING
)
return chat_history
def get_response():
llm = ChatGoogleGenerativeAI(model=model_name,
temperature=temperature,
max_tokens=max_tokens)
parser = StrOutputParser()
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
runnable = prompt | llm | parser
with_message_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="question",
history_messages_key="history",
)
return with_message_history.stream(
{"question": question},
config={
"configurable": {"session_id": chat_session_id}
})
def main():
with st.sidebar:
model_name = st.selectbox(
"Select AI Model",
("gemini-1.5-flash", "gemini-1.5-pro"),
index=0,
)
temperature = st.slider('Temperature', min_value=0.0, max_value=1.0, value=1.0, step=0.01)
max_tokens = st.slider("Max Tokens", min_value=128, max_value=4096, value=1024, step=128)
instruction_prompt = st.text_area(
"Instructions",
"Answer the user question",
height=120
)
system_prompt = "..."
for msg in chat_history.messages:
st.chat_message(msg.type).write(msg.content)
if prompt := st.chat_input("Ask Ellon Chat AI"):
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("ai"):
message = get_response(...)
st.write_stream(message)
main()