| | import os |
| | import sys |
| | import pathlib |
| | CURRENT_DIR = pathlib.Path(__file__).parent |
| | sys.path.append(str(CURRENT_DIR)) |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils import data |
| | import torchvision.transforms as transform |
| | import torch.nn.functional as F |
| | import onnxruntime |
| | from PIL import Image |
| | import argparse |
| | import datasets.utils as utils |
| |
|
| | class Configs(): |
| | def __init__(self): |
| | parser = argparse.ArgumentParser(description='PyTorch SemanticFPN model') |
| | |
| |
|
| | parser.add_argument('--dataset', type=str, default='citys', help='dataset name (default: citys)') |
| | parser.add_argument('--onnx_path', type=str, default='FPN_int_NHWC.onnx', help='onnx path') |
| | parser.add_argument('--num-classes', type=int, default=19, |
| | help='the classes numbers (default: 19 for cityscapes)') |
| | parser.add_argument('--test-folder', type=str, default='./data/cityscapes', |
| | help='test dataset folder (default: ./data/cityscapes)') |
| |
|
| | parser.add_argument('--base-size', type=int, default=1024, help='the shortest image size') |
| | parser.add_argument('--crop-size', type=int, default=256, help='input size for inference') |
| | parser.add_argument('--batch-size', type=int, default=1, metavar='N', |
| | help='input batch size for testing (default: 10)') |
| | |
| | parser.add_argument('--ipu', action='store_true', help='use ipu') |
| | parser.add_argument('--provider_config', type=str, default=None, help='provider config path') |
| |
|
| | self.parser = parser |
| |
|
| | def parse(self): |
| | args = self.parser.parse_args() |
| | print(args) |
| | return args |
| |
|
| |
|
| | def build_data(args, subset_len=None, sample_method='random'): |
| | from datasets import get_segmentation_dataset |
| | input_transform = transform.Compose([ |
| | transform.ToTensor(), |
| | transform.Normalize([.485, .456, .406], [.229, .224, .225])]) |
| |
|
| | data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size} |
| |
|
| | testset = get_segmentation_dataset(args.dataset, split='val', mode='testval', root=args.test_folder, |
| | **data_kwargs) |
| | if subset_len: |
| | assert subset_len <= len(testset) |
| | if sample_method == 'random': |
| | testset = torch.utils.data.Subset(testset, random.sample(range(0, len(test_data)), subset_len)) |
| | else: |
| | testset = torch.utils.data.Subset(testset, list(range(subset_len))) |
| | |
| | test_data = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False) |
| | return test_data |
| |
|
| |
|
| | def eval_miou(data,path="FPN_int.onnx", device='cpu'): |
| | confmat = utils.ConfusionMatrix(args.num_classes) |
| | tbar = tqdm(data, desc='\r') |
| | if args.ipu: |
| | providers = ["VitisAIExecutionProvider"] |
| | provider_options = [{"config_file": args.provider_config}] |
| | else: |
| | providers = ['CPUExecutionProvider'] |
| | provider_options = None |
| | session = onnxruntime.InferenceSession(path, providers=providers, provider_options=provider_options) |
| |
|
| | for i, (image, target) in enumerate(tbar): |
| | image, target = image.to(device), target.to(device) |
| | ort_input = {session.get_inputs()[0].name: image.cpu().numpy().transpose(0,2,3,1)} |
| | ort_output = session.run(None, ort_input)[0].transpose(0,3,1,2) |
| | if isinstance(ort_output, (tuple, list)): |
| | ort_output = ort_output[0] |
| | ort_output = torch.from_numpy(ort_output).to(device) |
| | if ort_output.size()[2:] != target.size()[1:]: |
| | ort_output = F.interpolate(ort_output, size=target.size()[1:], mode='bilinear', align_corners=True) |
| |
|
| | confmat.update(target.flatten(), ort_output.argmax(1).flatten()) |
| |
|
| | confmat.reduce_from_all_processes() |
| | print('Evaluation Metric: ') |
| | print(confmat) |
| |
|
| |
|
| | def main(args): |
| | print('===> Evaluation mIoU: ') |
| | test_data = build_data(args) |
| | eval_miou(test_data, args.onnx_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = Configs().parse() |
| | main(args) |
| |
|
| |
|