| | import wandb |
| |
|
| | import os |
| | import shutil |
| | import argparse |
| | import torch |
| | import torch.cuda.amp as amp |
| | import torch.distributed as distrib |
| | from torch.nn.utils import clip_grad_norm_ |
| | from torch.utils.data import DataLoader, random_split |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from tqdm.auto import tqdm |
| | |
| | |
| |
|
| | from pepflow.utils.vc import get_version, has_changes |
| | from pepflow.utils.misc import BlackHole, inf_iterator, load_config, seed_all, get_logger, get_new_log_dir, current_milli_time |
| | from pepflow.utils.data import PaddingCollate |
| | from pepflow.utils.train import ScalarMetricAccumulator, count_parameters, get_optimizer, get_scheduler, log_losses, recursive_to, sum_weighted_losses |
| |
|
| | from models_con.pep_dataloader import PepDataset |
| | |
| |
|
| | from models_con.flow_model import FlowModel |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--config', type=str, default='./configs/angle/learn_angle.yaml') |
| | parser.add_argument('--logdir', type=str, default="./logs") |
| | parser.add_argument('--debug', action='store_true', default=False) |
| | parser.add_argument('--device', type=str, default='cuda:0') |
| | parser.add_argument('--num_workers', type=int, default=4) |
| | parser.add_argument('--tag', type=str, default='') |
| | parser.add_argument('--resume', type=str, default=None) |
| | parser.add_argument('--name', type=str, default='pepflow') |
| | args = parser.parse_args() |
| |
|
| | |
| | branch, version = get_version() |
| | version_short = '%s-%s' % (branch, version[:7]) |
| | if has_changes() and not args.debug: |
| | c = input('Start training anyway? (y/n) ') |
| | if c != 'y': |
| | exit() |
| |
|
| | |
| | config, config_name = load_config(args.config) |
| | seed_all(config.train.seed) |
| | config['device'] = args.device |
| |
|
| | |
| | if args.debug: |
| | logger = get_logger('train', None) |
| | writer = BlackHole() |
| | else: |
| | run = wandb.init(project=args.name, config=config, name='%s[%s]' % (config_name, args.tag)) |
| | if args.resume: |
| | log_dir = os.path.dirname(os.path.dirname(args.resume)) |
| | else: |
| | log_dir = get_new_log_dir(args.logdir, prefix='%s[%s]' % (config_name, version_short), tag=args.tag) |
| | with open(os.path.join(log_dir, 'commit.txt'), 'w') as f: |
| | f.write(branch + '\n') |
| | f.write(version + '\n') |
| | ckpt_dir = os.path.join(log_dir, 'checkpoints') |
| | if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) |
| | logger = get_logger('train', log_dir) |
| | |
| | |
| | if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))): |
| | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) |
| | logger.info(args) |
| | logger.info(config) |
| |
|
| | |
| | logger.info('Loading datasets...') |
| | |
| | |
| | train_dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir, |
| | name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset) |
| | |
| | |
| | train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True, collate_fn=PaddingCollate(), num_workers=args.num_workers, pin_memory=True) |
| | train_iterator = inf_iterator(train_loader) |
| | |
| | logger.info('Train %d | Val %d' % (len(train_dataset), len(train_dataset))) |
| |
|
| | |
| | logger.info('Building model...') |
| | |
| | model = FlowModel(config.model).to(args.device) |
| | |
| | logger.info('Number of parameters: %d' % count_parameters(model)) |
| |
|
| | |
| | optimizer = get_optimizer(config.train.optimizer, model) |
| | scheduler = get_scheduler(config.train.scheduler, optimizer) |
| | optimizer.zero_grad() |
| | it_first = 1 |
| |
|
| | |
| | if args.resume is not None: |
| | logger.info('Resuming from checkpoint: %s' % args.resume) |
| | ckpt = torch.load(args.resume, map_location=args.device) |
| | it_first = ckpt['iteration'] |
| | model.load_state_dict(ckpt['model']) |
| | logger.info('Resuming optimizer states...') |
| | optimizer.load_state_dict(ckpt['optimizer']) |
| | logger.info('Resuming scheduler states...') |
| | scheduler.load_state_dict(ckpt['scheduler']) |
| |
|
| | def train(it): |
| | time_start = current_milli_time() |
| | model.train() |
| |
|
| | |
| | batch = recursive_to(next(train_iterator), args.device) |
| |
|
| | |
| | |
| | loss_dict = model(batch) |
| | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) |
| | |
| | time_forward_end = current_milli_time() |
| |
|
| | if torch.isnan(loss): |
| | print('NAN Loss!') |
| | torch.save({'batch':batch,'loss':loss,'loss_dict':loss_dict,'model': model.state_dict(), |
| | 'optimizer': optimizer.state_dict(), |
| | 'scheduler': scheduler.state_dict(), |
| | 'iteration': it,},os.path.join(log_dir,'nan.pt')) |
| | loss = torch.tensor(0.,requires_grad=True).to(loss.device) |
| |
|
| | loss.backward() |
| |
|
| | |
| | for param in model.parameters(): |
| | if param.grad is not None: |
| | if torch.isnan(param.grad).any(): |
| | param.grad[torch.isnan(param.grad)] = 0 |
| |
|
| | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) |
| |
|
| | |
| | |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | time_backward_end = current_milli_time() |
| |
|
| | |
| | scalar_dict = {} |
| | |
| | scalar_dict.update({ |
| | 'grad': orig_grad_norm, |
| | 'lr': optimizer.param_groups[0]['lr'], |
| | 'time_forward': (time_forward_end - time_start) / 1000, |
| | 'time_backward': (time_backward_end - time_forward_end) / 1000, |
| | }) |
| | log_losses(loss, loss_dict, scalar_dict, it=it, tag='train', logger=logger) |
| |
|
| | def validate(it): |
| | scalar_accum = ScalarMetricAccumulator() |
| | with torch.no_grad(): |
| | model.eval() |
| |
|
| | for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)): |
| | |
| | batch = recursive_to(batch, args.device) |
| |
|
| | |
| | |
| | loss_dict = model(batch) |
| | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) |
| | scalar_accum.add(name='loss', value=loss, batchsize=len(batch['aa']), mode='mean') |
| | for k, v in loss_dict['scalar'].items(): |
| | scalar_accum.add(name=k, value=v, batchsize=len(batch['aa']), mode='mean') |
| | |
| | avg_loss = scalar_accum.get_average('loss') |
| | summary = scalar_accum.log(it, 'val', logger=logger, writer=writer) |
| | for k,v in summary.items(): |
| | wandb.log({f'val/{k}': v}, step=it) |
| | |
| | if config.train.scheduler.type == 'plateau': |
| | scheduler.step(avg_loss) |
| | else: |
| | scheduler.step() |
| | return avg_loss |
| |
|
| | try: |
| | for it in range(it_first, config.train.max_iters + 1): |
| | train(it) |
| | |
| | |
| | |
| | if it % config.train.val_freq == 0: |
| | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) |
| | torch.save({ |
| | 'config': config, |
| | 'model': model.state_dict(), |
| | 'optimizer': optimizer.state_dict(), |
| | 'scheduler': scheduler.state_dict(), |
| | 'iteration': it, |
| | |
| | }, ckpt_path) |
| | except KeyboardInterrupt: |
| | logger.info('Terminating...') |