John6666 commited on
Commit
adefd5c
·
verified ·
1 Parent(s): aefd7f3

Upload 16 files

Browse files
Files changed (5) hide show
  1. app.py +11 -0
  2. t2i/infer.py +6 -3
  3. t2i/pipe.py +2 -6
  4. t2i/utils.py +598 -497
  5. t2i_config.py +47 -30
app.py CHANGED
@@ -1,6 +1,17 @@
1
  import spaces
2
  import gradio as gr
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
 
 
 
 
 
 
 
 
 
 
4
  from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
5
  update_param_mode_gr, update_ar_gr,
6
  MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
 
1
  import spaces
2
  import gradio as gr
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
+ from t2i_config import KERNELS_PREFETCH_ON_STARTUP, KERNELS_PREFETCH_REPOS
5
+
6
+ if KERNELS_PREFETCH_ON_STARTUP:
7
+ try:
8
+ from kernels import has_kernel, get_kernel
9
+ for _repo_id in KERNELS_PREFETCH_REPOS:
10
+ if has_kernel(_repo_id):
11
+ get_kernel(_repo_id)
12
+ except Exception as _e:
13
+ print(f"INFO : Kernels prefetch skipped: {_e}")
14
+
15
  from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
16
  update_param_mode_gr, update_ar_gr,
17
  MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
t2i/infer.py CHANGED
@@ -74,10 +74,13 @@ def infer_body(prompt: str, negative_prompt: str, seed: int, randomize_seed: boo
74
  kwargs, ikwargs = {"generator": generator}, {}
75
  metadata = {"prompt": prompt, "negative_prompt": negative_prompt, "Model": Path(model.split("/")[-1]).stem, "seed": seed}
76
  if negative_prompt: kwargs["negative_prompt"] = negative_prompt
 
 
 
 
77
  elif param_mode != "Default":
78
- params = get_auto_param(model_type) if param_mode == "Auto" else {}
79
- kwargs |= {"guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "width": width, "height": height} | params
80
- metadata |= {"num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "resolution": f"{width} x {height}"} | params
81
 
82
  if task == TASK_T2I:
83
  if pipe_type == "Long Prompt Weighting" and model_type == "SDXL": kwargs["clip_skip"], metadata["clip_skip"] = clip_skip, clip_skip
 
74
  kwargs, ikwargs = {"generator": generator}, {}
75
  metadata = {"prompt": prompt, "negative_prompt": negative_prompt, "Model": Path(model.split("/")[-1]).stem, "seed": seed}
76
  if negative_prompt: kwargs["negative_prompt"] = negative_prompt
77
+ if param_mode == "Auto":
78
+ params = get_auto_param(model_type)
79
+ kwargs |= params
80
+ metadata |= params
81
  elif param_mode != "Default":
82
+ kwargs |= {"guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "width": width, "height": height}
83
+ metadata |= {"num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "resolution": f"{width} x {height}"}
 
84
 
85
  if task == TASK_T2I:
86
  if pipe_type == "Long Prompt Weighting" and model_type == "SDXL": kwargs["clip_skip"], metadata["clip_skip"] = clip_skip, clip_skip
t2i/pipe.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from diffusers import DiffusionPipeline, AutoencoderKL
6
  from diffusers.models.attention_processor import AttnProcessor2_0
7
  from t2i_config import models, sdxl_vaes, sd15_vaes, PIPELINE_MAX_GIB
8
- from t2i.utils import (logger, get_token, free_memory, calc_pipe_size, is_weight_url, get_file,
9
  get_model_type, get_model_type_from_pipe, get_task_class, DEFAULT_TASKS, IS_ZEROGPU, DEVICE, DTYPE, IS_QUANT,
10
  MAX_SEED, MAX_IMAGE_SIZE, DEFAULT_MODEL_TYPE, DEFAULT_STR, ASPECT_RATIOS, PIPELINE_TYPES, DEFAULT_VAE, PARAM_MODES)
11
 
@@ -43,11 +43,7 @@ class Pipeline:
43
  self.lastmod = time.time()
44
  if device != "cpu" and not IS_QUANT:
45
  if self.pipe.device != device: self.pipe.to(device)
46
- # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
47
- #if model_type in ["SD 1.5", "SDXL"]: self.pipe.unet.set_attn_processor(AttnProcessor2_0())
48
- #elif model_type in ["FLUX"]: self.pipe.transformer.set_attn_processor(AttnProcessor2_0())
49
- #self.pipe.vae.set_attn_processor(AttnProcessor2_0())
50
- #logger.debug(f"SDPA enabled {type(self.pipe).__name__} ({model_type}) on {device}.") # by default in PyTorch 2.x
51
  return self.pipe
52
 
53
  def quantize(self):
 
5
  from diffusers import DiffusionPipeline, AutoencoderKL
6
  from diffusers.models.attention_processor import AttnProcessor2_0
7
  from t2i_config import models, sdxl_vaes, sd15_vaes, PIPELINE_MAX_GIB
8
+ from t2i.utils import (logger, get_token, free_memory, calc_pipe_size, is_weight_url, get_file, apply_attention_backend,
9
  get_model_type, get_model_type_from_pipe, get_task_class, DEFAULT_TASKS, IS_ZEROGPU, DEVICE, DTYPE, IS_QUANT,
10
  MAX_SEED, MAX_IMAGE_SIZE, DEFAULT_MODEL_TYPE, DEFAULT_STR, ASPECT_RATIOS, PIPELINE_TYPES, DEFAULT_VAE, PARAM_MODES)
11
 
 
43
  self.lastmod = time.time()
44
  if device != "cpu" and not IS_QUANT:
45
  if self.pipe.device != device: self.pipe.to(device)
46
+ apply_attention_backend(self.pipe)
 
 
 
 
47
  return self.pipe
48
 
49
  def quantize(self):
t2i/utils.py CHANGED
@@ -1,497 +1,598 @@
1
- import spaces
2
- import os, gc, json, uuid, time, datetime, re, urllib, tempfile, math, inspect
3
- from typing import Any, Tuple, Dict, List, Optional, Iterator
4
- from dataclasses import dataclass, field
5
- from pathlib import Path
6
- from PIL import Image, PngImagePlugin
7
- import torch
8
- import numpy as np
9
- import gradio as gr
10
- from huggingface_hub import HfApi, hf_hub_download
11
- from safetensors.torch import load_file
12
- from diffusers import (AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, DiffusionPipeline,
13
- StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline,
14
- StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, AutoencoderKL)
15
- from t2i.controlnet_union.pipeline.pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
16
- from t2i_config import STORAGE_MAX_GIB, IS_DEBUG
17
-
18
-
19
- DEFAULT_STR = "Default"
20
- IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False
21
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
- DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
- IS_QUANT = False if IS_ZEROGPU else False # https://huggingface.co/posts/cbensimon/565026286160860#684a4147f1e1efa28f85ba5c
24
- MAX_SEED = np.iinfo(np.int32).max
25
- MAX_IMAGE_SIZE = 2048 #1216
26
- PIPELINE_TYPES = ["Default", "Long Prompt Weighting"]
27
- DEFAULT_VAE = DEFAULT_STR
28
- PARAM_MODES = ["Auto", "Default", "Custom"]
29
- DEFAULT_I2I_STRENGTH = 0.8
30
- DEFAULT_UPSCALE_STRENGTH = 0.55
31
- DEFAULT_UPSCALE_BY = 1.5
32
- DEFAULT_CLIP_SKIP = 2
33
-
34
-
35
- def get_logger():
36
- import logging
37
- from pytz import timezone
38
- from datetime import datetime
39
- logger = logging.getLogger(__name__)
40
- if IS_DEBUG: logger.setLevel(logging.DEBUG)
41
- else: logger.setLevel(logging.INFO)
42
- sh = logging.StreamHandler()
43
- sh.setLevel(logging.DEBUG if IS_DEBUG else logging.INFO)
44
- def customTime(*args):
45
- return datetime.now(timezone('Asia/Tokyo')).timetuple()
46
- formatter = logging.Formatter(
47
- fmt='%(levelname)s : %(asctime)s : %(message)s',
48
- datefmt="%Y-%m-%d %H:%M:%S %z"
49
- )
50
- formatter.converter = customTime
51
- sh.setFormatter(formatter)
52
- logger.addHandler(sh)
53
- return logger
54
-
55
-
56
- logger = get_logger()
57
-
58
-
59
- def get_token() -> Any:
60
- return os.getenv("HF_TOKEN", None)
61
-
62
-
63
- def list_uniq_order(l: list) -> List:
64
- return list(dict.fromkeys(l))
65
-
66
-
67
- def free_memory():
68
- if torch.cuda.is_available():
69
- torch.cuda.empty_cache()
70
- #torch.cuda.ipc_collect()
71
- gc.collect()
72
-
73
-
74
- def calc_module_size(model: torch.nn.Module) -> int:
75
- param_size = 0
76
- for param in model.parameters():
77
- param_size += param.nelement() * param.element_size()
78
- buffer_size = 0
79
- for buffer in model.buffers():
80
- buffer_size += buffer.nelement() * buffer.element_size()
81
- return int(buffer_size + param_size)
82
-
83
-
84
- def calc_pipe_size(pipe: Any) -> int:
85
- return sum([calc_module_size(m) for m in pipe.components.values() if isinstance(m, torch.nn.Module)])
86
-
87
-
88
- def calc_pix_8(x: float) -> int:
89
- y = math.ceil(x)
90
- return y - (y % 8)
91
-
92
-
93
- def calc_pix_64(x: float) -> int:
94
- y = math.ceil(x)
95
- return y - (y % 64)
96
-
97
-
98
- WEIGHT_EXTS = [".safetensors", ".sft", ".bin", ".pth"]
99
-
100
-
101
- def is_weight_url(url: str) -> bool:
102
- if "http" not in url: return False
103
- for ext in WEIGHT_EXTS:
104
- if ext in url: return True
105
- return False
106
-
107
-
108
- def read_safetensors_key(path: str) -> List[str]:
109
- try:
110
- keys = []
111
- state_dict = load_file(str(Path(path)))
112
- for k in list(state_dict.keys()):
113
- keys.append(k)
114
- state_dict.pop(k)
115
- except Exception as e:
116
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
117
- finally:
118
- del state_dict
119
- free_memory()
120
- return keys
121
-
122
-
123
- def split_hf_url(url: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
124
- try:
125
- s = list(re.findall(r'^(?:(?:https?://huggingface.co/)|(?:https?://hf.co/))(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
126
- if len(s) < 4: return "", "", "", ""
127
- repo_id = s[1]
128
- if s[0] == "datasets": repo_type = "dataset"
129
- elif s[0] == "spaces": repo_type = "space"
130
- else: repo_type = "model"
131
- subfolder = urllib.parse.unquote(s[2]) if s[2] else None
132
- filename = urllib.parse.unquote(s[3])
133
- return repo_id, filename, subfolder, repo_type
134
- except Exception as e:
135
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
136
- return "", "", None, ""
137
-
138
-
139
- def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)) -> Optional[str]:
140
- hf_token = get_token()
141
- repo_id, filename, subfolder, repo_type = split_hf_url(url)
142
- if not repo_id:
143
- logger.info(f"Failed to download {url}")
144
- return None
145
- try:
146
- logger.debug(f"Downloading {url} to {directory}")
147
- if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
148
- else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
149
- return path
150
- except Exception as e:
151
- logger.info(f"Failed to download {e}")
152
- return None
153
-
154
-
155
- @dataclass(order=True)
156
- class LocalFile:
157
- path: str = ""
158
- url: str = ""
159
- lastmod: float = 0.
160
- size: int = 0
161
- keys: list = field(default_factory=list)
162
-
163
- def __str__(self):
164
- return f"{self.path} ({self.url}) Size:{float(self.size) / (1024.**3):.2f}GiB LastMod.:{datetime.datetime.fromtimestamp(self.lastmod).strftime('%Y/%m/%d %H:%M:%S')}"
165
-
166
- def __del__(self):
167
- delpath = Path(self.path)
168
- if delpath.exists() and delpath.is_file(): delpath.unlink()
169
- logger.debug(f"Deleted {self.path}.")
170
-
171
-
172
- class LocalFiles:
173
- def __init__(self):
174
- self.files: Dict[str, LocalFile] = {}
175
- self.temp_dir = tempfile.mkdtemp()
176
- self.max_gib = STORAGE_MAX_GIB
177
-
178
- def __call__(self, url: str) -> Optional[str]:
179
- try:
180
- if url in self.files.keys():
181
- self.files[url].lastmod = time.time()
182
- return self.files[url].path
183
- path = download_hf_file(self.temp_dir, url)
184
- if not path: return None
185
- self.files[url] = LocalFile(path=path, url=url, lastmod=time.time(), size=os.path.getsize(Path(path)), keys=read_safetensors_key(path))
186
- logger.info(f"Downloaded {self.files[url]}.")
187
- self.clean()
188
- return path
189
- except Exception as e:
190
- logger.debug(f"{inspect.currentframe().f_code.co_name}: {e}")
191
- return None
192
-
193
- def __str__(self):
194
- return "\n".join([str(x) for x in self.files.values()])
195
-
196
- def clean(self):
197
- items = sorted(list(self.files.values()), key=lambda x:x.lastmod, reverse=True)
198
- sum_bytes = 0
199
- max_bytes = self.max_gib * (1024 ** 3)
200
- del_items = []
201
- for item in items:
202
- sum_bytes += item.size
203
- if sum_bytes > max_bytes: del_items.append(item.name)
204
- for item in del_items:
205
- self.files.pop(item)
206
-
207
- def get_keys(self, url: str) -> Optional[list[str]]:
208
- if url not in self.files.keys(): self.__call__(url)
209
- return self.files[url].keys if url in self.files.keys() else None
210
-
211
-
212
- local_files = LocalFiles()
213
-
214
-
215
- def get_file(url: str) -> Optional[str]:
216
- path = local_files(url)
217
- return path
218
-
219
-
220
- def get_file_keys(url: str) -> Optional[List[str]]:
221
- return local_files.get_keys(url)
222
-
223
-
224
- MODEL_TYPE_CLASS = {
225
- "diffusers:StableDiffusionPipeline": "SD 1.5",
226
- "diffusers:StableDiffusionXLPipeline": "SDXL",
227
- "diffusers:FluxPipeline": "FLUX",
228
- }
229
-
230
-
231
- PIPELINE_TO_TYPE = {k.replace("diffusers:", ""): v for k, v in MODEL_TYPE_CLASS.items()}
232
- MODEL_TYPE_VALUES = list(MODEL_TYPE_CLASS.values())
233
- DEFAULT_MODEL_TYPE = "Auto"
234
- MODEL_TYPES = [DEFAULT_MODEL_TYPE] + MODEL_TYPE_VALUES
235
-
236
-
237
- def get_model_type_from_repo_id(repo_id: str) -> str:
238
- api = HfApi(token=get_token())
239
- default = "SDXL"
240
- try:
241
- model = api.model_info(repo_id=repo_id, timeout=5.0)
242
- tags = model.tags
243
- for tag in tags:
244
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
245
- except Exception:
246
- return default
247
- return default
248
-
249
-
250
- MODEL_TYPE_KEY = {
251
- "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
252
- "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
253
- "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
254
- "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
255
- "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
256
- }
257
-
258
-
259
- def get_model_type_from_key(url: str) -> str:
260
- default = "SDXL"
261
- try:
262
- keys = get_file_keys(url)
263
- for k, v in MODEL_TYPE_KEY.items():
264
- if k in set(keys): return v
265
- except Exception:
266
- return default
267
- return default
268
-
269
-
270
- def get_model_type_from_url(url: str) -> str:
271
- default = "SDXL"
272
- try:
273
- return get_model_type_from_key(url)
274
- except Exception:
275
- return default
276
-
277
-
278
- def get_model_type(name: str) -> str:
279
- model_type = DEFAULT_MODEL_TYPE
280
- try:
281
- if is_weight_url(name): model_type = get_model_type_from_url(name)
282
- else: model_type = get_model_type_from_repo_id(name)
283
- except Exception as e:
284
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
285
- finally:
286
- logger.debug(f"{name} is determined as {model_type}.")
287
- return model_type
288
-
289
-
290
- def get_model_type_from_pipe(pipe: Any) -> str:
291
- model_type = PIPELINE_TO_TYPE.get(type(pipe).__name__, DEFAULT_MODEL_TYPE)
292
- logger.debug(f"{type(pipe).__name__} is determined as {model_type}.")
293
- return model_type
294
-
295
-
296
- AR_TO_REZ = {
297
- "1:1 (Square)": "1024x1024",
298
- "3:2 (Landscape)": "1216x832",
299
- "2:3 (Portrait)": "832x1216",
300
- "16:9 (HD TV)": "1344x768",
301
- "9:16 (Selfie)": "768x1344",
302
- "4:3 (SD TV)": "1152x896",
303
- "3:4 (Standard)": "896x1152",
304
- "21:9 (Cinema)": "1536x640",
305
- "9:21": "640x1536",
306
- "3:1": "1728x576",
307
- "1:3": "576x1728",
308
- "4:1": "2048x512",
309
- "1:4": "512x2048"
310
- }
311
- SDXL_REZ = [s for s in AR_TO_REZ.values()]
312
- AR_CUSTOM = "Custom"
313
- ASPECT_RATIOS = [AR_CUSTOM] + [s for s in AR_TO_REZ.keys()]
314
- MODEL_TYPE_REZ = {"SDXL": 1024, "SD 1.5": 512, "FLUX": 1024}
315
-
316
-
317
- def get_rez_from_ar(ar: str, type: str="SDXL", sw: int=1024, sh: int=1024) -> Tuple[int, int]:
318
- if ar == AR_CUSTOM: return sw, sh
319
- br = AR_TO_REZ.get(ar, SDXL_REZ[0])
320
- bw, bh = int(br.split("x")[0]), int(br.split("x")[1])
321
- sr = 1024 # SDXL
322
- tr = MODEL_TYPE_REZ.get(type, 1024)
323
- return calc_pix_64(bw * tr / sr), calc_pix_64(bh * tr / sr)
324
-
325
-
326
- def update_ar_gr(ar: str):
327
- if ar == AR_CUSTOM: return gr.update(visible=True), gr.update(visible=True)
328
- else: return gr.update(visible=False), gr.update(visible=False)
329
-
330
-
331
- AUTO_PARAM_DICT = {
332
- "SD 1.5": {"guidance_scale": 7., "num_inference_steps": 50},
333
- "SDXL": {"guidance_scale": 7., "num_inference_steps": 28},
334
- "FLUX": {"guidance_scale": 3.5, "num_inference_steps": 28},
335
- }
336
-
337
-
338
- def get_auto_param(type: str) -> Dict:
339
- param = AUTO_PARAM_DICT.get(type, {})
340
- return param
341
-
342
-
343
- def update_param_mode_gr(mode: str):
344
- if mode in ["Auto", "Default"]: return gr.update(visible=False), gr.update(visible=False)
345
- else: return gr.update(visible=True), gr.update(visible=True)
346
-
347
-
348
- TASK_T2I = "Text-to-Image"
349
- TASK_I2I = "Image-to-Image"
350
- TASK_INPAINT = "Inpaint"
351
- DEFAULT_PIPE_CLASS = "Auto"
352
-
353
-
354
- DIFFUSERS_TASK = {
355
- DEFAULT_PIPE_CLASS: {
356
- TASK_T2I: AutoPipelineForText2Image,
357
- TASK_I2I: AutoPipelineForImage2Image,
358
- TASK_INPAINT: AutoPipelineForInpainting,
359
- },
360
- "SD 1.5": {
361
- TASK_T2I: StableDiffusionPipeline,
362
- TASK_I2I: StableDiffusionImg2ImgPipeline,
363
- TASK_INPAINT: StableDiffusionControlNetInpaintPipeline,
364
- },
365
- "SDXL": {
366
- TASK_T2I: StableDiffusionXLPipeline,
367
- TASK_I2I: StableDiffusionXLImg2ImgPipeline,
368
- TASK_INPAINT: StableDiffusionXLControlNetUnionInpaintPipeline,
369
- },
370
- "FLUX": {
371
- TASK_T2I: FluxPipeline,
372
- TASK_I2I: FluxImg2ImgPipeline,
373
- TASK_INPAINT: FluxInpaintPipeline,
374
- },
375
- }
376
-
377
-
378
- def get_tasks(model_type: str=DEFAULT_PIPE_CLASS) -> List[str]:
379
- if model_type not in DIFFUSERS_TASK.keys(): model_type = DEFAULT_PIPE_CLASS
380
- return [x for x in DIFFUSERS_TASK.get(model_type, DEFAULT_PIPE_CLASS).keys()]
381
-
382
-
383
- def get_task_class(model_type: str, task: str) -> Any:
384
- if model_type not in DIFFUSERS_TASK.keys(): model_type = DEFAULT_PIPE_CLASS
385
- try:
386
- return DIFFUSERS_TASK[model_type][task]
387
- except Exception as e:
388
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
389
- return DIFFUSERS_TASK[DEFAULT_PIPE_CLASS][task]
390
-
391
-
392
- DEFAULT_TASKS = get_tasks()
393
- KNOWN_PIPE_CLASS = [x for x in DIFFUSERS_TASK.keys() if x != DEFAULT_PIPE_CLASS]
394
-
395
-
396
- HF_DEFAULT_STEPS = 50
397
- DEFAULT_INFER_TIME = 10.
398
- MODEL_INFER_TIME = {"SD 1.5": 5.0, "SDXL": 8.5}
399
-
400
-
401
- def get_final_steps(type: str, mode: str, steps: int) -> int:
402
- if mode not in ["Default", "Auto"]: return steps
403
- if mode == "Auto":
404
- param = get_auto_param(type)
405
- s = param.get("num_inference_steps", None)
406
- if s is None: return HF_DEFAULT_STEPS
407
- else: return s
408
- elif mode == "Default": return HF_DEFAULT_STEPS
409
- return steps
410
-
411
-
412
- def estimate_model_infer_time(type: str=DEFAULT_MODEL_TYPE, task: str=DEFAULT_TASKS[0], mode: str=PARAM_MODES[0], steps: int=HF_DEFAULT_STEPS) -> float:
413
- steps = get_final_steps(type, mode, steps)
414
- base_time = MODEL_INFER_TIME.get(type, DEFAULT_INFER_TIME)
415
- time = (base_time * 0.25) + (base_time * 0.75 * float(steps) / float(HF_DEFAULT_STEPS))
416
- if task == TASK_INPAINT: time *= 3. if type == "SD 1.5" else 1.5
417
- elif task == TASK_I2I: time *= 1.5
418
- return time
419
-
420
-
421
- def resize_ref_image(image: Image.Image) -> Image.Image:
422
- MIN_SIZE = 256
423
- try:
424
- ow, oh = image.size
425
- if ow > oh:
426
- tw = max(calc_pix_8(ow), MIN_SIZE)
427
- th = calc_pix_8(tw * oh / ow)
428
- else:
429
- th = max(calc_pix_8(oh), MIN_SIZE)
430
- tw = calc_pix_8(th * ow / oh)
431
- return image.resize((tw, th), Image.LANCZOS)
432
- except Exception as e:
433
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
434
- return image
435
-
436
-
437
- def get_image_mask(image_dict: Optional[Dict]) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
438
- image, mask = None, None
439
- try:
440
- if isinstance(image_dict, dict):
441
- image = image_dict.get("background", None)
442
- layers = image_dict.get("layers", None)
443
- mask = layers[0] if layers is not None and len(layers) > 0 else None
444
- if isinstance(image, str): image = Image.open(image)
445
- if isinstance(image, Image.Image): image = resize_ref_image(image).convert("RGB")
446
- if isinstance(mask, str): mask = Image.open(mask)
447
- if isinstance(mask, Image.Image): mask = resize_ref_image(mask).convert("L")
448
- except Exception as e:
449
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
450
- finally:
451
- logger.debug(f"Image:{image}, Mask:{mask}")
452
- return image, mask
453
-
454
-
455
- FILE_FORMAT_MAP = {"PNG": "png", "WebP": "webp", "JPEG": "jpg"}
456
- FILE_FORMATS = [x for x in FILE_FORMAT_MAP.keys()]
457
-
458
-
459
- def save_image(image: Image.Image, metadata: dict, format: str=FILE_FORMATS[0]) -> Optional[str]:
460
- try:
461
- ext = FILE_FORMAT_MAP.get(format, "png")
462
- savefile = f'{metadata["Model"]}_{str(uuid.uuid4())}.{ext}'
463
- if ext in ["png"]:
464
- metadata_str = json.dumps(metadata)
465
- info = PngImagePlugin.PngInfo()
466
- info.add_text("metadata", metadata_str)
467
- image.save(savefile, "PNG", pnginfo=info)
468
- else: image.save(savefile)
469
- return str(Path(savefile).resolve())
470
- except Exception as e:
471
- logger.info(f"Failed to save image file: {e}")
472
- raise Exception(f"Failed to save image file: {e}") from e
473
-
474
-
475
- def save_image_history(image: str, gallery: Optional[List], files: Optional[List], progress=gr.Progress(track_tqdm=True)):
476
- if not gallery: gallery = []
477
- if not files: files = []
478
- try:
479
- if isinstance(image, str):
480
- files.insert(0, str(Path(image).resolve()))
481
- gallery.insert(0, (str(Path(image).resolve()), str(Path(image).name)))
482
- except Exception as e:
483
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
484
- finally:
485
- return gr.update(value=gallery), gr.update(value=files, visible=True)
486
-
487
-
488
- def save_gallery_history(images: Optional[List], gallery: Optional[List], files: Optional[List], progress=gr.Progress(track_tqdm=True)):
489
- if not gallery: gallery = []
490
- if not files: files = []
491
- try:
492
- gallery = list_uniq_order(images.copy() + gallery)
493
- files = [x[0] for x in gallery]
494
- except Exception as e:
495
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
496
- finally:
497
- return gr.update(value=gallery), gr.update(value=files, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os, gc, json, uuid, time, datetime, re, urllib, tempfile, math, inspect
3
+ from typing import Any, Tuple, Dict, List, Optional, Iterator
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from PIL import Image, PngImagePlugin
7
+ import torch
8
+ import numpy as np
9
+ import gradio as gr
10
+ from huggingface_hub import HfApi, hf_hub_download
11
+ from safetensors.torch import load_file
12
+ from diffusers import (AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, DiffusionPipeline,
13
+ StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline,
14
+ StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, AutoencoderKL)
15
+ from t2i.controlnet_union.pipeline.pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
16
+ from t2i_config import STORAGE_MAX_GIB, IS_DEBUG, ATTENTION_BACKEND, ATTENTION_BACKEND_NON_HOPPER
17
+
18
+
19
+ DEFAULT_STR = "Default"
20
+ IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
+ IS_QUANT = False if IS_ZEROGPU else False # https://huggingface.co/posts/cbensimon/565026286160860#684a4147f1e1efa28f85ba5c
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 2048 #1216
26
+ PIPELINE_TYPES = ["Default", "Long Prompt Weighting"]
27
+ DEFAULT_VAE = DEFAULT_STR
28
+ PARAM_MODES = ["Auto", "Default", "Custom"]
29
+ DEFAULT_I2I_STRENGTH = 0.8
30
+ DEFAULT_UPSCALE_STRENGTH = 0.55
31
+ DEFAULT_UPSCALE_BY = 1.5
32
+ DEFAULT_CLIP_SKIP = 2
33
+
34
+ # Attention backend switching (Diffusers attention dispatcher)
35
+ # Works across SD1.5/SDXL/FLUX by applying to any component that supports set_attention_backend().
36
+ # Config lives in t2i_config.py (ATTENTION_BACKEND, ATTENTION_BACKEND_NON_HOPPER).
37
+ def _is_hopper_gpu() -> bool:
38
+ if not torch.cuda.is_available():
39
+ return False
40
+ try:
41
+ major, minor = torch.cuda.get_device_capability()
42
+ return major >= 9 # SM90+ (Hopper)
43
+ except Exception:
44
+ return False
45
+
46
+
47
+ def _resolve_attention_backend() -> Optional[str]:
48
+ backend = ATTENTION_BACKEND
49
+ if backend is None:
50
+ return None
51
+ backend = str(backend).strip()
52
+ if backend == "":
53
+ return None
54
+ if backend.lower() == "auto":
55
+ return "_flash_3_hub" if _is_hopper_gpu() else (ATTENTION_BACKEND_NON_HOPPER or "flash_hub")
56
+ return backend
57
+
58
+
59
+ def _iter_attention_targets(pipe: Any) -> Iterator[Any]:
60
+ # common attributes
61
+ for name in ["unet", "transformer", "controlnet"]:
62
+ if hasattr(pipe, name):
63
+ obj = getattr(pipe, name)
64
+ if obj is None:
65
+ continue
66
+ if isinstance(obj, (list, tuple, set)):
67
+ for o in obj:
68
+ if o is not None:
69
+ yield o
70
+ elif isinstance(obj, dict):
71
+ for o in obj.values():
72
+ if o is not None:
73
+ yield o
74
+ else:
75
+ yield obj
76
+
77
+ # pipeline.components (dict)
78
+ if hasattr(pipe, "components"):
79
+ try:
80
+ comps = getattr(pipe, "components")
81
+ if isinstance(comps, dict):
82
+ for o in comps.values():
83
+ if o is None:
84
+ continue
85
+ if isinstance(o, (list, tuple, set)):
86
+ for x in o:
87
+ if x is not None:
88
+ yield x
89
+ elif isinstance(o, dict):
90
+ for x in o.values():
91
+ if x is not None:
92
+ yield x
93
+ else:
94
+ yield o
95
+ except Exception:
96
+ pass
97
+
98
+
99
+ def apply_attention_backend(pipe: Any) -> bool:
100
+ backend = _resolve_attention_backend()
101
+ if not backend:
102
+ return False
103
+
104
+ prev = getattr(pipe, "_t2i_attention_backend", None)
105
+ if prev == backend:
106
+ return False
107
+
108
+ applied: List[str] = []
109
+ seen = set()
110
+
111
+ for obj in _iter_attention_targets(pipe):
112
+ oid = id(obj)
113
+ if oid in seen:
114
+ continue
115
+ seen.add(oid)
116
+
117
+ if not hasattr(obj, "set_attention_backend"):
118
+ continue
119
+
120
+ try:
121
+ obj.set_attention_backend(backend)
122
+ applied.append(type(obj).__name__)
123
+ except Exception as e:
124
+ logger.debug(f"set_attention_backend({backend}) failed on {type(obj).__name__}: {e}")
125
+
126
+ if applied:
127
+ pipe._t2i_attention_backend = backend
128
+ logger.debug(f"Attention backend set to {backend} on {list_uniq_order(applied)}.")
129
+ return True
130
+
131
+ logger.debug(f"Attention backend {backend} was not applied (no compatible components).")
132
+ pipe._t2i_attention_backend = None
133
+ return False
134
+
135
+
136
+ def get_logger():
137
+ import logging
138
+ from pytz import timezone
139
+ from datetime import datetime
140
+ logger = logging.getLogger(__name__)
141
+ if IS_DEBUG: logger.setLevel(logging.DEBUG)
142
+ else: logger.setLevel(logging.INFO)
143
+ sh = logging.StreamHandler()
144
+ sh.setLevel(logging.DEBUG if IS_DEBUG else logging.INFO)
145
+ def customTime(*args):
146
+ return datetime.now(timezone('Asia/Tokyo')).timetuple()
147
+ formatter = logging.Formatter(
148
+ fmt='%(levelname)s : %(asctime)s : %(message)s',
149
+ datefmt="%Y-%m-%d %H:%M:%S %z"
150
+ )
151
+ formatter.converter = customTime
152
+ sh.setFormatter(formatter)
153
+ logger.addHandler(sh)
154
+ return logger
155
+
156
+
157
+ logger = get_logger()
158
+
159
+
160
+ def get_token() -> Any:
161
+ return os.getenv("HF_TOKEN", None)
162
+
163
+
164
+ def list_uniq_order(l: list) -> List:
165
+ return list(dict.fromkeys(l))
166
+
167
+
168
+ def free_memory():
169
+ if torch.cuda.is_available():
170
+ torch.cuda.empty_cache()
171
+ #torch.cuda.ipc_collect()
172
+ gc.collect()
173
+
174
+
175
+ def calc_module_size(model: torch.nn.Module) -> int:
176
+ param_size = 0
177
+ for param in model.parameters():
178
+ param_size += param.nelement() * param.element_size()
179
+ buffer_size = 0
180
+ for buffer in model.buffers():
181
+ buffer_size += buffer.nelement() * buffer.element_size()
182
+ return int(buffer_size + param_size)
183
+
184
+
185
+ def calc_pipe_size(pipe: Any) -> int:
186
+ return sum([calc_module_size(m) for m in pipe.components.values() if isinstance(m, torch.nn.Module)])
187
+
188
+
189
+ def calc_pix_8(x: float) -> int:
190
+ y = math.ceil(x)
191
+ return y - (y % 8)
192
+
193
+
194
+ def calc_pix_64(x: float) -> int:
195
+ y = math.ceil(x)
196
+ return y - (y % 64)
197
+
198
+
199
+ WEIGHT_EXTS = [".safetensors", ".sft", ".bin", ".pth"]
200
+
201
+
202
+ def is_weight_url(url: str) -> bool:
203
+ if "http" not in url: return False
204
+ for ext in WEIGHT_EXTS:
205
+ if ext in url: return True
206
+ return False
207
+
208
+
209
+ def read_safetensors_key(path: str) -> List[str]:
210
+ try:
211
+ keys = []
212
+ state_dict = load_file(str(Path(path)))
213
+ for k in list(state_dict.keys()):
214
+ keys.append(k)
215
+ state_dict.pop(k)
216
+ except Exception as e:
217
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
218
+ finally:
219
+ del state_dict
220
+ free_memory()
221
+ return keys
222
+
223
+
224
+ def split_hf_url(url: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
225
+ try:
226
+ s = list(re.findall(r'^(?:(?:https?://huggingface.co/)|(?:https?://hf.co/))(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
227
+ if len(s) < 4: return "", "", "", ""
228
+ repo_id = s[1]
229
+ if s[0] == "datasets": repo_type = "dataset"
230
+ elif s[0] == "spaces": repo_type = "space"
231
+ else: repo_type = "model"
232
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
233
+ filename = urllib.parse.unquote(s[3])
234
+ return repo_id, filename, subfolder, repo_type
235
+ except Exception as e:
236
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
237
+ return "", "", None, ""
238
+
239
+
240
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)) -> Optional[str]:
241
+ hf_token = get_token()
242
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
243
+ if not repo_id:
244
+ logger.info(f"Failed to download {url}")
245
+ return None
246
+ try:
247
+ logger.debug(f"Downloading {url} to {directory}")
248
+ if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
249
+ else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
250
+ return path
251
+ except Exception as e:
252
+ logger.info(f"Failed to download {e}")
253
+ return None
254
+
255
+
256
+ @dataclass(order=True)
257
+ class LocalFile:
258
+ path: str = ""
259
+ url: str = ""
260
+ lastmod: float = 0.
261
+ size: int = 0
262
+ keys: list = field(default_factory=list)
263
+
264
+ def __str__(self):
265
+ return f"{self.path} ({self.url}) Size:{float(self.size) / (1024.**3):.2f}GiB LastMod.:{datetime.datetime.fromtimestamp(self.lastmod).strftime('%Y/%m/%d %H:%M:%S')}"
266
+
267
+ def __del__(self):
268
+ delpath = Path(self.path)
269
+ if delpath.exists() and delpath.is_file(): delpath.unlink()
270
+ logger.debug(f"Deleted {self.path}.")
271
+
272
+
273
+ class LocalFiles:
274
+ def __init__(self):
275
+ self.files: Dict[str, LocalFile] = {}
276
+ self.temp_dir = tempfile.mkdtemp()
277
+ self.max_gib = STORAGE_MAX_GIB
278
+
279
+ def __call__(self, url: str) -> Optional[str]:
280
+ try:
281
+ if url in self.files.keys():
282
+ self.files[url].lastmod = time.time()
283
+ return self.files[url].path
284
+ path = download_hf_file(self.temp_dir, url)
285
+ if not path: return None
286
+ self.files[url] = LocalFile(path=path, url=url, lastmod=time.time(), size=os.path.getsize(Path(path)), keys=read_safetensors_key(path))
287
+ logger.info(f"Downloaded {self.files[url]}.")
288
+ self.clean()
289
+ return path
290
+ except Exception as e:
291
+ logger.debug(f"{inspect.currentframe().f_code.co_name}: {e}")
292
+ return None
293
+
294
+ def __str__(self):
295
+ return "\n".join([str(x) for x in self.files.values()])
296
+
297
+ def clean(self):
298
+ items = sorted(list(self.files.values()), key=lambda x:x.lastmod, reverse=True)
299
+ sum_bytes = 0
300
+ max_bytes = self.max_gib * (1024 ** 3)
301
+ del_items = []
302
+ for item in items:
303
+ sum_bytes += item.size
304
+ if sum_bytes > max_bytes: del_items.append(item.name)
305
+ for item in del_items:
306
+ self.files.pop(item)
307
+
308
+ def get_keys(self, url: str) -> Optional[list[str]]:
309
+ if url not in self.files.keys(): self.__call__(url)
310
+ return self.files[url].keys if url in self.files.keys() else None
311
+
312
+
313
+ local_files = LocalFiles()
314
+
315
+
316
+ def get_file(url: str) -> Optional[str]:
317
+ path = local_files(url)
318
+ return path
319
+
320
+
321
+ def get_file_keys(url: str) -> Optional[List[str]]:
322
+ return local_files.get_keys(url)
323
+
324
+
325
+ MODEL_TYPE_CLASS = {
326
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
327
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
328
+ "diffusers:FluxPipeline": "FLUX",
329
+ }
330
+
331
+
332
+ PIPELINE_TO_TYPE = {k.replace("diffusers:", ""): v for k, v in MODEL_TYPE_CLASS.items()}
333
+ MODEL_TYPE_VALUES = list(MODEL_TYPE_CLASS.values())
334
+ DEFAULT_MODEL_TYPE = "Auto"
335
+ MODEL_TYPES = [DEFAULT_MODEL_TYPE] + MODEL_TYPE_VALUES
336
+
337
+
338
+ def get_model_type_from_repo_id(repo_id: str) -> str:
339
+ api = HfApi(token=get_token())
340
+ default = "SDXL"
341
+ try:
342
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
343
+ tags = model.tags
344
+ for tag in tags:
345
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
346
+ except Exception:
347
+ return default
348
+ return default
349
+
350
+
351
+ MODEL_TYPE_KEY = {
352
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
353
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
354
+ "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
355
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
356
+ "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
357
+ }
358
+
359
+
360
+ def get_model_type_from_key(url: str) -> str:
361
+ default = "SDXL"
362
+ try:
363
+ keys = get_file_keys(url)
364
+ for k, v in MODEL_TYPE_KEY.items():
365
+ if k in set(keys): return v
366
+ except Exception:
367
+ return default
368
+ return default
369
+
370
+
371
+ def get_model_type_from_url(url: str) -> str:
372
+ default = "SDXL"
373
+ try:
374
+ return get_model_type_from_key(url)
375
+ except Exception:
376
+ return default
377
+
378
+
379
+ def get_model_type(name: str) -> str:
380
+ model_type = DEFAULT_MODEL_TYPE
381
+ try:
382
+ if is_weight_url(name): model_type = get_model_type_from_url(name)
383
+ else: model_type = get_model_type_from_repo_id(name)
384
+ except Exception as e:
385
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
386
+ finally:
387
+ logger.debug(f"{name} is determined as {model_type}.")
388
+ return model_type
389
+
390
+
391
+ def get_model_type_from_pipe(pipe: Any) -> str:
392
+ model_type = PIPELINE_TO_TYPE.get(type(pipe).__name__, DEFAULT_MODEL_TYPE)
393
+ logger.debug(f"{type(pipe).__name__} is determined as {model_type}.")
394
+ return model_type
395
+
396
+
397
+ AR_TO_REZ = {
398
+ "1:1 (Square)": "1024x1024",
399
+ "3:2 (Landscape)": "1216x832",
400
+ "2:3 (Portrait)": "832x1216",
401
+ "16:9 (HD TV)": "1344x768",
402
+ "9:16 (Selfie)": "768x1344",
403
+ "4:3 (SD TV)": "1152x896",
404
+ "3:4 (Standard)": "896x1152",
405
+ "21:9 (Cinema)": "1536x640",
406
+ "9:21": "640x1536",
407
+ "3:1": "1728x576",
408
+ "1:3": "576x1728",
409
+ "4:1": "2048x512",
410
+ "1:4": "512x2048"
411
+ }
412
+ SDXL_REZ = [s for s in AR_TO_REZ.values()]
413
+ AR_CUSTOM = "Custom"
414
+ ASPECT_RATIOS = [AR_CUSTOM] + [s for s in AR_TO_REZ.keys()]
415
+ MODEL_TYPE_REZ = {"SDXL": 1024, "SD 1.5": 512, "FLUX": 1024}
416
+
417
+
418
+ def get_rez_from_ar(ar: str, type: str="SDXL", sw: int=1024, sh: int=1024) -> Tuple[int, int]:
419
+ if ar == AR_CUSTOM: return sw, sh
420
+ br = AR_TO_REZ.get(ar, SDXL_REZ[0])
421
+ bw, bh = int(br.split("x")[0]), int(br.split("x")[1])
422
+ sr = 1024 # SDXL
423
+ tr = MODEL_TYPE_REZ.get(type, 1024)
424
+ return calc_pix_64(bw * tr / sr), calc_pix_64(bh * tr / sr)
425
+
426
+
427
+ def update_ar_gr(ar: str):
428
+ if ar == AR_CUSTOM: return gr.update(visible=True), gr.update(visible=True)
429
+ else: return gr.update(visible=False), gr.update(visible=False)
430
+
431
+
432
+ AUTO_PARAM_DICT = {
433
+ "SD 1.5": {"guidance_scale": 7., "num_inference_steps": 50},
434
+ "SDXL": {"guidance_scale": 7., "num_inference_steps": 28},
435
+ "FLUX": {"guidance_scale": 3.5, "num_inference_steps": 28},
436
+ }
437
+
438
+
439
+ def get_auto_param(type: str) -> Dict:
440
+ param = AUTO_PARAM_DICT.get(type, {})
441
+ return param
442
+
443
+
444
+ def update_param_mode_gr(mode: str):
445
+ if mode in ["Auto", "Default"]: return gr.update(visible=False), gr.update(visible=False)
446
+ else: return gr.update(visible=True), gr.update(visible=True)
447
+
448
+
449
+ TASK_T2I = "Text-to-Image"
450
+ TASK_I2I = "Image-to-Image"
451
+ TASK_INPAINT = "Inpaint"
452
+ DEFAULT_PIPE_CLASS = "Auto"
453
+
454
+
455
+ DIFFUSERS_TASK = {
456
+ DEFAULT_PIPE_CLASS: {
457
+ TASK_T2I: AutoPipelineForText2Image,
458
+ TASK_I2I: AutoPipelineForImage2Image,
459
+ TASK_INPAINT: AutoPipelineForInpainting,
460
+ },
461
+ "SD 1.5": {
462
+ TASK_T2I: StableDiffusionPipeline,
463
+ TASK_I2I: StableDiffusionImg2ImgPipeline,
464
+ TASK_INPAINT: StableDiffusionControlNetInpaintPipeline,
465
+ },
466
+ "SDXL": {
467
+ TASK_T2I: StableDiffusionXLPipeline,
468
+ TASK_I2I: StableDiffusionXLImg2ImgPipeline,
469
+ TASK_INPAINT: StableDiffusionXLControlNetUnionInpaintPipeline,
470
+ },
471
+ "FLUX": {
472
+ TASK_T2I: FluxPipeline,
473
+ TASK_I2I: FluxImg2ImgPipeline,
474
+ TASK_INPAINT: FluxInpaintPipeline,
475
+ },
476
+ }
477
+
478
+
479
+ def get_tasks(model_type: str=DEFAULT_PIPE_CLASS) -> List[str]:
480
+ if model_type not in DIFFUSERS_TASK.keys(): model_type = DEFAULT_PIPE_CLASS
481
+ return [x for x in DIFFUSERS_TASK.get(model_type, DEFAULT_PIPE_CLASS).keys()]
482
+
483
+
484
+ def get_task_class(model_type: str, task: str) -> Any:
485
+ if model_type not in DIFFUSERS_TASK.keys(): model_type = DEFAULT_PIPE_CLASS
486
+ try:
487
+ return DIFFUSERS_TASK[model_type][task]
488
+ except Exception as e:
489
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
490
+ return DIFFUSERS_TASK[DEFAULT_PIPE_CLASS][task]
491
+
492
+
493
+ DEFAULT_TASKS = get_tasks()
494
+ KNOWN_PIPE_CLASS = [x for x in DIFFUSERS_TASK.keys() if x != DEFAULT_PIPE_CLASS]
495
+
496
+
497
+ HF_DEFAULT_STEPS = 50
498
+ DEFAULT_INFER_TIME = 10.
499
+ MODEL_INFER_TIME = {"SD 1.5": 5.0, "SDXL": 8.5}
500
+
501
+
502
+ def get_final_steps(type: str, mode: str, steps: int) -> int:
503
+ if mode not in ["Default", "Auto"]: return steps
504
+ if mode == "Auto":
505
+ param = get_auto_param(type)
506
+ s = param.get("num_inference_steps", None)
507
+ if s is None: return HF_DEFAULT_STEPS
508
+ else: return s
509
+ elif mode == "Default": return HF_DEFAULT_STEPS
510
+ return steps
511
+
512
+
513
+ def estimate_model_infer_time(type: str=DEFAULT_MODEL_TYPE, task: str=DEFAULT_TASKS[0], mode: str=PARAM_MODES[0], steps: int=HF_DEFAULT_STEPS) -> float:
514
+ steps = get_final_steps(type, mode, steps)
515
+ base_time = MODEL_INFER_TIME.get(type, DEFAULT_INFER_TIME)
516
+ time = (base_time * 0.25) + (base_time * 0.75 * float(steps) / float(HF_DEFAULT_STEPS))
517
+ if task == TASK_INPAINT: time *= 3. if type == "SD 1.5" else 1.5
518
+ elif task == TASK_I2I: time *= 1.5
519
+ return time
520
+
521
+
522
+ def resize_ref_image(image: Image.Image) -> Image.Image:
523
+ MIN_SIZE = 256
524
+ try:
525
+ ow, oh = image.size
526
+ if ow > oh:
527
+ tw = max(calc_pix_8(ow), MIN_SIZE)
528
+ th = calc_pix_8(tw * oh / ow)
529
+ else:
530
+ th = max(calc_pix_8(oh), MIN_SIZE)
531
+ tw = calc_pix_8(th * ow / oh)
532
+ return image.resize((tw, th), Image.LANCZOS)
533
+ except Exception as e:
534
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
535
+ return image
536
+
537
+
538
+ def get_image_mask(image_dict: Optional[Dict]) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
539
+ image, mask = None, None
540
+ try:
541
+ if isinstance(image_dict, dict):
542
+ image = image_dict.get("background", None)
543
+ layers = image_dict.get("layers", None)
544
+ mask = layers[0] if layers is not None and len(layers) > 0 else None
545
+ if isinstance(image, str): image = Image.open(image)
546
+ if isinstance(image, Image.Image): image = resize_ref_image(image).convert("RGB")
547
+ if isinstance(mask, str): mask = Image.open(mask)
548
+ if isinstance(mask, Image.Image): mask = resize_ref_image(mask).convert("L")
549
+ except Exception as e:
550
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
551
+ finally:
552
+ logger.debug(f"Image:{image}, Mask:{mask}")
553
+ return image, mask
554
+
555
+
556
+ FILE_FORMAT_MAP = {"PNG": "png", "WebP": "webp", "JPEG": "jpg"}
557
+ FILE_FORMATS = [x for x in FILE_FORMAT_MAP.keys()]
558
+
559
+
560
+ def save_image(image: Image.Image, metadata: dict, format: str=FILE_FORMATS[0]) -> Optional[str]:
561
+ try:
562
+ ext = FILE_FORMAT_MAP.get(format, "png")
563
+ savefile = f'{metadata["Model"]}_{str(uuid.uuid4())}.{ext}'
564
+ if ext in ["png"]:
565
+ metadata_str = json.dumps(metadata)
566
+ info = PngImagePlugin.PngInfo()
567
+ info.add_text("metadata", metadata_str)
568
+ image.save(savefile, "PNG", pnginfo=info)
569
+ else: image.save(savefile)
570
+ return str(Path(savefile).resolve())
571
+ except Exception as e:
572
+ logger.info(f"Failed to save image file: {e}")
573
+ raise Exception(f"Failed to save image file: {e}") from e
574
+
575
+
576
+ def save_image_history(image: str, gallery: Optional[List], files: Optional[List], progress=gr.Progress(track_tqdm=True)):
577
+ if not gallery: gallery = []
578
+ if not files: files = []
579
+ try:
580
+ if isinstance(image, str):
581
+ files.insert(0, str(Path(image).resolve()))
582
+ gallery.insert(0, (str(Path(image).resolve()), str(Path(image).name)))
583
+ except Exception as e:
584
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
585
+ finally:
586
+ return gr.update(value=gallery), gr.update(value=files, visible=True)
587
+
588
+
589
+ def save_gallery_history(images: Optional[List], gallery: Optional[List], files: Optional[List], progress=gr.Progress(track_tqdm=True)):
590
+ if not gallery: gallery = []
591
+ if not files: files = []
592
+ try:
593
+ gallery = list_uniq_order(images.copy() + gallery)
594
+ files = [x[0] for x in gallery]
595
+ except Exception as e:
596
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
597
+ finally:
598
+ return gr.update(value=gallery), gr.update(value=files, visible=True)
t2i_config.py CHANGED
@@ -1,30 +1,47 @@
1
-
2
- models = [
3
- 'Yntec/YiffyMix',
4
- 'Raelina/Rae-Diffusion-XL-V2',
5
- 'Raelina/Raemu-XL-V4',
6
- 'Raelina/Raemu-XL-V5',
7
- 'Raelina/Raena-XL-V2',
8
- 'Raelina/Raehoshi-illust-XL',
9
- 'Raelina/Raehoshi-illust-xl-2',
10
- 'Raelina/Raehoshi-Illust-XL-2.1',
11
- 'Raelina/Raehoshi-illust-XL-3',
12
- 'Raelina/Raehoshi-illust-XL-4',
13
- 'Raelina/Raehoshi-illust-XL-8',
14
- "https://huggingface.co/Yntec/epiCPhotoGasm/blob/main/epiCPhotoGasmVAE.safetensors",
15
- ]
16
-
17
-
18
- sdxl_vaes = [
19
- "madebyollin/sdxl-vae-fp16-fix",
20
- "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
21
- ]
22
-
23
-
24
- sd15_vaes = []
25
-
26
-
27
- STORAGE_MAX_GIB = 40
28
- PIPELINE_MAX_GIB = 30
29
- DEFAULT_DURATION = 0 # if 0, auto
30
- IS_DEBUG = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ models = [
3
+ 'Yntec/YiffyMix',
4
+ 'Raelina/Rae-Diffusion-XL-V2',
5
+ 'Raelina/Raemu-XL-V4',
6
+ 'Raelina/Raemu-XL-V5',
7
+ 'Raelina/Raena-XL-V2',
8
+ 'Raelina/Raehoshi-illust-XL',
9
+ 'Raelina/Raehoshi-illust-xl-2',
10
+ 'Raelina/Raehoshi-Illust-XL-2.1',
11
+ 'Raelina/Raehoshi-illust-XL-3',
12
+ 'Raelina/Raehoshi-illust-XL-4',
13
+ 'Raelina/Raehoshi-illust-XL-8',
14
+ "https://huggingface.co/Yntec/epiCPhotoGasm/blob/main/epiCPhotoGasmVAE.safetensors",
15
+ ]
16
+
17
+
18
+ sdxl_vaes = [
19
+ "madebyollin/sdxl-vae-fp16-fix",
20
+ "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
21
+ ]
22
+
23
+
24
+ sd15_vaes = []
25
+
26
+
27
+ STORAGE_MAX_GIB = 40
28
+ PIPELINE_MAX_GIB = 30
29
+ DEFAULT_DURATION = 0 # if 0, auto
30
+ IS_DEBUG = True
31
+
32
+ # kernels attention backend (Diffusers attention dispatcher)
33
+ # '' or None: disabled. 'auto': Hopper->'_flash_3_hub' else ATTENTION_BACKEND_NON_HOPPER.
34
+ ATTENTION_BACKEND = 'auto'
35
+ ATTENTION_BACKEND_NON_HOPPER = 'flash_hub'
36
+
37
+
38
+ # kernels hub prefetch (to avoid first-inference heavy download)
39
+ # Notes:
40
+ # - This does not remove the download requirement; it moves it to app startup.
41
+ # - Add more repos if you also use 'flash_hub' (FlashAttention2) or 'sage_hub'.
42
+ KERNELS_PREFETCH_ON_STARTUP = True
43
+ KERNELS_PREFETCH_REPOS = [
44
+ "kernels-community/flash-attn3",
45
+ # "kernels-community/flash-attn2",
46
+ # "kernels-community/sage_attention",
47
+ ]