File size: 8,700 Bytes
901446a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcbc792
901446a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcbc792
901446a
fcbc792
 
 
 
901446a
 
 
 
fcbc792
 
901446a
 
 
fcbc792
901446a
fcbc792
901446a
 
fcbc792
 
 
 
 
 
 
 
 
 
 
 
901446a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from enum import StrEnum

from nltk.tokenize import TreebankWordDetokenizer

import torch
import torch.nn as nn
from transformers import (
    AutoConfig,
    AutoModel,
    BatchEncoding,
    DebertaV2Model,
    PreTrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.modeling_outputs import TokenClassifierOutput

class ModelURI(StrEnum):
    BASE = "microsoft/deberta-v3-base"
    LARGE = "microsoft/deberta-v3-large"

class ConSec(PreTrainedModel):
    def __init__(self, config: PreTrainedConfig):
        super().__init__(config)
        if config.init_basemodel:
            self.BaseModel = AutoModel.from_pretrained(config.name_or_path, 
                                                       device_map="auto",
                                                       dtype=torch.bfloat16)
            self.config.vocab_size += 2
            self.BaseModel.resize_token_embeddings(self.config.vocab_size)
        else:
            self.BaseModel = DebertaV2Model(config)
        config.init_basemodel = False

        self.loss = nn.CrossEntropyLoss()
        self.post_init()

    @classmethod
    def from_base(cls, base_id: ModelURI):
        config = AutoConfig.from_pretrained(base_id)
        config.init_basemodel = True
        return cls(config)
    
    def add_special_tokens(self, start: int, end: int, gloss: int):
        self.config.start_token = start
        self.config.end_token = end
        self.config.gloss_token = gloss
    
    def forward(self,
                input_ids: torch.Tensor | None = None, 
                attention_mask: torch.Tensor | None = None, 
                token_type_ids: torch.Tensor | None = None, 
                position_ids: torch.Tensor | None = None, 
                inputs_embeds: torch.Tensor | None = None, 
                labels: torch.Tensor | None = None, 
                output_attentions: bool | None = None, 
                output_hidden_states: bool | None = None, 
                return_dict: bool | None = None,
                **kwargs)->TokenClassifierOutput:
        base_model_output = self.BaseModel(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids,
                                           position_ids=position_ids,
                                           inputs_embeds=inputs_embeds,
                                           output_attentions=output_attentions,
                                           output_hidden_states=output_hidden_states,
                                           **kwargs)
        token_vectors = base_model_output.last_hidden_state
        selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype)
        starts = (input_ids == self.config.start_token).nonzero()
        ends = (input_ids == self.config.end_token).nonzero()
        for startpos, endpos in zip(starts, ends, strict=True):
            selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0
        entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection)
        gloss_vectors = self.gloss_vectors(
            input_ids, starts, position_ids, token_vectors
        )
        logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors)

        return TokenClassifierOutput(
            logits=logits,
            loss=self.loss(logits, labels) if labels is not None else None,
            hidden_states=base_model_output.hidden_states if output_hidden_states else None,
            attentions=base_model_output.attentions if output_attentions else None,
        )
    
    def gloss_vectors(self,input_ids: torch.Tensor,
                      starts: torch.Tensor,
                      position_ids: torch.Tensor,
                      token_vectors: torch.Tensor)->torch.Tensor:
        with self.device:
            vectors = [token_vectors[i,((position_ids[i]==position_ids[i,j])&(input_ids[i]==self.config.gloss_token))]
                       for (i,j) in starts]
            maxlen  = max(vector.shape[0] for vector in vectors)
            return torch.stack([torch.cat([vector,torch.zeros((maxlen-vector.shape[0],vector.shape[1]),
                                                              dtype=torch.bfloat16)])
                                for vector in vectors])
    
def json_sequencer(sentence:list[dict])->Generator[tuple[list[str], list[str], int]]:
    for site in sorted([{"span":i,
                        "n_candidates":len(chunk["candidates"])}
                        for (i,chunk) in enumerate(sentence)
                        if "candidates" in chunk],
                       key = lambda x: x["n_candidates"]):
        words = [word for chunk in sentence[:site["span"]]
                 for word in chunk["words"]]
        words.append("[START]")
        words.extend(sentence[site["span"]]["words"])
        words.append("[END]")
        words.extend([word for chunk in sentence[site["span"]+1:]
                      for word in chunk["words"]])
        yield (words,
               sentence[site["span"]]["candidates"],
               site["span"])
        
def json_labeller(sentence,tags):
    for tag in tags:
        sentence[tag["index"]]["label"]=tag["label"]
    return sentence
    
class ConSecTagger:
    def __init__(self,model,
                 tokenizer,
                 ontology,
                 sequencer=json_sequencer,
                 labeller=json_labeller):
        self.model = model
        self.tokenizer = tokenizer
        special_tokens = self.tokenizer.get_added_vocab()
        self.start_token = special_tokens["[START]"]
        self.gloss_token = special_tokens["[GLOSS]"]
        self.sequencer = sequencer
        self.detokenizer = TreebankWordDetokenizer()
        self.glosses = {synset.concept:synset.definition
                        for synset in ontology}
        self.label=labeller
        
        
    def __call__(self,sentence):
        already_tagged = []
        for (words,candidates,index) in self.sequencer(sentence):
            text = self.detokenizer.detokenize(words)
            glosses = ['']
            glosses.extend([self.glosses[candidate] for candidate in candidates])
            glosses.extend([self.glosses[previous["label"]] for previous in already_tagged])
            with self.model.device:
                tokens = self.tokenizer(text,"[GLOSS] ".join(glosses),
                                        return_tensors="pt")
                length = tokens.input_ids.shape[1]
                positions = torch.arange(length)
                place = (tokens.input_ids==self.start_token).nonzero(as_tuple=True)[1].item()
                wordpos = tokens.token_to_word(place)
                gloss_positions = [index.item() 
                                   for index in (tokens.input_ids==self.gloss_token).nonzero(as_tuple=True)[1]]
                gloss_positions.append(length)
                n_candidates = len(candidates)
                for (i,position) in enumerate(gloss_positions[:-1]):
                    if i<n_candidates:
                        end = (place + gloss_positions[i+1]-position)
                        positions[position:gloss_positions[i+1]] = torch.arange(place,end)
                    else:
                        known = already_tagged[i-n_candidates]
                        start = tokens.word_to_tokens(known["place"]).start
                        end = (start + gloss_positions[i+1] - position)
                        positions[position:gloss_positions[i+1]] = torch.arange(start,end)
                prediction = self.model(input_ids=tokens.input_ids,
                                    attention_mask=tokens.attention_mask,
                                    token_type_ids=tokens.token_type_ids,
                                    position_ids=positions.reshape((1,length)))
                try:
                    label = candidates[prediction.logits.argmax()]
                except IndexError:
                    print(text)
                    print(gloss_positions)
                    print([positions[pos].item() for pos in gloss_positions[:-1]])
                    print(already_tagged)
                    print(candidates)
                    print(prediction.logits)
                    print(prediction.logits.argmax())
                    raise
                already_tagged.append({"label":label,
                                       "place":wordpos,
                                       "index":index})
        return(self.label(sentence,already_tagged))