JTriggerFish commited on
Commit ·
3196863
1
Parent(s): f5cfba7
Fix posterior VP interpolation to use float32 precision
Browse files- README.md +1 -0
- fcdm_diffae/encoder.py +18 -13
README.md
CHANGED
|
@@ -17,6 +17,7 @@ library_name: fcdm_diffae
|
|
| 17 |
|
| 18 |
| Date | Change |
|
| 19 |
|------|--------|
|
|
|
|
| 20 |
| 2026-04-07 | Rename package `capacitor_diffae` → `fcdm_diffae`, class `FCDMDiffAE`; encode() now returns whitened latents, decode() dewhitens internally |
|
| 21 |
| 2026-04-06 | Initial release |
|
| 22 |
|
|
|
|
| 17 |
|
| 18 |
| Date | Change |
|
| 19 |
|------|--------|
|
| 20 |
+
| 2026-04-08 | Fix posterior VP interpolation to use float32 precision (was using model dtype) |
|
| 21 |
| 2026-04-07 | Rename package `capacitor_diffae` → `fcdm_diffae`, class `FCDMDiffAE`; encode() now returns whitened latents, decode() dewhitens internally |
|
| 22 |
| 2026-04-06 | Initial release |
|
| 23 |
|
fcdm_diffae/encoder.py
CHANGED
|
@@ -30,24 +30,28 @@ class EncoderPosterior:
|
|
| 30 |
|
| 31 |
@property
|
| 32 |
def alpha(self) -> Tensor:
|
| 33 |
-
"""VP signal coefficient: sqrt(sigmoid(logsnr))."""
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
@property
|
| 37 |
def sigma(self) -> Tensor:
|
| 38 |
-
"""VP noise coefficient: sqrt(sigmoid(-logsnr))."""
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
def mode(self) -> Tensor:
|
| 42 |
-
"""Posterior mode in token space: alpha * mean."""
|
| 43 |
-
return self.alpha.to(dtype=self.mean.dtype)
|
| 44 |
|
| 45 |
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
|
| 46 |
-
"""Sample from posterior: alpha * mean + sigma * eps."""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
class Encoder(nn.Module):
|
|
@@ -123,7 +127,8 @@ class Encoder(nn.Module):
|
|
| 123 |
if self.bottleneck_posterior_kind == "diagonal_gaussian":
|
| 124 |
mean, logsnr = projection.chunk(2, dim=1)
|
| 125 |
mean = self.norm_out(mean)
|
| 126 |
-
|
| 127 |
-
|
|
|
|
| 128 |
z = self.norm_out(projection)
|
| 129 |
return z
|
|
|
|
| 30 |
|
| 31 |
@property
|
| 32 |
def alpha(self) -> Tensor:
|
| 33 |
+
"""VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
|
| 34 |
+
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 35 |
+
return torch.sigmoid(logsnr_fp32).sqrt()
|
| 36 |
|
| 37 |
@property
|
| 38 |
def sigma(self) -> Tensor:
|
| 39 |
+
"""VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
|
| 40 |
+
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 41 |
+
return torch.sigmoid(-logsnr_fp32).sqrt()
|
| 42 |
|
| 43 |
def mode(self) -> Tensor:
|
| 44 |
+
"""Posterior mode in token space: alpha * mean, computed in float32."""
|
| 45 |
+
return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
|
| 46 |
|
| 47 |
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
|
| 48 |
+
"""Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
|
| 49 |
+
mean_fp32 = self.mean.to(torch.float32)
|
| 50 |
+
eps = torch.randn(
|
| 51 |
+
mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32,
|
| 52 |
+
generator=generator,
|
| 53 |
+
)
|
| 54 |
+
return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
|
| 55 |
|
| 56 |
|
| 57 |
class Encoder(nn.Module):
|
|
|
|
| 127 |
if self.bottleneck_posterior_kind == "diagonal_gaussian":
|
| 128 |
mean, logsnr = projection.chunk(2, dim=1)
|
| 129 |
mean = self.norm_out(mean)
|
| 130 |
+
logsnr_fp32 = logsnr.to(torch.float32)
|
| 131 |
+
alpha = torch.sigmoid(logsnr_fp32).sqrt()
|
| 132 |
+
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
|
| 133 |
z = self.norm_out(projection)
|
| 134 |
return z
|