File size: 7,228 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import gc
import torch

import torch.nn.functional as F
import lightning as pl

from typing import Optional
from transformers import AutoModelForMaskedLM, AutoTokenizer

from src.utils.model_utils import _print
from src.utils.optimizer_utils import get_optimizer, get_scheduler


class MembraneDiffusion(pl.LightningModule):
    def __init__(self, config):
        """
        Args:
            config (OmegaConf): config.yaml file with all training parameters
        """
        super().__init__()
        self.config = config
        self.save_hyperparameters(logger=True)

        self.model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_evoflow, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)

        self.mask_id = self.tokenizer.mask_token_id
        self.pad_id = self.tokenizer.pad_token_id

    def forward(self, input_ids, attention_mask, guidance: Optional[bool] = False):
        """
        Forward pass through language model.

        Args:
            - input_ids (torch.Tensor): [B, L], token ids
            - attention_mask (torch.Tensor): [B, L], pad/non-pad binary mask
        Returns:
            - logits (torch.Tensor): [B, L, V], unnormalized model outputs
        """
        return self.model(input_ids=input_ids, attention_mask=attention_mask).logits

    # -------# Diffusion #-------- #
    def step(self, batch):
        labels = batch['input_ids'] 

        # Forward diffusion
        t1 = self.sample_t(labels) # Sample timestep
        xt, _ = self.noise_x0(labels, t1, maskable_mask=self.is_maskable(labels)) # Noise sequence
        logits = self.forward(input_ids=xt, attention_mask=batch['attention_mask']) # Model logits

        # Loss computation
        weight = self.get_weight(t1, weight_type=self.config.lm.weight_type)  # RDM uses a weighted cross entropy loss
        loss_out = self.compute_loss(logits, labels, weight) # Compute loss and ppl

        self.cleanup()
        return loss_out['loss'], loss_out['ppl']
    
    def sample_t(self, labels, rdm_coupling=False):
        """
        Sample diffusion timesteps. Non-coupling RDM only uses one timestep (t1).
        """
        timesteps = torch.randint(
            1,
            self.config.lm.num_diffusion_timesteps + 1,
            (2 if rdm_coupling else 1) * (labels.size(0),),
            device=labels.device
        )

        if rdm_coupling:
            return timesteps.chunk(2)
        return timesteps

    def noise_x0(self, x0, t1, maskable_mask):
        """
        Apply noise to the initial sequence x0.
        """
        u = torch.rand_like(x0, dtype=torch.float) 
        t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask
        x_t1 = x0.masked_fill(t1_mask, self.mask_id)
        return x_t1, t1_mask

    def get_weight(self, t, weight_type):
        """
        Compute the weighting factor for the RDM-derived loss (weighted cross-entropy).
        """
        num_timesteps = self.config.lm.num_diffusion_timesteps
        weight = {
            "linear": (num_timesteps - (t - 1)),  # num_timesteps * (1 - (t-1)/num_timesteps)
            "constant": num_timesteps * torch.ones_like(t),
        }[weight_type][:, None].float() / num_timesteps
        return weight.squeeze()

    def compute_loss(self, logits, labels, weight):
        """
        Compute the cross entropy loss per sample.
        First, compute the per-token loss (with no reduction), then reduce over the sequence length for each sample.
        Finally, average over the batch.
        
        Args:
            logits (torch.Tensor): [B, L, vocab_size], unnormalized model outputs
            labels (torch.Tensor): [B, L], target labels (with padding tokens as -100)
            weight (torch.Tensor): [B, 1], per-sample weight for loss calculation
        Returns:
            loss (torch.Tensor): Averaged loss over the batch
            logging_output (torch.Tensor): Dictionary of values for logging
        """

        loss_token = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            reduction='none',
            ignore_index=self.pad_id,
        )
        
        loss_token = loss_token.view(labels.size(0), labels.size(1)) # Reshape to [B, L]
        valid_mask = (labels != self.pad_id)
        
        sample_loss = (loss_token * valid_mask.float()).sum(dim=1) / valid_mask.float().sum(dim=1).clamp(min=1)
        sample_loss *= weight # RDM weighting
        ppl = torch.exp(sample_loss)

        return {'ppl': ppl.mean(), 'loss': sample_loss.mean()}
    

    # -------# Training / Evaluation #-------- #
    def training_step(self, batch):
        loss, ppl = self.step(batch)
        self.log("train/loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True)
        self.log("train/ppl", ppl.item(), on_step=True, on_epoch=False, prog_bar=False)
        return loss
    
    def validation_step(self, batch):
        loss, ppl = self.step(batch)
        self.cleanup()
        self.log("val/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def test_step(self, batch):
        loss, ppl = self.step(batch)
        self.cleanup()
        self.log('test/loss', loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("test/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss


    # -------# Helper methods #-------- #
    def is_maskable(self, input_ids: torch.Tensor):
        return (
            (input_ids != self.tokenizer.pad_token_id) 
            & (input_ids != self.tokenizer.cls_token_id)
            & (input_ids != self.tokenizer.eos_token_id)
        )

    def configure_optimizers(self):
        """
        Choosing which optimizer and lr scheduler to use.
        """
        optimizer = get_optimizer(self.config, self.model.parameters())
        lr_scheduler, extra_kwargs = get_scheduler(self.config, optimizer) # Polynomial scheduler
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": lr_scheduler, **extra_kwargs},
        }

    def validate_config(self):
        assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
        assert self.config.training.mode in ["train", "test", "resume_from_checkpoint"], "invalid mode"

    def get_state_dict(self, ckpt_path):
        def remove_model_prefix(state_dict):
            for k, v in state_dict.items():
                if "model." in k:
                    k.replace('model.', '')
            return state_dict  

        checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
        state_dict = checkpoint.get("state_dict", checkpoint)

        if any(k.startswith("model.") for k in state_dict.keys()):
            state_dict = remove_model_prefix(state_dict)
        
        return state_dict

    def cleanup(self):
        torch.cuda.empty_cache()
        gc.collect()