File size: 3,033 Bytes
89874f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

class ShieldFilter:
    def __init__(self, model_path="LH-Tech-AI/Shield-82M"):
        print(f"Loading Shield-82M from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForTokenClassification.from_pretrained(model_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

        self.group_map = {
            "FIRSTNAME": "PERSON", "MIDDLENAME": "PERSON", "LASTNAME": "PERSON",
            "BUILDINGNUMBER": "ADDRESS", "STREET": "ADDRESS", "CITY": "ADDRESS", 
            "STATE": "ADDRESS", "ZIPCODE": "ADDRESS", "SECONDARYADDRESS": "ADDRESS",
            "EMAIL": "EMAIL", "PHONENUMBER": "PHONE", "PHONEIMEI": "PHONE",
            "DATE": "DOB", "TIME": "DOB"
        }

    def protect(self, text):
        inputs = self.tokenizer(
            text, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512, 
            return_offsets_mapping=True
        ).to(self.device)
        
        offsets = inputs.pop("offset_mapping")[0].cpu().numpy()
        
        with torch.no_grad():
            outputs = self.model(**inputs).logits
            
        predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()
        id2label = self.model.config.id2label
        
        spans_to_replace = []
        current_group = None
        start_char = -1
        last_char = -1
        
        for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):
            if offset[0] == 0 and offset[1] == 0:
                continue
                
            label = id2label[pred_id]
            
            if label == "O":
                if current_group is not None:
                    spans_to_replace.append((start_char, last_char, current_group))
                    current_group = None
            else:
                group_tag = self.group_map.get(label, label)
                
                if current_group != group_tag:
                    if current_group is not None:
                        spans_to_replace.append((start_char, last_char, current_group))
                    current_group = group_tag
                    start_char = offset[0]
                    
                last_char = offset[1]
                
        if current_group is not None:
            spans_to_replace.append((start_char, last_char, current_group))
            
        filtered_text = text
        for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):
            filtered_text = filtered_text[:start] + f"[{tag}]" + filtered_text[end:]
            
        return filtered_text

if __name__ == "__main__":
    shield = ShieldFilter()
    sample = "My name is John Doe. Email: john@example.com. Phone: +49 123 45678."
    print(f"Original: {sample}")
    print(f"Protected: {shield.protect(sample)}")