File size: 5,327 Bytes
e2b41e5 33929c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import logging
import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoModelForCausalLM
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
SEED = 0xdeadbeef
def pytest_addoption(parser):
parser.addoption(
"--measure-perf",
action="store_true",
default=False,
help=
"Measure execution time and peak memory usage during optimizer step.",
)
parser.addoption(
"--do-profile",
action="store_true",
default=False,
help="Enable profiling during tests.",
)
parser.addoption(
"--skip-verify",
action="store_true",
default=False,
help=
"Skip verification of optimizer step correctness with sequential implementation.\n"
"This can be useful when GPU memory is limited.",
)
def pytest_configure(config):
if config.getoption(
"--do-profile") and not config.getoption("--measure-perf"):
raise pytest.UsageError(
"--do-profile requires --measure-perf. Please enable both flags.")
@pytest.fixture(scope="session")
def measure_perf(request):
return request.config.getoption("--measure-perf")
@pytest.fixture(scope="session")
def do_profile(request):
return request.config.getoption("--do-profile")
@pytest.fixture(scope="session")
def skip_verify(request):
return request.config.getoption("--skip-verify")
@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
if version.parse(torch.__version__) < version.parse("2.8"):
pytest.skip("torch>=2.8.0 is required for parallel muon")
return
try:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
except Exception as e:
print(f"Failed to initialize torch.distributed: {e}")
pytest.skip("Failed to initialize torch.distributed")
if dist.get_world_size() != 8:
pytest.skip("Need 8 processes in dist group. "
"You can run with `torchrun --nproc-per-node=8 "
"--local-ranks-filter 0 -m pytest "
"test_rms_norm_sequence_parallel.py`."
"To run with less than 8 gpus, modify "
"the test cases accordingly.")
yield
dist.destroy_process_group()
@pytest.fixture(scope="session")
def inputs():
"""Load Motif-2.6B model and generate random gradients for testing.
Returns:
tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]:
- torch.nn.Module: The Motif-2.6B model.
- list[torch.Tensor]: A list of random gradients for each model parameter.
- dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits.
"""
model_name = "Motif-Technologies/Motif-2.6B-4layer-random"
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
)
logger.info(
f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)"
)
grads: list[torch.Tensor] = []
for param in model.parameters():
grad = torch.randn_like(param, device=param.device, dtype=param.dtype)
grads.append(grad)
qk_logits: dict[int, torch.Tensor] = {
i:
torch.randn(model.config.num_attention_heads,
device=model.device,
dtype=torch.bfloat16)
for i in range(model.config.num_hidden_layers)
}
return [model, grads, qk_logits]
def _create_moe_model(num_experts=8, top_k=2, n_layers=4):
"""Create a torchtitan Llama4 MoE model with random gradients."""
from torchtitan.models.llama4.model.args import TransformerModelArgs
from torchtitan.models.llama4.model.model import Transformer
from torchtitan.models.moe import MoEArgs
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
moe_args = MoEArgs(
num_experts=num_experts,
num_shared_experts=1,
top_k=top_k,
score_func="sigmoid",
)
model_args = TransformerModelArgs(
dim=2048,
n_layers=n_layers,
n_heads=16,
n_kv_heads=8,
vocab_size=32000,
norm_eps=1e-5,
rope_theta=10000,
max_seq_len=4096,
moe_args=moe_args,
interleave_moe_layer_step=1,
)
model = Transformer(model_args)
model.init_weights()
logger.info(f"Created torchtitan Llama4 MoE model "
f"(num_experts={num_experts}, n_layers={n_layers}, "
f"{len(list(model.parameters()))} parameters)")
grads = [
torch.randn_like(param, device=param.device, dtype=param.dtype)
for param in model.parameters()
]
return [model, grads]
@pytest.fixture(scope="session")
def moe_inputs():
"""MoE model with 8 experts (standard config)."""
return _create_moe_model(num_experts=8, top_k=2)
@pytest.fixture(scope="session")
def moe_inputs_few_experts():
"""MoE model with 2 experts (triggers EFSDP Shard(1) mode)."""
return _create_moe_model(num_experts=2, top_k=1)
|