|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| class ProbabilityPathTracer: |
| def __init__(self, oracle_model, tokenizer, device): |
| self.oracle = oracle_model |
| self.tokenizer = tokenizer |
| self.device = device |
| self.mask_id = tokenizer.mask_token_id |
| self.history = {} |
|
|
| @torch.inference_mode() |
| def compute_loglikeli(self, xt): |
| is_revealed = (xt != self.mask_id) |
| |
| if not is_revealed.any(): |
| return 0.0 |
|
|
| |
| logits = self.oracle( |
| input_ids=xt, |
| attention_mask=torch.ones_like(xt, device=xt.device) |
| ).logits |
| |
| |
| nll = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| xt.view(-1), |
| reduction='none' |
| ) |
| |
| nll = nll.view(xt.shape) |
| |
| |
| avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1) |
| |
| return avg_ll.item() |
|
|
| def log_step(self, xt, step_idx): |
| score = self.compute_loglikeli(xt) |
| self.history[f"trace_step_{step_idx}"] = score |
|
|
| def get_trace(self): |
| return self.history |