| import os |
| import logging |
| from datetime import datetime |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig, ChatTemplateConfig |
| from lmdeploy.vl import load_image |
|
|
| class ConversationalAgent: |
| def __init__(self, |
| model_path, |
| outputs_dir) -> None: |
| self.pipe = pipeline(model_path, |
| chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2'), |
| backend_config=TurbomindEngineConfig(session_len=8192)) |
| self.uploaded_images_storage = os.path.join(outputs_dir, "uploaded") |
| self.uploaded_images_storage = os.path.abspath(self.uploaded_images_storage) |
| os.makedirs(self.uploaded_images_storage, exist_ok=True) |
| self.sess = None |
| |
| def start_chat(self, chat_state): |
| self.sess = None |
| self.context = "" |
| self.current_image_id = -1 |
| self.image_list = [] |
| self.pixel_values_list = [] |
| self.seen_image_idx = [] |
| logging.info("=" * 30 + "Start Chat" + "=" * 30) |
| |
| return ( |
| |
| gr.update(interactive=True, placeholder='input the text.'), |
| gr.update(interactive=False), |
| gr.update(interactive=True), |
| gr.update(interactive=True), |
| gr.update(interactive=True), |
| chat_state |
| ) |
| |
| def restart_chat(self, chat_state): |
| self.sess = None |
| self.context = "" |
| self.current_image_id = -1 |
| self.image_list = [] |
| self.pixel_values_list = [] |
| self.seen_image_idx = [] |
| |
| logging.info("=" * 30 + "End Chat" + "=" * 30) |
| |
| return ( |
| None, |
| |
| gr.update(interactive=False, placeholder="Please click the <Start Chat> button to start chat!"), |
| gr.update(interactive=True), |
| gr.update(interactive=False), |
| gr.update(value=None, interactive=False), |
| gr.update(interactive=False), |
| chat_state |
| ) |
| |
| def upload_image(self, image: Image.Image, chat_history: gr.Chatbot, chat_state: gr.State): |
| logging.info(f"type(image): {type(image)}") |
| |
| self.image_list.append(image) |
| save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage)))) |
| image.save(save_image_path) |
| logging.info(f"image save path: {save_image_path}") |
| chat_history.append((gr.HTML(f'<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), "Received.")) |
| |
| return None, chat_history, chat_state |
| |
| def respond( |
| self, |
| message, |
| image, |
| chat_history: gr.Chatbot, |
| top_p, |
| temperature, |
| chat_state, |
| ): |
| current_time = datetime.now().strftime("%b%d-%H:%M:%S") |
| logging.info(f"Time: {current_time}") |
| logging.info(f"User: {message}") |
| gen_config = GenerationConfig(top_p=top_p, temperature=temperature) |
| chat_input = message |
| if image is not None: |
| save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage)))) |
| image.save(save_image_path) |
| logging.info(f"image save path: {save_image_path}") |
| chat_input = (message, image) |
| if self.sess is None: |
| self.sess = self.pipe.chat(chat_input, gen_config=gen_config) |
| else: |
| self.sess = self.pipe.chat(chat_input, session=self.sess, gen_config=gen_config) |
| response = self.sess.response.text |
| if image is not None: |
| chat_history.append((gr.HTML(f'{message}\n\n<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), response)) |
| else: |
| chat_history.append((message, response)) |
| |
| logging.info(f"generated text = \n{response}") |
| |
| return "", None, chat_history, chat_state |
|
|