| | import os |
| | from typing import Optional, List, Type |
| | import torch |
| | from library import sdxl_original_unet |
| | from library.utils import setup_logging |
| | setup_logging() |
| | import logging |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | SKIP_INPUT_BLOCKS = False |
| |
|
| | |
| | SKIP_OUTPUT_BLOCKS = True |
| |
|
| | |
| | SKIP_CONV2D = False |
| |
|
| | |
| | |
| | TRANSFORMER_ONLY = True |
| |
|
| | |
| | ATTN1_2_ONLY = True |
| |
|
| | |
| | ATTN_QKV_ONLY = True |
| |
|
| | |
| | |
| | ATTN1_ETC_ONLY = False |
| |
|
| | |
| | |
| | TRANSFORMER_MAX_BLOCK_INDEX = None |
| |
|
| |
|
| | class LLLiteModule(torch.nn.Module): |
| | def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0): |
| | super().__init__() |
| |
|
| | self.is_conv2d = org_module.__class__.__name__ == "Conv2d" |
| | self.lllite_name = name |
| | self.cond_emb_dim = cond_emb_dim |
| | self.org_module = [org_module] |
| | self.dropout = dropout |
| | self.multiplier = multiplier |
| |
|
| | if self.is_conv2d: |
| | in_dim = org_module.in_channels |
| | else: |
| | in_dim = org_module.in_features |
| |
|
| | |
| | |
| | modules = [] |
| | modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
| | if depth == 1: |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
| | elif depth == 2: |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) |
| | elif depth == 3: |
| | |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
| |
|
| | self.conditioning1 = torch.nn.Sequential(*modules) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.is_conv2d: |
| | self.down = torch.nn.Sequential( |
| | torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | self.mid = torch.nn.Sequential( |
| | torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | self.up = torch.nn.Sequential( |
| | torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), |
| | ) |
| | else: |
| | |
| | self.down = torch.nn.Sequential( |
| | torch.nn.Linear(in_dim, mlp_dim), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | self.mid = torch.nn.Sequential( |
| | torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | self.up = torch.nn.Sequential( |
| | torch.nn.Linear(mlp_dim, in_dim), |
| | ) |
| |
|
| | |
| | torch.nn.init.zeros_(self.up[0].weight) |
| |
|
| | self.depth = depth |
| | self.cond_emb = None |
| | self.batch_cond_only = False |
| | self.use_zeros_for_batch_uncond = False |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def set_cond_image(self, cond_image): |
| | r""" |
| | 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む |
| | / call the model inside, so if necessary, surround it with torch.no_grad() |
| | """ |
| | if cond_image is None: |
| | self.cond_emb = None |
| | return |
| |
|
| | |
| | |
| | cx = self.conditioning1(cond_image) |
| | if not self.is_conv2d: |
| | |
| | n, c, h, w = cx.shape |
| | cx = cx.view(n, c, h * w).permute(0, 2, 1) |
| | self.cond_emb = cx |
| |
|
| | def set_batch_cond_only(self, cond_only, zeros): |
| | self.batch_cond_only = cond_only |
| | self.use_zeros_for_batch_uncond = zeros |
| |
|
| | def apply_to(self): |
| | self.org_forward = self.org_module[0].forward |
| | self.org_module[0].forward = self.forward |
| |
|
| | def forward(self, x): |
| | r""" |
| | 学習用の便利forward。元のモジュールのforwardを呼び出す |
| | / convenient forward for training. call the forward of the original module |
| | """ |
| | if self.multiplier == 0.0 or self.cond_emb is None: |
| | return self.org_forward(x) |
| |
|
| | cx = self.cond_emb |
| |
|
| | if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: |
| | cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) |
| | if self.use_zeros_for_batch_uncond: |
| | cx[0::2] = 0.0 |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) |
| | cx = self.mid(cx) |
| |
|
| | if self.dropout is not None and self.training: |
| | cx = torch.nn.functional.dropout(cx, p=self.dropout) |
| |
|
| | cx = self.up(cx) * self.multiplier |
| |
|
| | |
| | if self.batch_cond_only: |
| | zx = torch.zeros_like(x) |
| | zx[1::2] += cx |
| | cx = zx |
| |
|
| | x = self.org_forward(x + cx) |
| | return x |
| |
|
| |
|
| | class ControlNetLLLite(torch.nn.Module): |
| | UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] |
| | UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
| |
|
| | def __init__( |
| | self, |
| | unet: sdxl_original_unet.SdxlUNet2DConditionModel, |
| | cond_emb_dim: int = 16, |
| | mlp_dim: int = 16, |
| | dropout: Optional[float] = None, |
| | varbose: Optional[bool] = False, |
| | multiplier: Optional[float] = 1.0, |
| | ) -> None: |
| | super().__init__() |
| | |
| |
|
| | def create_modules( |
| | root_module: torch.nn.Module, |
| | target_replace_modules: List[torch.nn.Module], |
| | module_class: Type[object], |
| | ) -> List[torch.nn.Module]: |
| | prefix = "lllite_unet" |
| |
|
| | modules = [] |
| | for name, module in root_module.named_modules(): |
| | if module.__class__.__name__ in target_replace_modules: |
| | for child_name, child_module in module.named_modules(): |
| | is_linear = child_module.__class__.__name__ == "Linear" |
| | is_conv2d = child_module.__class__.__name__ == "Conv2d" |
| |
|
| | if is_linear or (is_conv2d and not SKIP_CONV2D): |
| | |
| | |
| | block_name, index1, index2 = (name + "." + child_name).split(".")[:3] |
| | index1 = int(index1) |
| | if block_name == "input_blocks": |
| | if SKIP_INPUT_BLOCKS: |
| | continue |
| | depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) |
| | elif block_name == "middle_block": |
| | depth = 3 |
| | elif block_name == "output_blocks": |
| | if SKIP_OUTPUT_BLOCKS: |
| | continue |
| | depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) |
| | if int(index2) >= 2: |
| | depth -= 1 |
| | else: |
| | raise NotImplementedError() |
| |
|
| | lllite_name = prefix + "." + name + "." + child_name |
| | lllite_name = lllite_name.replace(".", "_") |
| |
|
| | if TRANSFORMER_MAX_BLOCK_INDEX is not None: |
| | p = lllite_name.find("transformer_blocks") |
| | if p >= 0: |
| | tf_index = int(lllite_name[p:].split("_")[2]) |
| | if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | if "emb_layers" in lllite_name or ( |
| | "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) |
| | ): |
| | continue |
| |
|
| | if ATTN1_2_ONLY: |
| | if not ("attn1" in lllite_name or "attn2" in lllite_name): |
| | continue |
| | if ATTN_QKV_ONLY: |
| | if "to_out" in lllite_name: |
| | continue |
| |
|
| | if ATTN1_ETC_ONLY: |
| | if "proj_out" in lllite_name: |
| | pass |
| | elif "attn1" in lllite_name and ( |
| | "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name |
| | ): |
| | pass |
| | elif "ff_net_2" in lllite_name: |
| | pass |
| | else: |
| | continue |
| |
|
| | module = module_class( |
| | depth, |
| | cond_emb_dim, |
| | lllite_name, |
| | child_module, |
| | mlp_dim, |
| | dropout=dropout, |
| | multiplier=multiplier, |
| | ) |
| | modules.append(module) |
| | return modules |
| |
|
| | target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE |
| | if not TRANSFORMER_ONLY: |
| | target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
| |
|
| | |
| | self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) |
| | logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") |
| |
|
| | def forward(self, x): |
| | return x |
| |
|
| | def set_cond_image(self, cond_image): |
| | r""" |
| | 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む |
| | / call the model inside, so if necessary, surround it with torch.no_grad() |
| | """ |
| | for module in self.unet_modules: |
| | module.set_cond_image(cond_image) |
| |
|
| | def set_batch_cond_only(self, cond_only, zeros): |
| | for module in self.unet_modules: |
| | module.set_batch_cond_only(cond_only, zeros) |
| |
|
| | def set_multiplier(self, multiplier): |
| | for module in self.unet_modules: |
| | module.multiplier = multiplier |
| |
|
| | def load_weights(self, file): |
| | if os.path.splitext(file)[1] == ".safetensors": |
| | from safetensors.torch import load_file |
| |
|
| | weights_sd = load_file(file) |
| | else: |
| | weights_sd = torch.load(file, map_location="cpu") |
| |
|
| | info = self.load_state_dict(weights_sd, False) |
| | return info |
| |
|
| | def apply_to(self): |
| | logger.info("applying LLLite for U-Net...") |
| | for module in self.unet_modules: |
| | module.apply_to() |
| | self.add_module(module.lllite_name, module) |
| |
|
| | |
| | def is_mergeable(self): |
| | return False |
| |
|
| | def merge_to(self, text_encoder, unet, weights_sd, dtype, device): |
| | raise NotImplementedError() |
| |
|
| | def enable_gradient_checkpointing(self): |
| | |
| | pass |
| |
|
| | def prepare_optimizer_params(self): |
| | self.requires_grad_(True) |
| | return self.parameters() |
| |
|
| | def prepare_grad_etc(self): |
| | self.requires_grad_(True) |
| |
|
| | def on_epoch_start(self): |
| | self.train() |
| |
|
| | def get_trainable_params(self): |
| | return self.parameters() |
| |
|
| | def save_weights(self, file, dtype, metadata): |
| | if metadata is not None and len(metadata) == 0: |
| | metadata = None |
| |
|
| | state_dict = self.state_dict() |
| |
|
| | if dtype is not None: |
| | for key in list(state_dict.keys()): |
| | v = state_dict[key] |
| | v = v.detach().clone().to("cpu").to(dtype) |
| | state_dict[key] = v |
| |
|
| | if os.path.splitext(file)[1] == ".safetensors": |
| | from safetensors.torch import save_file |
| |
|
| | save_file(state_dict, file, metadata) |
| | else: |
| | torch.save(state_dict, file) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| |
|
| | |
| |
|
| | |
| | logger.info("create unet") |
| | unet = sdxl_original_unet.SdxlUNet2DConditionModel() |
| | unet.to("cuda").to(torch.float16) |
| |
|
| | logger.info("create ControlNet-LLLite") |
| | control_net = ControlNetLLLite(unet, 32, 64) |
| | control_net.apply_to() |
| | control_net.to("cuda") |
| |
|
| | logger.info(control_net) |
| |
|
| | |
| | logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") |
| |
|
| | input() |
| |
|
| | unet.set_use_memory_efficient_attention(True, False) |
| | unet.set_gradient_checkpointing(True) |
| | unet.train() |
| |
|
| | control_net.train() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import bitsandbytes |
| |
|
| | optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3) |
| |
|
| | scaler = torch.cuda.amp.GradScaler(enabled=True) |
| |
|
| | logger.info("start training") |
| | steps = 10 |
| |
|
| | sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] |
| | for step in range(steps): |
| | logger.info(f"step {step}") |
| |
|
| | batch_size = 1 |
| | conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 |
| | x = torch.randn(batch_size, 4, 128, 128).cuda() |
| | t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() |
| | ctx = torch.randn(batch_size, 77, 2048).cuda() |
| | y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() |
| |
|
| | with torch.cuda.amp.autocast(enabled=True): |
| | control_net.set_cond_image(conditioning_image) |
| |
|
| | output = unet(x, t, ctx, y) |
| | target = torch.randn_like(output) |
| | loss = torch.nn.functional.mse_loss(output, target) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad(set_to_none=True) |
| | logger.info(f"{sample_param}") |
| |
|
| | |
| |
|
| | |
| |
|