PepFlow / train.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import wandb
import os
import shutil
import argparse
import torch
import torch.cuda.amp as amp
import torch.distributed as distrib
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, random_split
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
from pepflow.utils.vc import get_version, has_changes
from pepflow.utils.misc import BlackHole, inf_iterator, load_config, seed_all, get_logger, get_new_log_dir, current_milli_time
from pepflow.utils.data import PaddingCollate
from pepflow.utils.train import ScalarMetricAccumulator, count_parameters, get_optimizer, get_scheduler, log_losses, recursive_to, sum_weighted_losses
from models_con.pep_dataloader import PepDataset
# from models_con.flow_model import FlowModel
from models_con.flow_model import FlowModel
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./configs/angle/learn_angle.yaml')
parser.add_argument('--logdir', type=str, default="./logs")
parser.add_argument('--debug', action='store_true', default=False)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--tag', type=str, default='')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--name', type=str, default='pepflow')
args = parser.parse_args()
# Version control
branch, version = get_version()
version_short = '%s-%s' % (branch, version[:7])
if has_changes() and not args.debug:
c = input('Start training anyway? (y/n) ')
if c != 'y':
exit()
# Load configs
config, config_name = load_config(args.config)
seed_all(config.train.seed)
config['device'] = args.device
# Logging
if args.debug:
logger = get_logger('train', None)
writer = BlackHole()
else:
run = wandb.init(project=args.name, config=config, name='%s[%s]' % (config_name, args.tag))
if args.resume:
log_dir = os.path.dirname(os.path.dirname(args.resume))
else:
log_dir = get_new_log_dir(args.logdir, prefix='%s[%s]' % (config_name, version_short), tag=args.tag)
with open(os.path.join(log_dir, 'commit.txt'), 'w') as f:
f.write(branch + '\n')
f.write(version + '\n')
ckpt_dir = os.path.join(log_dir, 'checkpoints')
if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir)
logger = get_logger('train', log_dir)
# writer = torch.utils.tensorboard.SummaryWriter(log_dir)
# tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))):
shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
logger.info(args)
logger.info(config)
# Data
logger.info('Loading datasets...')
# train_dataset = get_dataset(config.dataset.train)
# val_dataset = get_dataset(config.dataset.val)
train_dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir,
name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset)
# val_dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
# name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True, collate_fn=PaddingCollate(), num_workers=args.num_workers, pin_memory=True)
train_iterator = inf_iterator(train_loader)
# val_loader = DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, collate_fn=PaddingCollate(), num_workers=args.num_workers)
logger.info('Train %d | Val %d' % (len(train_dataset), len(train_dataset)))
# Model
logger.info('Building model...')
# model = get_model(config.model).to(args.device)
model = FlowModel(config.model).to(args.device)
# wandb.watch(model,log='all',log_freq=1)
logger.info('Number of parameters: %d' % count_parameters(model))
# Optimizer & Scheduler
optimizer = get_optimizer(config.train.optimizer, model)
scheduler = get_scheduler(config.train.scheduler, optimizer)
optimizer.zero_grad()
it_first = 1
# Resume
if args.resume is not None:
logger.info('Resuming from checkpoint: %s' % args.resume)
ckpt = torch.load(args.resume, map_location=args.device)
it_first = ckpt['iteration'] # + 1
model.load_state_dict(ckpt['model'])
logger.info('Resuming optimizer states...')
optimizer.load_state_dict(ckpt['optimizer'])
logger.info('Resuming scheduler states...')
scheduler.load_state_dict(ckpt['scheduler'])
def train(it):
time_start = current_milli_time()
model.train()
# Prepare data
batch = recursive_to(next(train_iterator), args.device)
# Forward pass
# loss_dict, metric_dict = model.get_loss(batch) # get loss and metrics
loss_dict = model(batch) # get loss and metrics
loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
# loss = loss / config.train.accum_grad
time_forward_end = current_milli_time()
if torch.isnan(loss):
print('NAN Loss!')
torch.save({'batch':batch,'loss':loss,'loss_dict':loss_dict,'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'iteration': it,},os.path.join(log_dir,'nan.pt'))
loss = torch.tensor(0.,requires_grad=True).to(loss.device)
loss.backward()
# rescue for nan grad
for param in model.parameters():
if param.grad is not None:
if torch.isnan(param.grad).any():
param.grad[torch.isnan(param.grad)] = 0
orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
# Backward
# if it % config.train.accum_grad ==0:
optimizer.step()
optimizer.zero_grad()
time_backward_end = current_milli_time()
# Logging
scalar_dict = {}
# scalar_dict.update(metric_dict['scalar'])
scalar_dict.update({
'grad': orig_grad_norm,
'lr': optimizer.param_groups[0]['lr'],
'time_forward': (time_forward_end - time_start) / 1000,
'time_backward': (time_backward_end - time_forward_end) / 1000,
})
log_losses(loss, loss_dict, scalar_dict, it=it, tag='train', logger=logger)
def validate(it):
scalar_accum = ScalarMetricAccumulator()
with torch.no_grad():
model.eval()
for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)):
# Prepare data
batch = recursive_to(batch, args.device)
# Forward pass
# loss_dict, metric_dict = model.get_loss(batch)
loss_dict = model(batch)
loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
scalar_accum.add(name='loss', value=loss, batchsize=len(batch['aa']), mode='mean')
for k, v in loss_dict['scalar'].items():
scalar_accum.add(name=k, value=v, batchsize=len(batch['aa']), mode='mean')
avg_loss = scalar_accum.get_average('loss')
summary = scalar_accum.log(it, 'val', logger=logger, writer=writer)
for k,v in summary.items():
wandb.log({f'val/{k}': v}, step=it)
# Trigger scheduler
if config.train.scheduler.type == 'plateau':
scheduler.step(avg_loss)
else:
scheduler.step()
return avg_loss
try:
for it in range(it_first, config.train.max_iters + 1):
train(it)
# if it % config.train.val_freq == 0:
# avg_val_loss = validate(it)
# if not args.debug:
if it % config.train.val_freq == 0:
ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
torch.save({
'config': config,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'iteration': it,
# 'avg_val_loss': avg_val_loss,
}, ckpt_path)
except KeyboardInterrupt:
logger.info('Terminating...')