| import torch |
| import data_utils as du |
|
|
| def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"): |
| coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) |
| two_channel_image = du.complex_to_two_channel_image(coil_complex_image) |
| two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device) |
| autoencoder = autoencoder.to(device) |
| with torch.no_grad(): |
| autoencoder_output = autoencoder.encode(two_channel_tensor) |
| latents = autoencoder_output.latent_dist.mean |
| decoded_image = autoencoder.decode(latents).sample |
| recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy()) |
| input = coil_complex_image |
| return input, recon |
|
|
| def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"): |
| coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) |
| three_channel_image = du.create_three_channel_image(coil_complex_image) |
| three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device) |
| autoencoder = autoencoder.to(device) |
| with torch.no_grad(): |
| autoencoder_output = autoencoder.encode(three_channel_tensor) |
| latents = autoencoder_output.latent_dist.mean |
| decoded_image = autoencoder.decode(latents).sample |
| recon = decoded_image[0].detach().cpu().numpy() |
| input = three_channel_image |
| return input, recon |
|
|