| |
| import streamlit as st |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import sqlparse |
|
|
| |
| st.set_page_config( |
| page_title="AI SQL Query Generator", |
| page_icon="🤖", |
| layout="centered" |
| ) |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model_name = "tscholak/cxmefzzi" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| return tokenizer, model |
|
|
| |
| def format_sql(sql): |
| return sqlparse.format(sql, reindent=True, keyword_case='upper') |
|
|
| |
| def generate_sql(input_text, tokenizer, model): |
| prefix = "Translate English to SQL: " |
| inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True) |
| outputs = model.generate(**inputs, max_length=256) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| def main(): |
| st.title("🤖 AI-Powered SQL Query Generator") |
| st.markdown("Convert natural language questions to SQL queries") |
|
|
| |
| tokenizer, model = load_model() |
|
|
| |
| user_input = st.text_area( |
| "Enter your question in natural language:", |
| placeholder="e.g., Show all customers from California who made purchases after January 2023", |
| height=150 |
| ) |
|
|
| |
| if st.button("Generate SQL"): |
| if user_input.strip() == "": |
| st.warning("Please enter a question") |
| else: |
| with st.spinner("Generating SQL query..."): |
| try: |
| |
| raw_sql = generate_sql(user_input, tokenizer, model) |
| formatted_sql = format_sql(raw_sql) |
| |
| |
| st.subheader("Generated SQL Query:") |
| st.code(formatted_sql, language="sql") |
| |
| st.success("Query generated successfully!") |
| |
| |
| with st.expander("Debug Info"): |
| st.write(f"Model: tscholak/cxmefzzi") |
| st.write(f"Raw Output: `{raw_sql}`") |
| except Exception as e: |
| st.error(f"Error generating SQL: {str(e)}") |
|
|
| |
| st.markdown("---") |
| st.markdown("### How to use:") |
| st.markdown("1. Enter a question about data you want to query") |
| st.markdown("2. Click 'Generate SQL'") |
| st.markdown("3. Copy the generated SQL and use it in your database") |
| |
| st.markdown("### Example queries:") |
| st.code("Show the total sales per product category in 2022", language="text") |
| st.code("List employees hired before 2020 with salary above $50,000", language="text") |
| st.code("Count orders by customer country and sort descending", language="text") |
|
|
| if __name__ == "__main__": |
| main() |