| """Console logger utilities. |
| |
| Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py |
| Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging |
| """ |
|
|
| import math |
| import os |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import fsspec |
| import torch |
| from timm.scheduler import CosineLRScheduler |
| import omegaconf |
| import rich |
| import rich.syntax |
| import rich.tree |
|
|
| from decoupled_utils import rank_zero_fn, rprint |
| from decoupled_utils import (get_hostname, get_num_devices, get_tpu_devices, gprint, |
| is_torch_cuda_available, is_torch_xla_available, rprint) |
|
|
|
|
| def print_trainable_parameters(model): |
| trainable_params = 0 |
| all_param = 0 |
| for _, param in model.named_parameters(): |
| all_param += param.numel() |
| if param.requires_grad: |
| trainable_params += param.numel() |
| print( |
| f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" |
| ) |
|
|
| def fsspec_exists(filename): |
| """Check if a file exists using fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(filename) |
| return fs.exists(filename) |
|
|
|
|
| def fsspec_listdir(dirname): |
| """Listdir in manner compatible with fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(dirname) |
| return fs.ls(dirname) |
|
|
|
|
| def fsspec_mkdirs(dirname, exist_ok=True): |
| """Mkdirs in manner compatible with fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(dirname) |
| fs.makedirs(dirname, exist_ok=exist_ok) |
|
|
|
|
| def print_nans(tensor, name): |
| if torch.isnan(tensor).any(): |
| gprint(f"{name} has nans: {tensor}") |
|
|
|
|
| class CosineDecayWarmupLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): |
| """Wrap timm.scheduler.CosineLRScheduler |
| Enables calling scheduler.step() without passing in epoch. |
| Supports resuming as well. |
| Adapted from: |
| https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._last_epoch = -1 |
| self.step(epoch=0) |
|
|
| def step(self, epoch=None): |
| if epoch is None: |
| self._last_epoch += 1 |
| else: |
| self._last_epoch = epoch |
| |
| |
| |
| |
| |
| |
| |
| if self.t_in_epochs: |
| super().step(epoch=self._last_epoch) |
| else: |
| super().step_update(num_updates=self._last_epoch) |
|
|
|
|
| class Sampler: |
| def __init__(self, shape): |
| self.shape = shape |
|
|
| def _sampling_noise(self): |
| pass |
|
|
| def _hard_sample(self, logits): |
| pass |
|
|
| def _soft_sample(self, logits): |
| return 0 |
|
|
| def sample(self, logits): |
| noise = self._sampling_noise() |
| noise = noise[: logits.shape[0], :] |
| logits = logits + noise.to(dtype=logits.dtype, device=logits.device) |
| hard_sample = self._hard_sample(logits) |
| soft_sample = self._soft_sample(logits) |
| return soft_sample + (hard_sample - soft_sample).detach() |
|
|
|
|
| class TopKSampler(Sampler): |
| def __init__(self, k, shape, gamma_tau=1.0): |
| super().__init__(shape) |
| self.k = k |
| self.gamma_tau = gamma_tau |
| self.num_betas = 10 |
| self.sampler = torch.distributions.gamma.Gamma(1 / k * torch.ones(self.num_betas, *self.shape), 1.0) |
|
|
| def _sampling_noise(self): |
| noise = self.sampler.sample() |
| beta = self.k / torch.arange(1, self.num_betas + 1, 1, dtype=torch.float32) |
| beta = beta[:, None, None] |
| assert beta.ndim == noise.ndim |
| s = noise / beta |
| s = torch.sum(s, axis=0) |
| s = s - math.log(10.0) |
| s = self.gamma_tau * (s / self.k) |
| return s |
|
|
| def _hard_sample(self, logits): |
| assert logits.ndim == 2 |
| thresholds, _ = torch.sort(logits, dim=-1) |
| thresholds = thresholds[:, -self.k][:, None] |
| return (logits >= thresholds).type(logits.dtype) |
|
|
| def _soft_sample(self, logits): |
| soft_top_k = logits - torch.mean(logits, dim=-1, keepdim=True) |
| return soft_top_k / torch.norm(soft_top_k, dim=-1, keepdim=True) |
|
|
|
|
| class DeterministicTopK(TopKSampler): |
| def __init__(self, k): |
| super().__init__(k, shape=(1, 1)) |
|
|
| def _sampling_noise(self): |
| return 0 |
|
|
| def discreize(self, x): |
| hard_sample = self._hard_sample(x) |
| soft_sample = self._soft_sample(x) |
| return soft_sample + (hard_sample - soft_sample).detach() |
|
|
|
|
| class GumbelSampler(Sampler): |
|
|
| def __init__(self, shape, temperature=1.0): |
| super().__init__(shape) |
| self.temperature = temperature |
|
|
| def _sampling_noise(self): |
| return -(1e-10 - (torch.rand(*self.shape) + 1e-10).log()).log() |
|
|
| def _hard_sample(self, logits): |
| assert logits.ndim == 2 |
| indices = torch.argmax(logits, dim=-1) |
| zeros = logits * 0 |
| ones = torch.ones_like(logits[:, :, :1]) |
| return torch.scatter(zeros, -1, indices[:, :, None], ones) |
|
|
| def _soft_sample(self, logits): |
| return torch.nn.functional.softmax(logits / self.temperature, dim=-1) |
|
|
|
|
| class BinarySampler(GumbelSampler): |
|
|
| def sample(self, probs): |
| |
| pos_noise = self._sampling_noise().to(dtype=probs.dtype, device=probs.device) |
| neg_noise = self._sampling_noise().to(dtype=probs.dtype, device=probs.device) |
| del_noise_exp = (neg_noise - pos_noise).exp() |
| hard_sample = (probs * (1 + del_noise_exp) > 1).to(probs.dtype) |
| soft_sample = probs / (probs + (1 - probs) * del_noise_exp) |
| return soft_sample + (hard_sample - soft_sample).detach() |
|
|
|
|
| class GaussianSampler: |
| def __init__(self): |
| self.softplus = torch.nn.Softplus() |
|
|
| def sample(self, x): |
| assert x.ndim == 2 |
| n = x.shape[-1] // 2 |
| mu = x[:, :n] |
| sigma = self.softplus(x[:, n:]).sqrt() |
| return mu + sigma * torch.randn_like(mu) |
|
|
|
|
| def is_global_rank_zero(): |
| """Helper function to determine if the current process is global_rank 0 (the main process)""" |
| |
| |
| rank = os.environ.get("RANK", None) |
| if rank is not None: |
| return rank == 0 |
|
|
| |
| |
| slurm_rank = os.environ.get("SLURM_PROCID", None) |
| if slurm_rank is not None: |
| return slurm_rank == 0 |
|
|
| |
| mpi_rank = os.environ.get("OMPI_COMM_WORLD_RANK", None) |
| if mpi_rank is not None: |
| return mpi_rank == 0 |
|
|
| |
| |
| |
| node_rank = os.environ.get("NODE_RANK", os.environ.get("GROUP_RANK", 0)) |
| local_rank = os.environ.get("LOCAL_RANK", 0) |
| return node_rank == 0 and local_rank == 0 |
|
|
|
|
| def get_rank(): |
| """Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise it returns 0.""" |
|
|
| if is_global_rank_zero(): |
| return 0 |
| else: |
| return torch.distributed.get_rank() |
|
|
|
|
| def set_numa_affinity(gpu_index, verbose=False): |
| import pynvml as nvml |
|
|
| nvml.nvmlInit() |
| """This util will assign to the current process the cpu cores set that resides on the same NUMA |
| node as the GPU. Typically if you have 8 GPUs, then the first 4 are on the first NUMA node and |
| the remaining 4 are on the second. |
| |
| `gpu_index` is typically the same as `LOCAL_RANK` in the distributed training, but beware that |
| `CUDA_VISIBLE_DEVICES` could impact that. e.g. `CUDA_VISIBLE_DEVICES=0,7` won't do the right |
| thing - then you will probably want to remap the ids with something like: |
| |
| ``` |
| if "CUDA_VISIBLE_DEVICES" in os.environ: |
| ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))) |
| gpu_index = ids[gpu_index] # remap |
| ``` |
| |
| """ |
|
|
| num_elements = math.ceil(os.cpu_count() / 64) |
| handle = nvml.nvmlDeviceGetHandleByIndex(gpu_index) |
| affinity_string = "" |
| for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements): |
| |
| affinity_string = f"{j:064b}{affinity_string}" |
| affinity_list = [int(x) for x in affinity_string] |
| affinity_list.reverse() |
| affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0] |
|
|
| if verbose: |
| cores = os.sched_getaffinity(0) |
| gprint(f"before: {len(cores)} visible cpu cores: {cores}") |
|
|
| try: |
| os.sched_setaffinity(0, affinity_to_set) |
| except Exception as e: |
| gprint(f"Failed to set affinity: {e}") |
|
|
| if verbose: |
| cores = os.sched_getaffinity(0) |
| gprint(f"after: {len(cores)} visible cpu cores: {cores}") |
|
|
|
|
| from typing import Dict, Union |
|
|
| import torch |
| from torch.nn import Module |
|
|
|
|
| def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]: |
| """Compute each parameter's gradient's norm and their overall norm. |
| |
| The overall norm is computed over all gradients together, as if they |
| were concatenated into a single vector. |
| |
| Args: |
| module: :class:`torch.nn.Module` to inspect. |
| norm_type: The type of the used p-norm, cast to float if necessary. |
| Can be ``'inf'`` for infinity norm. |
| group_separator: The separator string used by the logger to group |
| the gradients norms in their own subfolder instead of the logs one. |
| |
| Return: |
| norms: The dictionary of p-norms of each parameter's gradient and |
| a special entry for the total p-norm of the gradients viewed |
| as a single vector. |
| |
| """ |
| norm_type = float(norm_type) |
| if norm_type <= 0: |
| raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") |
|
|
| norms = {f"{group_separator}{name}": p.grad.data.norm(norm_type) for name, p in module.named_parameters() if p.grad is not None} |
| total_norm = torch.tensor(list(norms.values())).norm(norm_type) |
| return norms, total_norm |
|
|
| has_set_omega_conf_resolvers = False |
|
|
| def set_omega_conf_resolvers(): |
| global has_set_omega_conf_resolvers |
| if has_set_omega_conf_resolvers: |
| return |
| has_set_omega_conf_resolvers = True |
| import omegaconf |
| from omegaconf import OmegaConf |
|
|
| def get_dir_name(_root_): |
| if str(_root_.mode) == "eval": |
| return "eval" |
| elif _root_.debug: |
| return "debug" |
| else: |
| return _root_.data.train |
|
|
| def getpythoncmd(_root_): |
| return _root_.python_orig + "--multi_gpu \\\n" if (_root_.trainer.devices * _root_.trainer.num_nodes > 1) else _root_.python_orig |
|
|
| def custom_batch_size(): |
| if is_torch_cuda_available() and torch.cuda.get_device_properties(0).total_memory >= 23 * 1024 * 1024 * 1024: |
| return 64 |
| elif is_torch_cuda_available() and torch.cuda.get_device_properties(0).total_memory >= 10 * 1024 * 1024 * 1024: |
| return 32 |
| else: |
| return 28 |
|
|
| def get_slurm_name(_root_): |
| return _root_.slurm_name if hasattr(_root_, "slurm_name") and _root_.slurm_name is not None else _root_.wandb.project |
| |
| partition_time_limit_min = { |
| "partition_name": 60 * 6, |
| } |
| |
| gpu_constraints = { |
| "cluster_name": "gpu_constraints", |
| } |
|
|
| partitions = { |
| "cluster_name": "partition_name", |
| } |
|
|
|
|
| babel_exclude_nodes = set() |
| if os.environ.get("BAD_NODES", None) is not None: |
| babel_exclude_nodes.update(os.environ.get("BAD_NODES").split(",")) |
|
|
| exclude_nodes = { |
| "cluster_name": "nodes_to_exclude", |
| } |
|
|
| def get_hostname_split(): |
| return get_hostname().split("-")[0].split(".")[0] |
|
|
| omegaconf.OmegaConf.register_new_resolver("getpythoncmd", getpythoncmd) |
| omegaconf.OmegaConf.register_new_resolver("get_dir_name", get_dir_name) |
| omegaconf.OmegaConf.register_new_resolver("cwd", os.getcwd) |
| omegaconf.OmegaConf.register_new_resolver("device_count", get_num_devices) |
| omegaconf.OmegaConf.register_new_resolver("eval", eval) |
| omegaconf.OmegaConf.register_new_resolver("div_up", lambda x, y: (x + y - 1) // y) |
| omegaconf.OmegaConf.register_new_resolver("find_grad_accum", lambda x, y: round(x / y)) |
| omegaconf.OmegaConf.register_new_resolver("find_partition", lambda: partitions[get_hostname_split()] if get_hostname_split() in partitions else "all") |
| omegaconf.OmegaConf.register_new_resolver("find_constraint", lambda: gpu_constraints[get_hostname_split()] if get_hostname_split() in gpu_constraints else "") |
| omegaconf.OmegaConf.register_new_resolver("is_ar", lambda parameterization: parameterization == "ar") |
| omegaconf.OmegaConf.register_new_resolver("kv_cache_batch_size", lambda eval_batch_size, cfg: eval_batch_size * 2 if cfg is not None else eval_batch_size) |
| omegaconf.OmegaConf.register_new_resolver("exclude_nodes", lambda: exclude_nodes[get_hostname_split()] if get_hostname_split() in exclude_nodes else "") |
| omegaconf.OmegaConf.register_new_resolver("get_slurm_name", get_slurm_name) |
| |
|
|
| def adjust_n_blocks(_root_): |
| return ( |
| (_root_.model.base_n_blocks - 1 if _root_.model.base_n_blocks < 24 else _root_.model.base_n_blocks - 2) |
| if str(_root_.backbone) == "maskdit" |
| else _root_.model.base_n_blocks |
| ) |
|
|
| omegaconf.OmegaConf.register_new_resolver("adjust_n_blocks", adjust_n_blocks) |
| omegaconf.OmegaConf.register_new_resolver("partition_limit", lambda x: partition_time_limit_min[x] if x in partition_time_limit_min else 60 * 6) |
| omegaconf.OmegaConf.register_new_resolver("custom_batch_size", custom_batch_size) |
| omegaconf.OmegaConf.register_new_resolver("get_repo_dir", lambda: os.getenv("UNIDISC_DIR", str(Path(__file__).parent))) |
|
|
|
|
| @rank_zero_fn |
| def _print_config(config, resolve: bool = True, save_cfg: bool = True) -> None: |
| """Prints content of DictConfig using Rich library and its tree structure. |
| |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| resolve (bool): Whether to resolve reference fields of DictConfig. |
| save_cfg (bool): Whether to save the configuration tree to a file. |
| """ |
|
|
| style = "dim" |
| tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
| fields = config.keys() |
| for field in fields: |
| branch = tree.add(field, style=style, guide_style=style) |
|
|
| config_section = config.get(field) |
| branch_content = str(config_section) |
| if isinstance(config_section, omegaconf.DictConfig): |
| branch_content = omegaconf.OmegaConf.to_yaml(config_section, resolve=resolve) |
|
|
| branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
|
|
| rich.print(tree) |
| if save_cfg: |
| with fsspec.open("config_tree.txt", "w") as fp: |
| rich.print(tree, file=fp) |
|
|
| def set_torch_defaults(benchmark=True): |
| torch.set_float32_matmul_precision("medium") |
| if is_torch_cuda_available(): |
| rprint(f"Setting torch defaults") |
| exec("import torch.backends.cuda as cuda") |
| exec("import torch.backends.cudnn as cudnn") |
| exec("cudnn.enabled = True") |
| if benchmark: |
| exec("cudnn.benchmark = True") |
| else: |
| rprint(f"Warning: Not benchmarking") |
| exec("cudnn.allow_tf32 = True") |
| exec("cuda.matmul.allow_tf32 = True") |
| exec("cudnn.deterministic = False") |
| else: |
| rprint(f"Warning: CUDA not available. Not setting defaults.") |
|
|
| from torch.distributed.elastic.multiprocessing.errors import (ChildFailedError, |
| record) |
| from torch.distributed.elastic.multiprocessing.errors.handlers import \ |
| get_error_handler |
|
|
| _NOT_AVAILABLE = "<N/A>" |
| class ErrorHandler: |
| def __init__(self, error_handler=None): |
| self.error_handler = error_handler or get_error_handler() |
|
|
| def __enter__(self): |
| assert self.error_handler is not None |
| self.error_handler.initialize() |
| return self |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| if exc_type is not None: |
| if issubclass(exc_type, SystemExit) and exc_value.code == 0: |
| return True |
| elif issubclass(exc_type, ChildFailedError): |
| rank, failure = exc_value.get_first_failure() |
| if failure.error_file != _NOT_AVAILABLE: |
| self.error_handler.dump_error_file(failure.error_file, failure.exitcode) |
| else: |
| rprint( |
| "local_rank %s FAILED with no error file. " |
| "Decorate your entrypoint fn with @record for traceback info. " |
| "See: https://pytorch.org/docs/stable/elastic/errors.html", |
| rank |
| ) |
| return False |
| self.error_handler.record_exception(exc_value) |
| return False |
|
|
|
|
| def convert_state_dict_keys(state_dict): |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if "attn_out" in k: |
| new_key = k.replace("attn_out", "attention.attn_out") |
| elif "attn_qkv" in k: |
| new_key = k.replace("attn_qkv", "attention.attn_qkv") |
| else: |
| new_key = k |
| new_state_dict[new_key] = v |
| return new_state_dict |
|
|
| from accelerate.utils import extract_model_from_parallel |
| def apply_compile(model, **compile_kwargs): |
| """ |
| Apply torch.compile to each TransformerBlock, which makes compilation efficient due to |
| repeated structure. Alternatively one can compile the whole model (after applying DP). |
| """ |
| for layer_id, transformer_block in extract_model_from_parallel(model).blocks.named_children(): |
| transformer_block = torch.compile(transformer_block, **compile_kwargs) |
| extract_model_from_parallel(model).blocks.register_module(layer_id, transformer_block) |
|
|
| output_layer = torch.compile(extract_model_from_parallel(model).output_layer, **compile_kwargs) |
| extract_model_from_parallel(model).register_module("output_layer", output_layer) |
|
|
| def compile_model(config, model): |
| compile_kwargs = dict() |
|
|
| if config.backbone == "maskdit": |
| compile_kwargs["dynamic"] = True |
|
|
| compile_kwargs["mode"] = config.trainer.compile_mode |
| rprint(f"Using compile mode: {config.trainer.compile_mode}") |
|
|
| if getattr(config.trainer, "sd3_compile_config", True): |
| torch._inductor.config.conv_1x1_as_mm = True |
| torch._inductor.config.coordinate_descent_tuning = True |
| torch._inductor.config.epilogue_fusion = False |
| torch._inductor.config.coordinate_descent_check_all_directions = True |
| rprint(f"Using SD3 compile config") |
|
|
| if config.trainer.compile_fullgraph: |
| compile_kwargs["fullgraph"] = True |
| rprint(f"Using fullgraph compile") |
|
|
| if getattr(config.trainer, "compile_per_layer", False): |
| apply_compile(model, **compile_kwargs) |
| else: |
| model = torch.compile(model, **compile_kwargs) |
|
|
| return model |
|
|