File size: 3,966 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch import nn

from model.cache_utils import FLACache
from model.config import PlayHeadConfig
from model.prediction_head import CircularHead, LogitsHead
from model.simple_gla import SimpleGLABlock


class PlayHead(nn.Module):
    def __init__(self, cfg: PlayHeadConfig):
        super().__init__()
        self.cycle_len = cfg.cycle_len
        self.num_sink_tokens = cfg.num_sink_tokens
        self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim)
        self.avg_pool_stride = cfg.avg_pool_stride
        self.net = nn.ModuleList(
            [
                SimpleGLABlock(
                    dim=cfg.dim,
                    num_heads=cfg.dim // 128,
                    layer_idx=i,
                    expand_k=0.5,
                    expand_v=1.0,
                    use_short_conv=True,
                    ffn_expansion_factor=4,
                )
                for i in range(cfg.num_layers)
            ]
        )

        self.logits_head = (
            LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None
        )
        self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None

    def forward(
        self,
        cross_attention_weights: torch.Tensor,
        target: torch.Tensor,
        mask: torch.Tensor | None = None,
    ):
        B, T, A = cross_attention_weights.shape
        # if self.cross_attention_reduction == "sum":
        #    cross_attention_weights = cross_attention_weights.sum(1)

        device = cross_attention_weights.device
        pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
        sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
        sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
        x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
        for block in self.net:
            x = block(x)

        losses = dict()
        if self.logits_head is not None:
            losses |= self.logits_head.compute_loss(x, target.long(), mask=mask)
        if self.circular_head is not None:
            losses |= self.circular_head.compute_loss(x, target, mask=mask)

        return losses

    def init_cache(self):
        return FLACache(num_states=len(self.net))

    def predict(
        self,
        cross_attention_weights: torch.Tensor,
        previous_position: torch.Tensor | None = None,
        cache: FLACache | None = None,
    ):
        avg_pool_ca = torch.nn.functional.avg_pool1d(
            cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2),
            self.avg_pool_stride,
            stride=self.avg_pool_stride,
            ceil_mode=True,
        ).transpose(-1, -2)

        sink_ca = cross_attention_weights[:, : self.num_sink_tokens]
        cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1)

        B, T, A = cross_attention_weights.shape
        device = cross_attention_weights.device
        pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
        sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
        sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
        x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
        for block in self.net:
            x = block(x, cache=cache)
        if self.logits_head is not None:
            logits = self.logits_head(x)
            pred_position = torch.argmax(logits, -1)

            if previous_position is not None:
                current_angle, previous_angle = map(
                    lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len),
                    (pred_position, previous_position),
                )
                diff = current_angle / previous_angle
                step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long()
                return pred_position, step

            return pred_position