|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import gc |
|
|
import shutil |
|
|
import requests |
|
|
import json |
|
|
import struct |
|
|
import numpy as np |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional, List |
|
|
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login |
|
|
from safetensors.torch import load_file, save_file |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class MemoryEfficientSafeOpen: |
|
|
""" |
|
|
Reads safetensors metadata and tensors without mmap, keeping RAM usage low. |
|
|
""" |
|
|
def __init__(self, filename): |
|
|
self.filename = filename |
|
|
self.file = open(filename, "rb") |
|
|
self.header, self.header_size = self._read_header() |
|
|
|
|
|
def __enter__(self): |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.file.close() |
|
|
|
|
|
def keys(self) -> list[str]: |
|
|
return [k for k in self.header.keys() if k != "__metadata__"] |
|
|
|
|
|
def metadata(self) -> Dict[str, str]: |
|
|
return self.header.get("__metadata__", {}) |
|
|
|
|
|
def get_tensor(self, key): |
|
|
if key not in self.header: |
|
|
raise KeyError(f"Tensor '{key}' not found in the file") |
|
|
metadata = self.header[key] |
|
|
offset_start, offset_end = metadata["data_offsets"] |
|
|
self.file.seek(self.header_size + 8 + offset_start) |
|
|
tensor_bytes = self.file.read(offset_end - offset_start) |
|
|
return self._deserialize_tensor(tensor_bytes, metadata) |
|
|
|
|
|
def _read_header(self): |
|
|
header_size = struct.unpack("<Q", self.file.read(8))[0] |
|
|
header_json = self.file.read(header_size).decode("utf-8") |
|
|
return json.loads(header_json), header_size |
|
|
|
|
|
def _deserialize_tensor(self, tensor_bytes, metadata): |
|
|
dtype_map = { |
|
|
"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, |
|
|
"I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, |
|
|
"U8": torch.uint8, "BOOL": torch.bool |
|
|
} |
|
|
dtype = dtype_map[metadata["dtype"]] |
|
|
shape = metadata["shape"] |
|
|
return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape) |
|
|
|
|
|
|
|
|
try: |
|
|
TempDir = Path("/tmp/temp_tool") |
|
|
os.makedirs(TempDir, exist_ok=True) |
|
|
except: |
|
|
TempDir = Path("./temp_tool") |
|
|
os.makedirs(TempDir, exist_ok=True) |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
def cleanup_temp(): |
|
|
if TempDir.exists(): |
|
|
shutil.rmtree(TempDir) |
|
|
os.makedirs(TempDir, exist_ok=True) |
|
|
gc.collect() |
|
|
|
|
|
def download_file(input_path, token, filename=None): |
|
|
local_path = TempDir / (filename if filename else "model.safetensors") |
|
|
if input_path.startswith("http"): |
|
|
print(f"Downloading {filename} from URL...") |
|
|
try: |
|
|
response = requests.get(input_path, stream=True, timeout=30) |
|
|
response.raise_for_status() |
|
|
with open(local_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
except Exception as e: raise ValueError(f"Download failed: {e}") |
|
|
else: |
|
|
print(f"Downloading {filename} from Hub...") |
|
|
if not filename: |
|
|
try: |
|
|
files = list_repo_files(repo_id=input_path, token=token) |
|
|
safetensors = [f for f in files if f.endswith(".safetensors")] |
|
|
filename = safetensors[0] if safetensors else "adapter_model.safetensors" |
|
|
except: filename = "adapter_model.safetensors" |
|
|
|
|
|
try: |
|
|
hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False) |
|
|
if not (TempDir / filename).exists(): |
|
|
found = list(TempDir.rglob(filename)) |
|
|
if found: shutil.move(found[0], local_path) |
|
|
except Exception as e: raise ValueError(f"Hub download failed: {e}") |
|
|
|
|
|
return local_path |
|
|
|
|
|
def get_key_stem(key): |
|
|
key = key.replace(".weight", "").replace(".bias", "") |
|
|
key = key.replace(".lora_down", "").replace(".lora_up", "") |
|
|
key = key.replace(".lora_A", "").replace(".lora_B", "") |
|
|
key = key.replace(".alpha", "") |
|
|
prefixes = [ |
|
|
"model.diffusion_model.", "diffusion_model.", "model.", |
|
|
"transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model." |
|
|
] |
|
|
changed = True |
|
|
while changed: |
|
|
changed = False |
|
|
for p in prefixes: |
|
|
if key.startswith(p): |
|
|
key = key[len(p):] |
|
|
changed = True |
|
|
return key |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): |
|
|
print(f"Loading LoRA from {lora_path}...") |
|
|
state_dict = load_file(lora_path, device="cpu") |
|
|
pairs = {} |
|
|
alphas = {} |
|
|
for k, v in state_dict.items(): |
|
|
stem = get_key_stem(k) |
|
|
if "alpha" in k: |
|
|
alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v |
|
|
else: |
|
|
if stem not in pairs: pairs[stem] = {} |
|
|
if "lora_down" in k or "lora_A" in k: |
|
|
pairs[stem]["down"] = v.to(dtype=precision_dtype) |
|
|
pairs[stem]["rank"] = v.shape[0] |
|
|
elif "lora_up" in k or "lora_B" in k: |
|
|
pairs[stem]["up"] = v.to(dtype=precision_dtype) |
|
|
for stem in pairs: |
|
|
pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) |
|
|
return pairs |
|
|
|
|
|
class ShardBuffer: |
|
|
def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): |
|
|
self.max_bytes = int(max_size_gb * 1024**3) |
|
|
self.output_dir = output_dir |
|
|
self.output_repo = output_repo |
|
|
self.subfolder = subfolder |
|
|
self.hf_token = hf_token |
|
|
self.filename_prefix = filename_prefix |
|
|
self.buffer = [] |
|
|
self.current_bytes = 0 |
|
|
self.shard_count = 0 |
|
|
self.index_map = {} |
|
|
self.total_size = 0 |
|
|
|
|
|
def add_tensor(self, key, tensor): |
|
|
|
|
|
if tensor.dtype == torch.bfloat16: |
|
|
raw_bytes = tensor.view(torch.int16).numpy().tobytes() |
|
|
dtype_str = "BF16" |
|
|
elif tensor.dtype == torch.float16: |
|
|
raw_bytes = tensor.numpy().tobytes() |
|
|
dtype_str = "F16" |
|
|
else: |
|
|
raw_bytes = tensor.numpy().tobytes() |
|
|
dtype_str = "F32" |
|
|
|
|
|
size = len(raw_bytes) |
|
|
|
|
|
self.buffer.append({ |
|
|
"key": key, |
|
|
"data": raw_bytes, |
|
|
"dtype": dtype_str, |
|
|
"shape": tensor.shape |
|
|
}) |
|
|
|
|
|
self.current_bytes += size |
|
|
self.total_size += size |
|
|
|
|
|
if self.current_bytes >= self.max_bytes: |
|
|
self.flush() |
|
|
|
|
|
def flush(self): |
|
|
if not self.buffer: return |
|
|
self.shard_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" |
|
|
|
|
|
|
|
|
path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename |
|
|
|
|
|
print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...") |
|
|
|
|
|
header = {"__metadata__": {"format": "pt"}} |
|
|
current_offset = 0 |
|
|
for item in self.buffer: |
|
|
header[item["key"]] = { |
|
|
"dtype": item["dtype"], |
|
|
"shape": item["shape"], |
|
|
"data_offsets": [current_offset, current_offset + len(item["data"])] |
|
|
} |
|
|
current_offset += len(item["data"]) |
|
|
self.index_map[item["key"]] = filename |
|
|
|
|
|
header_json = json.dumps(header).encode('utf-8') |
|
|
|
|
|
out_path = self.output_dir / filename |
|
|
with open(out_path, 'wb') as f: |
|
|
f.write(struct.pack('<Q', len(header_json))) |
|
|
f.write(header_json) |
|
|
for item in self.buffer: |
|
|
f.write(item["data"]) |
|
|
|
|
|
print(f"Uploading {path_in_repo}...") |
|
|
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token) |
|
|
|
|
|
os.remove(out_path) |
|
|
self.buffer = [] |
|
|
self.current_bytes = 0 |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_lora_smart(input_str, token): |
|
|
"""Robust LoRA downloader that handles Direct URLs and Repo IDs.""" |
|
|
local_path = TempDir / "adapter.safetensors" |
|
|
if local_path.exists(): os.remove(local_path) |
|
|
|
|
|
|
|
|
if input_str.startswith("http"): |
|
|
print(f"Downloading LoRA from URL: {input_str}") |
|
|
headers = {"Authorization": f"Bearer {token}"} if token else {} |
|
|
try: |
|
|
response = requests.get(input_str, stream=True, headers=headers, timeout=60) |
|
|
response.raise_for_status() |
|
|
with open(local_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
if verify_safetensors(local_path): return local_path |
|
|
except Exception as e: |
|
|
print(f"URL download failed: {e}. Trying as Repo ID...") |
|
|
|
|
|
|
|
|
print(f"Attempting download from Hub Repo: {input_str}") |
|
|
try: |
|
|
|
|
|
if ".safetensors" in input_str and "/" in input_str: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
files = list_repo_files(repo_id=input_str, token=token) |
|
|
candidates = ["adapter_model.safetensors", "model.safetensors"] |
|
|
target = next((f for f in files if f in candidates), None) |
|
|
|
|
|
|
|
|
if not target: |
|
|
safes = [f for f in files if f.endswith(".safetensors")] |
|
|
if safes: target = safes[0] |
|
|
|
|
|
if not target: raise ValueError("No .safetensors found") |
|
|
|
|
|
hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir) |
|
|
|
|
|
|
|
|
downloaded = TempDir / target |
|
|
if downloaded != local_path: |
|
|
shutil.move(downloaded, local_path) |
|
|
|
|
|
return local_path |
|
|
except Exception as e: |
|
|
raise ValueError(f"Could not download LoRA. Checked URL and Repo. Error: {e}") |
|
|
|
|
|
def get_tensor_byte_size(shape, dtype_str): |
|
|
"""Calculates byte size of a tensor based on shape and dtype.""" |
|
|
|
|
|
bytes_per = 4 if "F32" in dtype_str else 2 if "16" in dtype_str else 1 |
|
|
numel = 1 |
|
|
for d in shape: numel *= d |
|
|
return numel * bytes_per |
|
|
|
|
|
def plan_resharding(input_shards, max_shard_size_gb, filename_prefix): |
|
|
""" |
|
|
Pass 1: Reads headers ONLY. Groups tensors into virtual shards of max_shard_size_gb. |
|
|
Returns a Plan (List of ShardDefinitions). |
|
|
""" |
|
|
print(f"Planning resharding (Max {max_shard_size_gb} GB)...") |
|
|
max_bytes = int(max_shard_size_gb * 1024**3) |
|
|
|
|
|
all_tensors = [] |
|
|
|
|
|
|
|
|
for p in input_shards: |
|
|
with MemoryEfficientSafeOpen(p) as f: |
|
|
for k in f.keys(): |
|
|
shape = f.header[k]['shape'] |
|
|
dtype = f.header[k]['dtype'] |
|
|
size = get_tensor_byte_size(shape, dtype) |
|
|
all_tensors.append({ |
|
|
"key": k, |
|
|
"shape": shape, |
|
|
"dtype": dtype, |
|
|
"size": size, |
|
|
"source": p |
|
|
}) |
|
|
|
|
|
|
|
|
all_tensors.sort(key=lambda x: x["key"]) |
|
|
|
|
|
|
|
|
plan = [] |
|
|
current_shard = [] |
|
|
current_size = 0 |
|
|
|
|
|
for t in all_tensors: |
|
|
|
|
|
if current_size + t['size'] > max_bytes and current_shard: |
|
|
plan.append(current_shard) |
|
|
current_shard = [] |
|
|
current_size = 0 |
|
|
|
|
|
current_shard.append(t) |
|
|
current_size += t['size'] |
|
|
|
|
|
if current_shard: |
|
|
plan.append(current_shard) |
|
|
|
|
|
total_shards = len(plan) |
|
|
total_model_size = sum(t['size'] for shard in plan for t in shard) |
|
|
|
|
|
print(f"Plan created: {total_shards} shards. Total size: {total_model_size / 1024**3:.2f} GB") |
|
|
|
|
|
|
|
|
final_plan = [] |
|
|
for i, shard_tensors in enumerate(plan): |
|
|
|
|
|
name = f"{filename_prefix}-{i+1:05d}-of-{total_shards:05d}.safetensors" |
|
|
final_plan.append({ |
|
|
"filename": name, |
|
|
"tensors": shard_tensors |
|
|
}) |
|
|
|
|
|
return final_plan, total_model_size |
|
|
|
|
|
def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder): |
|
|
""" |
|
|
Downloads NON-WEIGHT files (json, txt, model) from Base Repo and uploads to Output. |
|
|
""" |
|
|
print(f"Copying config files from {base_repo}...") |
|
|
try: |
|
|
files = list_repo_files(repo_id=base_repo, token=hf_token) |
|
|
|
|
|
|
|
|
allowed_ext = ['.json', '.txt', '.model', '.py', '.yml', '.yaml'] |
|
|
|
|
|
blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5'] |
|
|
|
|
|
for f in files: |
|
|
|
|
|
if base_subfolder and not f.startswith(base_subfolder): |
|
|
continue |
|
|
|
|
|
ext = os.path.splitext(f)[1] |
|
|
if ext in blocked_ext: continue |
|
|
if ext not in allowed_ext: continue |
|
|
|
|
|
|
|
|
print(f"Transferring {f}...") |
|
|
local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir) |
|
|
|
|
|
|
|
|
if base_subfolder: |
|
|
|
|
|
rel_name = f[len(base_subfolder):].lstrip('/') |
|
|
else: |
|
|
rel_name = f |
|
|
|
|
|
|
|
|
target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name |
|
|
|
|
|
api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token) |
|
|
os.remove(local) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Config copy warning: {e}") |
|
|
|
|
|
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()): |
|
|
cleanup_temp() |
|
|
|
|
|
if not hf_token: return "Error: Token missing." |
|
|
login(hf_token) |
|
|
|
|
|
|
|
|
try: |
|
|
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token) |
|
|
except Exception as e: return f"Error creating repo: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if base_subfolder: |
|
|
output_subfolder = "transformer" if "qint" in base_subfolder or "transformer" in base_subfolder else base_subfolder |
|
|
else: |
|
|
output_subfolder = "" |
|
|
|
|
|
|
|
|
copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder) |
|
|
|
|
|
|
|
|
if structure_repo: |
|
|
print(f"Copying extras from {structure_repo}...") |
|
|
|
|
|
|
|
|
|
|
|
streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix="transformer") |
|
|
|
|
|
|
|
|
progress(0.1, desc="Downloading Input Model...") |
|
|
files = list_repo_files(repo_id=base_repo, token=hf_token) |
|
|
input_shards = [] |
|
|
|
|
|
for f in files: |
|
|
if f.endswith(".safetensors"): |
|
|
if base_subfolder and not f.startswith(base_subfolder): continue |
|
|
|
|
|
local = TempDir / "inputs" / os.path.basename(f) |
|
|
os.makedirs(local.parent, exist_ok=True) |
|
|
hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False) |
|
|
|
|
|
|
|
|
found = list(local.parent.rglob(os.path.basename(f))) |
|
|
if found: input_shards.append(found[0]) |
|
|
|
|
|
if not input_shards: return "No safetensors found." |
|
|
input_shards.sort() |
|
|
|
|
|
|
|
|
sample_name = os.path.basename(input_shards[0]) |
|
|
if "diffusion_pytorch_model" in sample_name or output_subfolder == "transformer": |
|
|
prefix = "diffusion_pytorch_model" |
|
|
index_file = "diffusion_pytorch_model.safetensors.index.json" |
|
|
else: |
|
|
prefix = "model" |
|
|
index_file = "model.safetensors.index.json" |
|
|
|
|
|
|
|
|
|
|
|
progress(0.2, desc="Planning Shards...") |
|
|
plan, total_model_size = plan_resharding(input_shards, shard_size, prefix) |
|
|
|
|
|
|
|
|
dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32 |
|
|
try: |
|
|
progress(0.25, desc="Loading LoRA...") |
|
|
lora_path = download_lora_smart(lora_input, hf_token) |
|
|
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype) |
|
|
except Exception as e: return f"LoRA Error: {e}" |
|
|
|
|
|
|
|
|
index_map = {} |
|
|
|
|
|
for i, shard_plan in enumerate(plan): |
|
|
filename = shard_plan['filename'] |
|
|
tensors_to_write = shard_plan['tensors'] |
|
|
|
|
|
progress(0.3 + (0.7 * i / len(plan)), desc=f"Merging {filename}") |
|
|
print(f"Generating {filename} ({len(tensors_to_write)} tensors)...") |
|
|
|
|
|
|
|
|
header = {"__metadata__": {"format": "pt"}} |
|
|
current_offset = 0 |
|
|
for t in tensors_to_write: |
|
|
|
|
|
tgt_dtype_str = "BF16" if dtype == torch.bfloat16 else "F16" if dtype == torch.float16 else "F32" |
|
|
|
|
|
|
|
|
|
|
|
out_size = get_tensor_byte_size(t['shape'], tgt_dtype_str) |
|
|
|
|
|
header[t['key']] = { |
|
|
"dtype": tgt_dtype_str, |
|
|
"shape": t['shape'], |
|
|
"data_offsets": [current_offset, current_offset + out_size] |
|
|
} |
|
|
current_offset += out_size |
|
|
index_map[t['key']] = filename |
|
|
|
|
|
header_json = json.dumps(header).encode('utf-8') |
|
|
|
|
|
out_path = TempDir / filename |
|
|
with open(out_path, 'wb') as f_out: |
|
|
f_out.write(struct.pack('<Q', len(header_json))) |
|
|
f_out.write(header_json) |
|
|
|
|
|
|
|
|
open_files = {} |
|
|
|
|
|
for t_plan in tqdm(tensors_to_write, leave=False): |
|
|
src = t_plan['source'] |
|
|
if src not in open_files: open_files[src] = MemoryEfficientSafeOpen(src) |
|
|
|
|
|
|
|
|
v = open_files[src].get_tensor(t_plan['key']) |
|
|
k = t_plan['key'] |
|
|
|
|
|
|
|
|
base_stem = get_key_stem(k) |
|
|
match = None |
|
|
|
|
|
|
|
|
if base_stem in lora_pairs: match = lora_pairs[base_stem] |
|
|
|
|
|
if not match: |
|
|
if "to_q" in base_stem: |
|
|
qkv = base_stem.replace("to_q", "qkv") |
|
|
if qkv in lora_pairs: match = lora_pairs[qkv] |
|
|
elif "to_k" in base_stem: |
|
|
qkv = base_stem.replace("to_k", "qkv") |
|
|
if qkv in lora_pairs: match = lora_pairs[qkv] |
|
|
elif "to_v" in base_stem: |
|
|
qkv = base_stem.replace("to_v", "qkv") |
|
|
if qkv in lora_pairs: match = lora_pairs[qkv] |
|
|
|
|
|
if match: |
|
|
down = match["down"] |
|
|
up = match["up"] |
|
|
|
|
|
scaling = scale * (match["alpha"] / match["rank"]) |
|
|
if len(v.shape) == 4 and len(down.shape) == 2: |
|
|
down = down.unsqueeze(-1).unsqueeze(-1) |
|
|
up = up.unsqueeze(-1).unsqueeze(-1) |
|
|
try: |
|
|
if len(up.shape) == 4: |
|
|
delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) |
|
|
else: |
|
|
delta = up @ down |
|
|
except: delta = up.T @ down |
|
|
|
|
|
delta = delta * scaling |
|
|
|
|
|
|
|
|
valid = True |
|
|
if delta.shape == v.shape: pass |
|
|
elif delta.shape[0] == v.shape[0] * 3: |
|
|
chunk = v.shape[0] |
|
|
if "to_q" in k: delta = delta[0:chunk, ...] |
|
|
elif "to_k" in k: delta = delta[chunk:2*chunk, ...] |
|
|
elif "to_v" in k: delta = delta[2*chunk:, ...] |
|
|
else: valid = False |
|
|
elif delta.numel() == v.numel(): delta = delta.reshape(v.shape) |
|
|
else: valid = False |
|
|
|
|
|
if valid: |
|
|
v = v.to(dtype) |
|
|
delta = delta.to(dtype) |
|
|
v.add_(delta) |
|
|
del delta |
|
|
|
|
|
|
|
|
|
|
|
if v.dtype != dtype: v = v.to(dtype) |
|
|
if dtype == torch.bfloat16: |
|
|
raw = v.view(torch.int16).numpy().tobytes() |
|
|
else: |
|
|
raw = v.numpy().tobytes() |
|
|
f_out.write(raw) |
|
|
del v |
|
|
|
|
|
|
|
|
for fh in open_files.values(): fh.file.close() |
|
|
|
|
|
|
|
|
path_in_repo = f"{output_subfolder}/{filename}" if output_subfolder else filename |
|
|
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token) |
|
|
os.remove(out_path) |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_total_size = 0 |
|
|
for t_list in plan: |
|
|
for t in t_list['tensors']: |
|
|
tgt_dtype_str = "BF16" if dtype == torch.bfloat16 else "F16" if dtype == torch.float16 else "F32" |
|
|
final_total_size += get_tensor_byte_size(t['shape'], tgt_dtype_str) |
|
|
|
|
|
index_data = {"metadata": {"total_size": final_total_size}, "weight_map": index_map} |
|
|
with open(TempDir / index_file, "w") as f: |
|
|
json.dump(index_data, f, indent=4) |
|
|
|
|
|
path_in_repo = f"{output_subfolder}/{index_file}" if output_subfolder else index_file |
|
|
api.upload_file(path_or_fileobj=TempDir / index_file, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token) |
|
|
|
|
|
cleanup_temp() |
|
|
return f"Success! {len(plan)} shards created at {output_repo}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): |
|
|
org = MemoryEfficientSafeOpen(model_org) |
|
|
tuned = MemoryEfficientSafeOpen(model_tuned) |
|
|
lora_sd = {} |
|
|
print("Calculating diffs...") |
|
|
for key in tqdm(org.keys()): |
|
|
if key not in tuned.keys(): continue |
|
|
mat_org = org.get_tensor(key).float() |
|
|
mat_tuned = tuned.get_tensor(key).float() |
|
|
diff = mat_tuned - mat_org |
|
|
if torch.max(torch.abs(diff)) < 1e-4: continue |
|
|
|
|
|
out_dim, in_dim = diff.shape[:2] |
|
|
r = min(rank, in_dim, out_dim) |
|
|
is_conv = len(diff.shape) == 4 |
|
|
if is_conv: diff = diff.flatten(start_dim=1) |
|
|
|
|
|
try: |
|
|
U, S, Vh = torch.linalg.svd(diff, full_matrices=False) |
|
|
U, S, Vh = U[:, :r], S[:r], Vh[:r, :] |
|
|
U = U @ torch.diag(S) |
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
|
hi_val = torch.quantile(dist, clamp) |
|
|
U = U.clamp(-hi_val, hi_val) |
|
|
Vh = Vh.clamp(-hi_val, hi_val) |
|
|
if is_conv: |
|
|
U = U.reshape(out_dim, r, 1, 1) |
|
|
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) |
|
|
else: |
|
|
U = U.reshape(out_dim, r) |
|
|
Vh = Vh.reshape(r, in_dim) |
|
|
stem = key.replace(".weight", "") |
|
|
lora_sd[f"{stem}.lora_up.weight"] = U |
|
|
lora_sd[f"{stem}.lora_down.weight"] = Vh |
|
|
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() |
|
|
except: pass |
|
|
out = TempDir / "extracted.safetensors" |
|
|
save_file(lora_sd, out) |
|
|
return str(out) |
|
|
|
|
|
def task_extract(hf_token, org, tun, rank, out): |
|
|
cleanup_temp() |
|
|
login(hf_token) |
|
|
try: |
|
|
p1 = download_file(org, hf_token, filename="org.safetensors") |
|
|
p2 = download_file(tun, hf_token, filename="tun.safetensors") |
|
|
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) |
|
|
api.create_repo(repo_id=out, exist_ok=True, token=hf_token) |
|
|
api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token) |
|
|
return "Done" |
|
|
except Exception as e: return f"Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigma_rel_to_gamma(sigma_rel): |
|
|
t = sigma_rel**-2 |
|
|
coeffs = [1, 7, 16 - t, 12 - t] |
|
|
roots = np.roots(coeffs) |
|
|
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() |
|
|
return gamma |
|
|
|
|
|
def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo): |
|
|
cleanup_temp() |
|
|
login(hf_token) |
|
|
|
|
|
urls = [u.strip() for u in lora_urls.split(",") if u.strip()] |
|
|
paths = [] |
|
|
try: |
|
|
for i, url in enumerate(urls): |
|
|
paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors")) |
|
|
except Exception as e: return f"Download Error: {e}" |
|
|
|
|
|
if not paths: return "No models found" |
|
|
|
|
|
base_sd = load_file(paths[0], device="cpu") |
|
|
for k in base_sd: |
|
|
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() |
|
|
|
|
|
gamma = None |
|
|
if sigma_rel > 0: |
|
|
gamma = sigma_rel_to_gamma(sigma_rel) |
|
|
|
|
|
for i, path in enumerate(paths[1:]): |
|
|
print(f"Merging {path}") |
|
|
if gamma is not None: |
|
|
t = i + 1 |
|
|
current_beta = (1 - 1 / t) ** (gamma + 1) |
|
|
else: |
|
|
current_beta = beta |
|
|
|
|
|
curr = load_file(path, device="cpu") |
|
|
for k in base_sd: |
|
|
if k in curr and "alpha" not in k: |
|
|
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) |
|
|
|
|
|
out = TempDir / "merged_adapters.safetensors" |
|
|
save_file(base_sd, out) |
|
|
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) |
|
|
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) |
|
|
return "Done" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def index_sv_ratio(S, target): |
|
|
max_sv = S[0] |
|
|
min_sv = max_sv / target |
|
|
index = int(torch.sum(S > min_sv).item()) |
|
|
return max(1, min(index, len(S) - 1)) |
|
|
|
|
|
def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): |
|
|
cleanup_temp() |
|
|
login(hf_token) |
|
|
try: |
|
|
path = download_file(lora_input, hf_token) |
|
|
except Exception as e: return f"Error: {e}" |
|
|
|
|
|
state = load_file(path, device="cpu") |
|
|
new_state = {} |
|
|
|
|
|
groups = {} |
|
|
for k in state: |
|
|
stem = get_key_stem(k) |
|
|
simple = k.split(".lora_")[0] |
|
|
if simple not in groups: groups[simple] = {} |
|
|
if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] |
|
|
if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] |
|
|
if "alpha" in k: groups[simple]["alpha"] = state[k] |
|
|
|
|
|
for stem, g in tqdm(groups.items()): |
|
|
if "down" in g and "up" in g: |
|
|
down, up = g["down"].float(), g["up"].float() |
|
|
|
|
|
if len(down.shape) == 4: |
|
|
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) |
|
|
flat = merged.flatten(1) |
|
|
else: |
|
|
merged = up @ down |
|
|
flat = merged |
|
|
|
|
|
U, S, Vh = torch.linalg.svd(flat, full_matrices=False) |
|
|
|
|
|
target_rank = int(new_rank) |
|
|
if dynamic_method == "sv_ratio": |
|
|
target_rank = index_sv_ratio(S, dynamic_param) |
|
|
|
|
|
target_rank = min(target_rank, S.shape[0]) |
|
|
|
|
|
U = U[:, :target_rank] |
|
|
S = S[:target_rank] |
|
|
U = U @ torch.diag(S) |
|
|
Vh = Vh[:target_rank, :] |
|
|
|
|
|
if len(down.shape) == 4: |
|
|
U = U.reshape(up.shape[0], target_rank, 1, 1) |
|
|
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3]) |
|
|
|
|
|
new_state[f"{stem}.lora_down.weight"] = Vh |
|
|
new_state[f"{stem}.lora_up.weight"] = U |
|
|
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float() |
|
|
|
|
|
out = TempDir / "resized.safetensors" |
|
|
save_file(new_state, out) |
|
|
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) |
|
|
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token) |
|
|
return "Done" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = ".container { max-width: 900px; margin: auto; }" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🧰SOONmerge® LoRA Toolkit") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Merge to Base + Reshard Output"): |
|
|
t1_token = gr.Textbox(label="Token", type="password") |
|
|
t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo") |
|
|
t1_sub = gr.Textbox(label="Subfolder", value="transformer") |
|
|
t1_lora = gr.Textbox(label="LoRA Repo as (name/repo)", value="GuangyuanSD/Z-Image-Re-Turbo-LoRA") |
|
|
with gr.Row(): |
|
|
t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) |
|
|
t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") |
|
|
t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) |
|
|
t1_out = gr.Textbox(label="Output Repo") |
|
|
t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo") |
|
|
t1_priv = gr.Checkbox(label="Private", value=True) |
|
|
t1_btn = gr.Button("Merge") |
|
|
t1_res = gr.Textbox(label="Result") |
|
|
t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) |
|
|
|
|
|
with gr.Tab("Extract Adapter"): |
|
|
t2_token = gr.Textbox(label="Token", type="password") |
|
|
t2_org = gr.Textbox(label="Original Model") |
|
|
t2_tun = gr.Textbox(label="Tuned Model") |
|
|
t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) |
|
|
t2_out = gr.Textbox(label="Output Repo") |
|
|
t2_btn = gr.Button("Extract") |
|
|
t2_res = gr.Textbox(label="Result") |
|
|
t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) |
|
|
|
|
|
with gr.Tab("Merge Multiple Adapters"): |
|
|
t3_token = gr.Textbox(label="Token", type="password") |
|
|
t3_urls = gr.Textbox(label="URLs") |
|
|
with gr.Row(): |
|
|
t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01) |
|
|
t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01) |
|
|
t3_out = gr.Textbox(label="Output Repo") |
|
|
t3_btn = gr.Button("Merge") |
|
|
t3_res = gr.Textbox(label="Result") |
|
|
t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res) |
|
|
|
|
|
with gr.Tab("Resize Adapter"): |
|
|
t4_token = gr.Textbox(label="Token", type="password") |
|
|
t4_in = gr.Textbox(label="LoRA") |
|
|
with gr.Row(): |
|
|
t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1) |
|
|
t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method") |
|
|
t4_param = gr.Number(label="Dynamic Param", value=4.0) |
|
|
t4_out = gr.Textbox(label="Output") |
|
|
t4_btn = gr.Button("Resize") |
|
|
t4_res = gr.Textbox(label="Result") |
|
|
t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(css=css, ssr_mode=False) |