| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import time |
|
|
| from functools import partial |
| from typing import Any, Dict, Optional, Tuple |
| from warnings import warn |
|
|
| import torch |
| from omegaconf import DictConfig, ListConfig |
|
|
| from torch import nn |
| from torch.distributed import destroy_process_group, init_process_group |
| from torch.distributed.fsdp import ( |
| FullOptimStateDictConfig, |
| FullStateDictConfig, |
| FullyShardedDataParallel as FSDP, |
| StateDictType, |
| ) |
| from torch.optim import Optimizer |
| from torch.utils.data import DataLoader, DistributedSampler |
| from torchtune import config, modules, utils |
| from torchtune.datasets import ConcatDataset |
| from torchtune.modules.peft.peft_utils import ( |
| get_adapter_params, |
| get_merged_lora_ckpt, |
| set_trainable_params, |
| validate_state_dict_for_lora, |
| ) |
| from torchtune.recipe_interfaces import FTRecipeInterface |
|
|
| from tqdm import tqdm |
|
|
| log = utils.get_logger("DEBUG") |
|
|
|
|
| class LoRAFinetuneRecipeDistributed(FTRecipeInterface): |
| """ |
| Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports |
| distributed training and can be run on a single node (1 to 8 GPUs). |
| |
| Features: |
| - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not |
| supported. |
| |
| - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` |
| flag. Activation checkpointing helps reduce the memory footprint since we no longer keep |
| activations in memory and instead recompute them during the backward pass. This is especially |
| helpful for larger batch sizes when you're memory constrained. But these savings in memory |
| come at the cost of training performance. In most cases training can slow-down quite a bit as |
| a result of this activation recomputation. |
| |
| - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` |
| flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In |
| most cases this should halve the memory footprint of full precision (fp32) training, without |
| loss in model quality (will depend on the model, training data and other settings). For |
| GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 |
| precision are currently not supported. |
| |
| - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is |
| controlled using the ``gradient_accumulation_steps`` flag. |
| |
| Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. |
| |
| For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a |
| total batch size of 64. |
| |
| Gradient accumulation is especially useful when you are memory constrained. In this case, |
| accumulating gradients might give you better training speed than enabling activation |
| checkpointing. |
| |
| - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of |
| training. Currently we checkpoint both the adapter weights (trainable params only) and the |
| complete merged weights (adapter weights added back to the base model). For more details |
| please take a look at our LoRA tutorial |
| (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). |
| |
| Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are |
| only saved at the end of a given epoch and used in case of resuming training. Resuming |
| training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is |
| currently not supported. |
| |
| For more details on the checkpointer, please take a look at |
| our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). |
| |
| - Logging. Terminal, Disk, WandB and TensorBoard are all supported. |
| |
| For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config |
| has example commands for how to kick-off training. |
| |
| Args: |
| cfg (DictConfig): OmegaConf object parsed from yaml file |
| |
| Raises: |
| ValueError: If ``dtype`` is set to fp16. |
| ValueError: If world_size is 1 |
| RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. |
| """ |
|
|
| def __init__(self, cfg: DictConfig) -> None: |
| self._device = utils.get_device(device=cfg.device) |
| self._dtype = utils.get_dtype(cfg.dtype, device=self._device) |
|
|
| if self._dtype == torch.float16: |
| raise ValueError( |
| "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." |
| ) |
|
|
| _, rank = utils.get_world_size_and_rank() |
|
|
| |
| |
| self._is_rank_zero = rank == 0 |
|
|
| |
| self._output_dir = cfg.output_dir |
| self._log_every_n_steps = cfg.get("log_every_n_steps", 1) |
| self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) |
|
|
| |
| self._enable_activation_checkpointing = cfg.enable_activation_checkpointing |
|
|
| |
| |
| self.seed = utils.set_seed(seed=cfg.seed) |
| self.epochs_run = 0 |
| self.total_epochs = cfg.epochs |
| self.max_steps_per_epoch = cfg.max_steps_per_epoch |
| self.global_step = 0 |
|
|
| self._resume_from_checkpoint = cfg.resume_from_checkpoint |
| self._gradient_accumulation_steps = cfg.gradient_accumulation_steps |
|
|
| def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: |
| """ |
| Extract the checkpoint state from file and validate. This includes the |
| base model weights. If resume_from_checkpoint is True, this also includes |
| the adapter weights and recipe state |
| """ |
| self._checkpointer = config.instantiate( |
| cfg_checkpointer, |
| resume_from_checkpoint=self._resume_from_checkpoint, |
| ) |
| checkpoint_dict = self._checkpointer.load_checkpoint() |
|
|
| |
| |
| |
| if self._resume_from_checkpoint: |
| if utils.ADAPTER_KEY not in checkpoint_dict: |
| raise ValueError( |
| "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." |
| ) |
| |
| |
| self._update_recipe_state(checkpoint_dict) |
| return checkpoint_dict |
|
|
| def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: |
| """ |
| Updates the recipe state from checkpoint. |
| """ |
| if not ( |
| utils.SEED_KEY in ckpt_dict |
| and utils.TOTAL_EPOCHS_KEY in ckpt_dict |
| and utils.MAX_STEPS_KEY in ckpt_dict |
| ): |
| raise KeyError( |
| "Checkpoint does not contain the required keys needed for updating recipe state." |
| "Are you sure you passed in the right recipe checkpoint?" |
| ) |
| |
| |
| if ( |
| self.seed != ckpt_dict[utils.SEED_KEY] |
| or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] |
| or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] |
| ): |
| warn( |
| message="""Configured value for seed, epochs or max_steps_per_epoch |
| does not match the value stored in checkpoint.""" |
| ) |
| self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) |
| self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] |
| self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] |
| self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] |
|
|
| def setup(self, cfg: DictConfig) -> None: |
| """ |
| Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), |
| model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. |
| """ |
| if self._is_rank_zero: |
| self._metric_logger = config.instantiate(cfg.metric_logger) |
|
|
| |
| self._metric_logger.log_config(cfg) |
|
|
| checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) |
|
|
| self._model = self._setup_model( |
| cfg_model=cfg.model, |
| enable_activation_checkpointing=cfg.enable_activation_checkpointing, |
| base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], |
| lora_weights_state_dict=( |
| checkpoint_dict[utils.ADAPTER_KEY] |
| if self._resume_from_checkpoint |
| else None |
| ), |
| ) |
| self._tokenizer = config.instantiate(cfg.tokenizer) |
|
|
| self._optimizer = self._setup_optimizer( |
| cfg_optimizer=cfg.optimizer, |
| opt_state_dict=checkpoint_dict[utils.OPT_KEY] |
| if self._resume_from_checkpoint |
| else None, |
| ) |
|
|
| self._loss_fn = config.instantiate(cfg.loss) |
|
|
| |
| |
| self._sampler, self._dataloader = self._setup_data( |
| cfg_dataset=cfg.dataset, |
| shuffle=cfg.shuffle, |
| batch_size=cfg.batch_size, |
| ) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| self._steps_per_epoch = ( |
| len(self._dataloader) // self._gradient_accumulation_steps |
| ) |
| if ( |
| self.max_steps_per_epoch is not None |
| and self.max_steps_per_epoch < self._steps_per_epoch |
| ): |
| self._steps_per_epoch = self.max_steps_per_epoch |
| self.global_step = self.epochs_run * self._steps_per_epoch |
|
|
| |
| |
| self._lr_scheduler = self._setup_lr_scheduler( |
| cfg_lr_scheduler=cfg.lr_scheduler, |
| num_training_steps=self.total_epochs * self._steps_per_epoch, |
| last_epoch=self.global_step - 1, |
| ) |
|
|
| def _setup_model( |
| self, |
| cfg_model: DictConfig, |
| enable_activation_checkpointing: bool, |
| base_model_state_dict: Dict[str, Any], |
| lora_weights_state_dict: Optional[Dict[str, Any]] = None, |
| ) -> nn.Module: |
| """ |
| Model initialization has some important considerations: |
| a. To minimize GPU peak memory, we load the model on CPU with the right |
| dtype. To ensure that we don't instantiate ``world_size`` number of models, |
| we initialize on meta_device for all ranks other than rank 0. |
| b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the |
| model weights from checkpoint. |
| c. While wrapping the model with FSDP, we set ``sync_module_states`` |
| to TRUE and broadcast module params and buffers from rank 0. |
| d. The ``device_id`` param ensures that the FSDP initialization happens on |
| the correct device. |
| """ |
|
|
| if self._is_rank_zero: |
| log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") |
| init_start = time.perf_counter() |
|
|
| with utils.set_default_dtype(self._dtype): |
| model = config.instantiate(cfg_model) |
|
|
| log.info( |
| f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" |
| ) |
|
|
| |
| |
| |
| |
| |
| validate_state_dict_for_lora( |
| lora_attn_modules=cfg_model.lora_attn_modules, |
| apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, |
| apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), |
| full_model_state_dict_keys=model.state_dict().keys(), |
| lora_state_dict_keys=( |
| lora_weights_state_dict.keys() |
| if lora_weights_state_dict is not None |
| else None |
| ), |
| base_model_state_dict_keys=base_model_state_dict.keys(), |
| ) |
|
|
| |
| |
| model.load_state_dict(base_model_state_dict, strict=False) |
| if lora_weights_state_dict: |
| model.load_state_dict(lora_weights_state_dict, strict=False) |
|
|
| else: |
| |
| with utils.set_default_dtype(self._dtype), torch.device("meta"): |
| model = config.instantiate(cfg_model) |
|
|
| if self._dtype == torch.bfloat16: |
| model = model.to(torch.bfloat16) |
|
|
| |
| self._lora_rank = cfg_model.lora_rank |
| self._lora_alpha = cfg_model.lora_alpha |
|
|
| |
| self.adapter_params = get_adapter_params(model) |
| set_trainable_params(model, self.adapter_params) |
|
|
| model = FSDP( |
| module=model, |
| auto_wrap_policy=utils.lora_fsdp_wrap_policy( |
| modules_to_wrap={modules.TransformerDecoderLayer} |
| ), |
| sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, |
| device_id=self._device, |
| |
| mixed_precision=None, |
| |
| sync_module_states=True, |
| |
| param_init_fn=( |
| lambda module: module.to_empty( |
| device=torch.device("cuda"), recurse=False |
| ) |
| if not self._is_rank_zero |
| else None |
| ), |
| ) |
|
|
| |
| utils.validate_no_params_on_meta_device(model) |
|
|
| if enable_activation_checkpointing: |
| utils.set_activation_checkpointing( |
| model, auto_wrap_policy={modules.TransformerDecoderLayer} |
| ) |
| if self._is_rank_zero: |
| memory_stats = utils.get_memory_stats(device=self._device) |
| utils.log_memory_stats(memory_stats) |
|
|
| |
| torch.distributed.barrier() |
|
|
| return model |
|
|
| def _setup_optimizer( |
| self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None |
| ) -> Optimizer: |
| optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) |
| if opt_state_dict: |
| |
| |
| opt_state_dict = utils.transform_opt_state_dict( |
| opt_state_dict, self._model, optimizer |
| ) |
| optimizer.load_state_dict(opt_state_dict) |
|
|
| if self._is_rank_zero: |
| log.info("Optimizer and loss are initialized.") |
| return optimizer |
|
|
| def _setup_lr_scheduler( |
| self, |
| cfg_lr_scheduler: DictConfig, |
| num_training_steps: int, |
| last_epoch: int, |
| ) -> Optimizer: |
| lr_scheduler = config.instantiate( |
| cfg_lr_scheduler, |
| self._optimizer, |
| num_training_steps=num_training_steps, |
| last_epoch=last_epoch, |
| ) |
| if self._is_rank_zero: |
| log.info("Learning rate scheduler is initialized.") |
| return lr_scheduler |
|
|
| def _setup_data( |
| self, |
| cfg_dataset: DictConfig, |
| shuffle: bool, |
| batch_size: int, |
| ) -> Tuple[DistributedSampler, DataLoader]: |
| """ |
| All data related setup happens here. Currently this recipe only supports the |
| DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, |
| iterable datasets and streaming datasets are not supported. |
| """ |
| world_size, rank = utils.get_world_size_and_rank() |
|
|
| if isinstance(cfg_dataset, ListConfig): |
| datasets = [ |
| config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) |
| for single_cfg_dataset in cfg_dataset |
| ] |
| ds = ConcatDataset(datasets=datasets) |
| packed = False |
| else: |
| ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) |
| packed = cfg_dataset.get("packed", False) |
|
|
| sampler = DistributedSampler( |
| ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 |
| ) |
|
|
| dataloader = DataLoader( |
| dataset=ds, |
| batch_size=batch_size, |
| sampler=sampler, |
| collate_fn=partial( |
| utils.padded_collate, |
| padding_idx=self._tokenizer.pad_id, |
| ignore_idx=self._loss_fn.ignore_index, |
| ) |
| if not packed |
| else None, |
| ) |
|
|
| if self._is_rank_zero: |
| log.info("Dataset and Sampler are initialized.") |
|
|
| return sampler, dataloader |
|
|
| def save_checkpoint( |
| self, |
| epoch: int, |
| ) -> None: |
| """ |
| Checkpoint the state of the recipe. The constructed checkpoint state dict |
| contains the following information: |
| - Merged weights with key MODEL_KEY |
| - Adapter weights with key ADAPTER_KEY |
| - Relevant recipe state if training is not complete |
| |
| Checkpointer will save the merged weights, adapter weights and recipe state in |
| different checkpoint files. To correctly resume from training, the adapter weights |
| and recipe state must be provided along with the base model weights. |
| """ |
| |
| checkpoint_dict = {} |
|
|
| intermediate_checkpoint = epoch + 1 < self.total_epochs |
| |
| |
| with FSDP.state_dict_type( |
| self._model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| cpu_state_dict = self._model.state_dict() |
| if intermediate_checkpoint: |
| opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) |
| else: |
| opt_state_dict = None |
|
|
| |
| |
| if self._is_rank_zero: |
|
|
| |
| |
| adapter_key_filter = lambda x: x in self.adapter_params |
| adapter_state_dict = { |
| k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) |
| } |
| checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) |
|
|
| |
| merged_state_dict = get_merged_lora_ckpt( |
| cpu_state_dict, |
| rank=self._lora_rank, |
| alpha=self._lora_alpha, |
| ) |
| checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) |
|
|
| |
| |
| if intermediate_checkpoint: |
| checkpoint_dict.update( |
| { |
| utils.OPT_KEY: opt_state_dict, |
| utils.SEED_KEY: self.seed, |
| utils.EPOCHS_KEY: self.epochs_run, |
| utils.TOTAL_EPOCHS_KEY: self.total_epochs, |
| utils.MAX_STEPS_KEY: self.max_steps_per_epoch, |
| } |
| ) |
|
|
| self._checkpointer.save_checkpoint( |
| checkpoint_dict, |
| epoch=epoch, |
| intermediate_checkpoint=intermediate_checkpoint, |
| ) |
|
|
| def train(self) -> None: |
| """ |
| The core training loop. |
| """ |
| |
| utils.cleanup_before_training() |
|
|
| _, rank = utils.get_world_size_and_rank() |
|
|
| |
| self._optimizer.zero_grad() |
|
|
| |
| t0 = time.perf_counter() |
| running_loss = 0 |
| num_tokens = 0 |
|
|
| |
| for curr_epoch in range(self.epochs_run, self.total_epochs): |
|
|
| |
| |
| self._sampler.set_epoch(curr_epoch) |
|
|
| pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) |
| for idx, batch in enumerate(self._dataloader): |
| if ( |
| self.max_steps_per_epoch is not None |
| and (idx // self._gradient_accumulation_steps) |
| == self.max_steps_per_epoch |
| ): |
| break |
|
|
| |
| tokens, labels = batch["tokens"], batch["labels"] |
| |
| |
| mask = batch.get("mask", None) |
| input_pos = batch.get("input_pos", None) |
|
|
| tokens = tokens.to(self._device) |
| num_tokens += tokens.numel() |
| labels = labels.to(self._device) |
| mask = mask.to(self._device) if mask is not None else None |
| input_pos = ( |
| input_pos.to(self._device) if input_pos is not None else None |
| ) |
|
|
| logits = self._model(tokens, mask=mask, input_pos=input_pos) |
| |
| logits = logits[..., :-1, :].contiguous() |
| labels = labels[..., 1:].contiguous() |
| logits = logits.transpose(1, 2) |
| |
| loss = self._loss_fn(logits, labels) |
|
|
| loss = loss / self._gradient_accumulation_steps |
| running_loss += loss |
| loss.backward() |
|
|
| |
| if (idx + 1) % self._gradient_accumulation_steps == 0: |
| self._optimizer.step() |
| self._optimizer.zero_grad(set_to_none=True) |
| self._lr_scheduler.step() |
|
|
| |
| self.global_step += 1 |
|
|
| loss_to_log = running_loss.item() |
| pbar.update(1) |
| pbar.set_description( |
| f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" |
| ) |
|
|
| |
| if ( |
| self.global_step % self._log_every_n_steps == 0 |
| and self._is_rank_zero |
| ): |
| time_per_step = time.perf_counter() - t0 |
| log_dict = { |
| "loss": loss_to_log, |
| "lr": self._optimizer.param_groups[0]["lr"], |
| "tokens_per_second_per_gpu": num_tokens / time_per_step, |
| } |
| if self._log_peak_memory_stats: |
| log_dict.update(utils.get_memory_stats(device=self._device)) |
| self._metric_logger.log_dict( |
| log_dict, |
| step=self.global_step, |
| ) |
|
|
| |
| running_loss = 0 |
| num_tokens = 0 |
| t0 = time.perf_counter() |
|
|
| self.epochs_run += 1 |
| self.save_checkpoint(epoch=curr_epoch) |
|
|
| def cleanup(self) -> None: |
| if self._is_rank_zero: |
| self._metric_logger.close() |
| destroy_process_group() |
|
|
|
|
| @config.parse |
| def recipe_main(cfg: DictConfig) -> None: |
| """ |
| Entry point for the recipe. |
| |
| Configurable parameters are read in the following order: |
| - Parameters specified in config (see available configs through ``tune ls``) |
| - Overwritten by arguments from the command-line |
| """ |
| if not utils.is_distributed(): |
| raise RuntimeError( |
| "Distributed finetune recipe should be run via a distributed launcher." |
| "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" |
| ) |
| os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" |
| init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") |
|
|
| config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) |
|
|
| recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) |
| recipe.setup(cfg=cfg) |
| recipe.train() |
| recipe.cleanup() |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(recipe_main()) |
|
|