data-archetype commited on
Commit
588c6b6
·
verified ·
1 Parent(s): c70ce92

Patch VP logSNR stability in exported full_capacitor code

Browse files
Files changed (1) hide show
  1. full_capacitor/encoder.py +6 -5
full_capacitor/encoder.py CHANGED
@@ -10,6 +10,7 @@ from __future__ import annotations
10
  from dataclasses import dataclass
11
 
12
  import torch
 
13
  from torch import Tensor, nn
14
 
15
  from .fcdm_block import FCDMBlock
@@ -30,15 +31,15 @@ class EncoderPosterior:
30
 
31
  @property
32
  def alpha(self) -> Tensor:
33
- """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
34
  logsnr_fp32 = self.logsnr.to(torch.float32)
35
- return torch.sigmoid(logsnr_fp32).sqrt()
36
 
37
  @property
38
  def sigma(self) -> Tensor:
39
- """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
40
  logsnr_fp32 = self.logsnr.to(torch.float32)
41
- return torch.sigmoid(-logsnr_fp32).sqrt()
42
 
43
  def mode(self) -> Tensor:
44
  """Posterior mode in token space: alpha * mean, computed in float32."""
@@ -130,7 +131,7 @@ class Encoder(nn.Module):
130
  mean, logsnr = projection.chunk(2, dim=1)
131
  mean = self.norm_out(mean)
132
  logsnr_fp32 = logsnr.to(torch.float32)
133
- alpha = torch.sigmoid(logsnr_fp32).sqrt()
134
  return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
135
  z = self.norm_out(projection)
136
  return z
 
10
  from dataclasses import dataclass
11
 
12
  import torch
13
+ import torch.nn.functional as F
14
  from torch import Tensor, nn
15
 
16
  from .fcdm_block import FCDMBlock
 
31
 
32
  @property
33
  def alpha(self) -> Tensor:
34
+ """VP signal coefficient computed stably in float32."""
35
  logsnr_fp32 = self.logsnr.to(torch.float32)
36
+ return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
37
 
38
  @property
39
  def sigma(self) -> Tensor:
40
+ """VP noise coefficient computed stably in float32."""
41
  logsnr_fp32 = self.logsnr.to(torch.float32)
42
+ return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
43
 
44
  def mode(self) -> Tensor:
45
  """Posterior mode in token space: alpha * mean, computed in float32."""
 
131
  mean, logsnr = projection.chunk(2, dim=1)
132
  mean = self.norm_out(mean)
133
  logsnr_fp32 = logsnr.to(torch.float32)
134
+ alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
135
  return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
136
  z = self.norm_out(projection)
137
  return z