import gradio as gr import torch import os import gc import shutil import requests import json import struct import numpy as np import re from pathlib import Path from typing import Dict, Any, Optional, List from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from tqdm import tqdm # --- Memory Efficient Safetensors --- class MemoryEfficientSafeOpen: def __init__(self, filename): self.filename = filename self.file = open(filename, "rb") self.header, self.header_size = self._read_header() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"] def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack("= 2: parts = input_str.split("/") repo_id = f"{parts[0]}/{parts[1]}" filename = "/".join(parts[2:]) hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) found = list(TempDir.rglob(filename.split("/")[-1]))[0] if found != local_path: shutil.move(found, local_path) return local_path # Standard Auto-Discovery candidates = ["adapter_model.safetensors", "model.safetensors"] files = list_repo_files(repo_id=input_str, token=token) target = next((f for f in files if f in candidates), None) if not target: safes = [f for f in files if f.endswith(".safetensors")] if safes: target = safes[0] if not target: raise ValueError("No safetensors found") hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir) found = list(TempDir.rglob(target.split("/")[-1]))[0] if found != local_path: shutil.move(found, local_path) return local_path except Exception as e: # 3. Last Resort: Raw Requests (For non-HF links) if input_str.startswith("http"): try: headers = {"Authorization": f"Bearer {token}"} if token else {} r = requests.get(input_str, stream=True, headers=headers, timeout=60) r.raise_for_status() with open(local_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) return local_path except Exception as req_e: raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}") raise e def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): print(f"Loading LoRA from {lora_path}...") state_dict = load_file(lora_path, device="cpu") pairs = {} alphas = {} for k, v in state_dict.items(): stem = get_key_stem(k) if "alpha" in k: alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v else: if stem not in pairs: pairs[stem] = {} if "lora_down" in k or "lora_A" in k: pairs[stem]["down"] = v.to(dtype=precision_dtype) pairs[stem]["rank"] = v.shape[0] elif "lora_up" in k or "lora_B" in k: pairs[stem]["up"] = v.to(dtype=precision_dtype) for stem in pairs: pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) return pairs class ShardBuffer: def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): self.max_bytes = int(max_size_gb * 1024**3) self.output_dir = output_dir self.output_repo = output_repo self.subfolder = subfolder self.hf_token = hf_token self.filename_prefix = filename_prefix self.buffer = [] self.current_bytes = 0 self.shard_count = 0 self.index_map = {} self.total_size = 0 def add_tensor(self, key, tensor): if tensor.dtype == torch.bfloat16: raw_bytes = tensor.view(torch.int16).numpy().tobytes() dtype_str = "BF16" elif tensor.dtype == torch.float16: raw_bytes = tensor.numpy().tobytes() dtype_str = "F16" else: raw_bytes = tensor.numpy().tobytes() dtype_str = "F32" size = len(raw_bytes) self.buffer.append({ "key": key, "data": raw_bytes, "dtype": dtype_str, "shape": tensor.shape }) self.current_bytes += size self.total_size += size if self.current_bytes >= self.max_bytes: self.flush() def flush(self): if not self.buffer: return self.shard_count += 1 filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...") header = {"__metadata__": {"format": "pt"}} current_offset = 0 for item in self.buffer: header[item["key"]] = { "dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])] } current_offset += len(item["data"]) self.index_map[item["key"]] = filename header_json = json.dumps(header).encode('utf-8') out_path = self.output_dir / filename with open(out_path, 'wb') as f: f.write(struct.pack(' downloads specific file. 2. If input is a Repo ID -> scans for diffusers format (unet/transformer) or standard safetensors. """ print(f"Resolving model input: {input_str}") # --- STRATEGY A: Direct URL --- repo_id_from_url, filename_from_url = parse_hf_url(input_str) if repo_id_from_url and filename_from_url: print(f"Detected Direct Link. Repo: {repo_id_from_url}, File: {filename_from_url}") local_path = TempDir / os.path.basename(filename_from_url) # Clean up previous download if name conflicts if local_path.exists(): os.remove(local_path) try: hf_hub_download(repo_id=repo_id_from_url, filename=filename_from_url, token=token, local_dir=TempDir) # Find where it landed (handling subfolders in local_dir) found = list(TempDir.rglob(os.path.basename(filename_from_url)))[0] return found except Exception as e: print(f"URL Download failed: {e}. Trying fallback...") # --- STRATEGY B: Repo Discovery (Auto-Detect) --- # If we are here, input_str is treated as a Repo ID (e.g. "ostris/Z-Image-De-Turbo") print(f"Scanning Repo {input_str} for model weights...") try: files = list_repo_files(repo_id=input_str, token=token) except Exception as e: raise ValueError(f"Failed to list repo '{input_str}'. If this is a URL, ensure it is formatted correctly. Error: {e}") # Priority list for diffusers vs single file priorities = [ "transformer/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors", "model.safetensors", # Fallback to any safetensors that isn't an adapter or lora lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f ] target_file = None for p in priorities: if callable(p): candidates = [f for f in files if p(f)] if candidates: # Pick the largest file if multiple candidates (heuristic for "main" model) target_file = candidates[0] break elif p in files: target_file = p break if not target_file: raise ValueError(f"Could not find a valid model weight file in {input_str}. Ensure it contains .safetensors weights.") print(f"Downloading auto-detected weight file: {target_file}") hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir) # Locate actual path found = list(TempDir.rglob(os.path.basename(target_file)))[0] return found def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): org = MemoryEfficientSafeOpen(model_org) tuned = MemoryEfficientSafeOpen(model_tuned) lora_sd = {} print("Calculating diffs & extracting LoRA...") # Get intersection of keys keys = set(org.keys()).intersection(set(tuned.keys())) for key in tqdm(keys, desc="Extracting"): # Skip integer buffers/metadata if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key: continue mat_org = org.get_tensor(key).float() mat_tuned = tuned.get_tensor(key).float() # Skip if shapes mismatch (shouldn't happen if models match) if mat_org.shape != mat_tuned.shape: continue diff = mat_tuned - mat_org # Skip if no difference if torch.max(torch.abs(diff)) < 1e-4: continue out_dim = diff.shape[0] in_dim = diff.shape[1] if len(diff.shape) > 1 else 1 r = min(rank, in_dim, out_dim) is_conv = len(diff.shape) == 4 if is_conv: diff = diff.flatten(start_dim=1) elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed try: # Use svd_lowrank for massive speedup on CPU vs linalg.svd U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4) Vh = V.t() U = U[:, :r] S = S[:r] Vh = Vh[:r, :] # Merge S into U for standard LoRA format U = U @ torch.diag(S) # Clamp outliers dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(torch.abs(dist), clamp) if hi_val > 0: U = U.clamp(-hi_val, hi_val) Vh = Vh.clamp(-hi_val, hi_val) if is_conv: U = U.reshape(out_dim, r, 1, 1) Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) else: U = U.reshape(out_dim, r) Vh = Vh.reshape(r, in_dim) stem = key.replace(".weight", "") lora_sd[f"{stem}.lora_up.weight"] = U.contiguous() lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous() lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() except Exception as e: print(f"Skipping {key} due to error: {e}") pass out = TempDir / "extracted.safetensors" save_file(lora_sd, out) return str(out) def task_extract(hf_token, org, tun, rank, out): cleanup_temp() if hf_token: login(hf_token.strip()) try: print("Downloading Original Model...") p1 = identify_and_download_model(org, hf_token) print("Downloading Tuned Model...") p2 = identify_and_download_model(tun, hf_token) f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) api.create_repo(repo_id=out, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token) return "Done! Extracted to " + out except Exception as e: return f"Error: {e}" # ================================================================================= # TAB 3: MERGE ADAPTERS (EMA) # ================================================================================= def sigma_rel_to_gamma(sigma_rel): t = sigma_rel**-2 coeffs = [1, 7, 16 - t, 12 - t] roots = np.roots(coeffs) gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() return gamma def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo): cleanup_temp() if hf_token: login(hf_token.strip()) urls = [u.strip() for u in lora_urls.split(",") if u.strip()] paths = [] try: for i, url in enumerate(urls): paths.append(download_lora_smart(url, hf_token)) except Exception as e: return f"Download Error: {e}" if not paths: return "No models found" base_sd = load_file(paths[0], device="cpu") for k in base_sd: if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() gamma = None if sigma_rel > 0: gamma = sigma_rel_to_gamma(sigma_rel) for i, path in enumerate(paths[1:]): print(f"Merging {path}") if gamma is not None: t = i + 1 current_beta = (1 - 1 / t) ** (gamma + 1) else: current_beta = beta curr = load_file(path, device="cpu") for k in base_sd: if k in curr and "alpha" not in k: base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) out = TempDir / "merged_adapters.safetensors" save_file(base_sd, out) api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) return "Done" # ================================================================================= # TAB 4: RESIZE (CPU Optimized) # ================================================================================= def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): cleanup_temp() if not hf_token: return "Error: Token required" login(hf_token.strip()) try: path = download_lora_smart(lora_input, hf_token) except Exception as e: return f"Error: {e}" state = load_file(path, device="cpu") new_state = {} groups = {} for k in state: stem = get_key_stem(k) simple = k.split(".lora_")[0] if simple not in groups: groups[simple] = {} if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] if "alpha" in k: groups[simple]["alpha"] = state[k] print(f"Resizing {len(groups)} blocks...") for stem, g in tqdm(groups.items()): if "down" in g and "up" in g: down, up = g["down"].float(), g["up"].float() # 1. Merge Up/Down if len(down.shape) == 4: merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) flat = merged.flatten(1) else: merged = up @ down flat = merged # 2. FAST SVD (svd_lowrank) target_rank = int(new_rank) # Add buffer to q to ensure convergence q = min(target_rank + 10, min(flat.shape)) U, S, V = torch.svd_lowrank(flat, q=q) Vh = V.t() # 3. Dynamic Rank Selection if dynamic_method == "sv_ratio": target_rank = index_sv_ratio(S, dynamic_param) # Hard limit by user's max rank target_rank = min(target_rank, int(new_rank), S.shape[0]) # 4. Truncate U = U[:, :target_rank] S = S[:target_rank] Vh = Vh[:target_rank, :] # 5. Reconstruct Up Matrix U = U @ torch.diag(S) if len(down.shape) == 4: U = U.reshape(up.shape[0], target_rank, 1, 1) Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3]) # 6. Save (FIX: Enforce contiguous memory layout) new_state[f"{stem}.lora_down.weight"] = Vh.contiguous() new_state[f"{stem}.lora_up.weight"] = U.contiguous() new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float() out = TempDir / "resized.safetensors" # safetensors requires contiguous tensors save_file(new_state, out) api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token) return "Done" # ================================================================================= # UI # ================================================================================= css = ".container { max-width: 900px; margin: auto; }" with gr.Blocks() as demo: gr.Markdown("# 🧰SOONmerge® LoRA Toolkit") with gr.Tabs(): with gr.Tab("Merge to Base + Reshard Output"): t1_token = gr.Textbox(label="Token", type="password") t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo") t1_sub = gr.Textbox(label="Subfolder", value="transformer") t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") with gr.Row(): t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) t1_out = gr.Textbox(label="Output Repo") t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo") t1_priv = gr.Checkbox(label="Private", value=True) t1_btn = gr.Button("Merge") t1_res = gr.Textbox(label="Result") t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) with gr.Tab("Extract Adapter"): t2_token = gr.Textbox(label="Token", type="password") t2_org = gr.Textbox(label="Original Model") t2_tun = gr.Textbox(label="Tuned Model") t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) t2_out = gr.Textbox(label="Output Repo") t2_btn = gr.Button("Extract") t2_res = gr.Textbox(label="Result") t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) with gr.Tab("Merge Multiple Adapters"): t3_token = gr.Textbox(label="Token", type="password") t3_urls = gr.Textbox(label="URLs") with gr.Row(): t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01) t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01) t3_out = gr.Textbox(label="Output Repo") t3_btn = gr.Button("Merge") t3_res = gr.Textbox(label="Result") t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res) with gr.Tab("Resize Adapter"): t4_token = gr.Textbox(label="Token", type="password") t4_in = gr.Textbox(label="LoRA") with gr.Row(): t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1) t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method") t4_param = gr.Number(label="Dynamic Param", value=4.0) t4_out = gr.Textbox(label="Output") t4_btn = gr.Button("Resize") t4_res = gr.Textbox(label="Result") t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) if __name__ == "__main__": demo.queue().launch(css=css, ssr_mode=False)