Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import streamlit as st | |
| from diffusers import StableDiffusionPipeline | |
| from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | |
| DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
| TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa | |
| DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu") | |
| HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
| def load_translation_models(translation_model_id): | |
| tokenizer = MBart50TokenizerFast.from_pretrained( | |
| translation_model_id, | |
| use_auth_token=HUGGING_FACE_TOKEN | |
| ) | |
| tokenizer.src_lang = 'pt_XX' | |
| text_model = MBartForConditionalGeneration.from_pretrained( | |
| translation_model_id, | |
| use_auth_token=HUGGING_FACE_TOKEN | |
| ) | |
| return tokenizer, text_model | |
| def pipeline_generate(diffusion_model_id): | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| diffusion_model_id, | |
| use_auth_token=HUGGING_FACE_TOKEN | |
| ) | |
| pipe = pipe.to(DEVICE_NAME) | |
| # Recommended if your computer has < 64 GB of RAM | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| def translate(prompt, tokenizer, text_model): | |
| pt_tokens = tokenizer([prompt], return_tensors="pt") | |
| en_tokens = text_model.generate( | |
| **pt_tokens, max_new_tokens=100, | |
| num_beams=8, early_stopping=True | |
| ) | |
| en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True) | |
| return en_prompt[0] | |
| def generate_image(pipe, prompt): | |
| # First-time "warmup" pass (see explanation above) | |
| _ = pipe(prompt, num_inference_steps=1) | |
| return pipe(prompt).images[0] | |
| def process_prompt(prompt): | |
| tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID) | |
| prompt = translate(prompt, tokenizer, text_model) | |
| pipe = pipeline_generate(DIFFUSION_MODEL_ID) | |
| image = generate_image(pipe, prompt) | |
| return image | |
| st.write("# Crie imagens com Stable Diffusion") | |
| prompt_input = st.text_input("Escreva uma descrição da imagem") | |
| placeholder = st.empty() | |
| btn = placeholder.button('Processar imagem', disabled=False, key=1) | |
| reload = st.button('Reiniciar', disabled=False) | |
| if btn: | |
| placeholder.button('Processar imagem', disabled=True, key=2) | |
| image = process_prompt(prompt_input) | |
| st.image(image) | |
| placeholder.button('Processar imagem', disabled=False, key=3) | |
| placeholder.empty() | |
| if reload: | |
| st.experimental_rerun() | |