| |
| |
| |
| |
|
|
| import os |
| import io |
| import math |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
|
|
| from torch.utils.data import Dataset, DataLoader |
|
|
| import torchvision.transforms as transforms |
|
|
| from datasets import load_dataset |
|
|
| |
| |
| |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| BATCH_SIZE_TRAIN = 96 |
| BATCH_SIZE_VAL = 32 |
|
|
| EPOCHS = 20 |
|
|
| LR = 1e-3 |
| WEIGHT_DECAY = 1e-4 |
|
|
| MODEL_SAVE_PATH = "LookThem_V76_Full_LiteResidual.pth" |
|
|
| |
| |
| |
|
|
| transform_train = transforms.Compose([ |
| transforms.Lambda(lambda img: img.convert("RGB")), |
| transforms.Resize((256, 256)), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.485, 0.456, 0.406), |
| (0.229, 0.224, 0.225) |
| ) |
| ]) |
|
|
| transform_val = transforms.Compose([ |
| transforms.Lambda(lambda img: img.convert("RGB")), |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.485, 0.456, 0.406), |
| (0.229, 0.224, 0.225) |
| ) |
| ]) |
|
|
| |
| |
| |
|
|
| class ImageNet100ParquetDataset(Dataset): |
|
|
| def __init__(self, hf_subset, transform=None): |
|
|
| self.dataset = hf_subset |
| self.transform = transform |
|
|
| def __getitem__(self, index): |
|
|
| row = self.dataset[index] |
|
|
| img_data = row["image"] |
|
|
| if isinstance(img_data, dict) and "bytes" in img_data: |
| img = Image.open(io.BytesIO(img_data["bytes"])) |
|
|
| elif isinstance(img_data, Image.Image): |
| img = img_data |
|
|
| else: |
| img = Image.open(io.BytesIO(img_data)) |
|
|
| label = row["label"] |
|
|
| if self.transform: |
| img = self.transform(img) |
|
|
| return img, label |
|
|
| def __len__(self): |
|
|
| return len(self.dataset) |
|
|
| |
| |
| |
|
|
| print("📡 Loading ImageNet-100...") |
|
|
| raw_train = load_dataset( |
| "clane9/imagenet-100", |
| split="train" |
| ) |
|
|
| raw_val = load_dataset( |
| "clane9/imagenet-100", |
| split="validation" |
| ) |
|
|
| train_dataset = ImageNet100ParquetDataset( |
| raw_train, |
| transform=transform_train |
| ) |
|
|
| val_dataset = ImageNet100ParquetDataset( |
| raw_val, |
| transform=transform_val |
| ) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=BATCH_SIZE_TRAIN, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True |
| ) |
|
|
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=BATCH_SIZE_VAL, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True |
| ) |
|
|
| |
| |
| |
|
|
| class LookThemLayer(nn.Module): |
|
|
| def __init__(self, num_tokens, in_features, hidden_dim): |
|
|
| super().__init__() |
|
|
| self.num_tokens = num_tokens |
|
|
| self.mod1_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod1_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.mod2_w1 = nn.Parameter( |
| torch.randn(num_tokens, in_features, hidden_dim) |
| ) |
|
|
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
|
|
| self.mod2_w2 = nn.Parameter( |
| torch.randn(num_tokens, hidden_dim, 1) |
| ) |
|
|
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
|
|
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
|
|
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_(w, a=math.sqrt(5)) |
|
|
| def forward(self, x): |
|
|
| N = self.num_tokens |
|
|
| h1 = ( |
| torch.einsum( |
| "bti,tij->btj", |
| x, |
| self.mod1_w1 |
| ) |
| + self.mod1_b1 |
| ) |
|
|
| out_m1 = ( |
| torch.einsum( |
| "btj,tjk->btk", |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
|
|
| h2 = ( |
| torch.einsum( |
| "bti,tij->btj", |
| x, |
| self.mod2_w1 |
| ) |
| + self.mod2_b1 |
| ) |
|
|
| out_m2 = ( |
| torch.einsum( |
| "btj,tjk->btk", |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
|
|
| out_m2_safe = out_m2 + 1e-5 |
|
|
| compare = torch.tanh( |
| out_m1.unsqueeze(2) |
| / out_m2_safe.unsqueeze(1) |
| ) |
|
|
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) |
| / out_m2_safe.unsqueeze(2) |
| ) |
|
|
| bias_reshaped = self.trans_b.view( |
| 1, 1, N, 1 |
| ) |
|
|
| trans_compare = ( |
| torch.einsum( |
| "bije,jef->bijf", |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| trans_compare2 = ( |
| torch.einsum( |
| "bije,jef->bijf", |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
|
|
| interaksi = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
|
|
| mask = 1.0 - torch.eye( |
| N, |
| device=x.device |
| ) |
|
|
| interaksi_masked = ( |
| interaksi |
| * mask.view(1, N, N, 1) |
| ) |
|
|
| return interaksi_masked.sum(dim=2) / (N - 1.0) |
|
|
| |
| |
| |
|
|
| class LiteResidualBlock(nn.Module): |
|
|
| def __init__(self, dim, dropout=0.05): |
|
|
| super().__init__() |
|
|
| self.block = nn.Sequential( |
|
|
| nn.Linear(dim, dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
|
|
| nn.Linear(dim, dim) |
| ) |
|
|
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
|
|
| residual = x |
|
|
| x = self.block(x) |
|
|
| x = x + residual |
|
|
| x = self.norm(x) |
|
|
| return x |
|
|
| |
| |
| |
|
|
| class LookThemV76LiteResidual(nn.Module): |
|
|
| def __init__(self): |
|
|
| super().__init__() |
|
|
| |
| |
| |
|
|
| self.stream_a = nn.Sequential( |
|
|
| nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 32, |
| 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 64, |
| 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| |
| |
| |
|
|
| self.stream_b = nn.Sequential( |
|
|
| nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(16), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(32), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 32, |
| 64, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.Conv2d( |
| 64, |
| 64, |
| kernel_size=3, |
| stride=1, |
| padding=1 |
| ), |
|
|
| nn.BatchNorm2d(64), |
| nn.GELU(), |
|
|
| nn.AdaptiveMaxPool2d((8, 8)) |
| ) |
|
|
| |
| |
| |
|
|
| self.lookthemA = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=32 |
| ) |
|
|
| self.lookthemB = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=32 |
| ) |
|
|
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=128, |
| hidden_dim=32 |
| ) |
|
|
| self.compressor = nn.Conv1d( |
| 128, |
| 64, |
| kernel_size=1 |
| ) |
|
|
| self.imageCorrupter = nn.Dropout(0.1) |
|
|
| |
| |
| |
|
|
| self.flatten = nn.Flatten() |
|
|
| self.input_proj = nn.Sequential( |
|
|
| nn.Linear(4096, 256), |
| nn.GELU(), |
| nn.Dropout(0.08) |
| ) |
|
|
| self.res1 = LiteResidualBlock(256, 0.05) |
|
|
| self.res2 = LiteResidualBlock(256, 0.05) |
|
|
| self.head = nn.Sequential( |
|
|
| nn.Linear(256, 128), |
| nn.GELU(), |
|
|
| nn.Linear(128, 100) |
| ) |
|
|
| def extract_features(self, x): |
|
|
| batch_size = x.size(0) |
|
|
| |
| |
| |
|
|
| feat_a = self.stream_a(x) |
|
|
| feat_a_tokens = feat_a.view( |
| batch_size, |
| 64, |
| 64 |
| ).transpose(1, 2) |
|
|
| feat_a_tokens = self.imageCorrupter( |
| feat_a_tokens |
| ) |
|
|
| feat_a_lt = self.lookthemA( |
| feat_a_tokens |
| ) |
|
|
| |
| |
| |
|
|
| feat_b = self.stream_b(x) |
|
|
| feat_b_tokens = feat_b.view( |
| batch_size, |
| 64, |
| 64 |
| ).transpose(1, 2) |
|
|
| feat_b_tokens = self.imageCorrupter( |
| feat_b_tokens |
| ) |
|
|
| feat_b_lt = self.lookthemB( |
| feat_b_tokens |
| ) |
|
|
| |
| |
| |
|
|
| tokens_combined = torch.cat( |
| [feat_a_lt, feat_b_lt], |
| dim=2 |
| ) |
|
|
| out_lookthem = self.lookthem( |
| tokens_combined |
| ) |
|
|
| out_lookthem = out_lookthem.transpose(1, 2) |
|
|
| compressed = self.compressor( |
| out_lookthem |
| ) |
|
|
| return compressed |
|
|
| def forward(self, x): |
|
|
| x = self.extract_features(x) |
|
|
| x = self.flatten(x) |
|
|
| x = self.input_proj(x) |
|
|
| x = self.res1(x) |
|
|
| x = self.res2(x) |
|
|
| x = self.head(x) |
|
|
| return x |
|
|
| |
| |
| |
|
|
| model = LookThemV76LiteResidual().to(DEVICE) |
|
|
| |
| |
| |
|
|
| total_params = sum( |
| p.numel() |
| for p in model.parameters() |
| ) |
|
|
| print(f"\n🧠 Total Parameters : {total_params:,}") |
|
|
| size_mb = total_params * 4 / (1024 * 1024) |
|
|
| print(f"📦 Estimated Size : {size_mb:.2f} MB") |
|
|
| |
| |
| |
|
|
| criterion = nn.CrossEntropyLoss() |
|
|
| optimizer = optim.AdamW( |
| model.parameters(), |
| lr=LR, |
| weight_decay=WEIGHT_DECAY |
| ) |
|
|
| scheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=EPOCHS |
| ) |
|
|
| |
| |
| |
|
|
| print("\n🚀 Training Started...\n") |
|
|
| for epoch in range(EPOCHS): |
|
|
| model.train() |
|
|
| total_loss = 0 |
| correct = 0 |
| total = 0 |
|
|
| for step, (data, target) in enumerate(train_loader): |
|
|
| data = data.to(DEVICE) |
| target = target.to(DEVICE) |
|
|
| optimizer.zero_grad() |
|
|
| output = model(data) |
|
|
| loss = criterion(output, target) |
|
|
| loss.backward() |
|
|
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| _, predicted = output.max(1) |
|
|
| total += target.size(0) |
|
|
| correct += predicted.eq(target).sum().item() |
|
|
| if (step + 1) % 100 == 0: |
|
|
| print( |
| f"Epoch [{epoch+1:02d}/{EPOCHS}] " |
| f"| Step [{step+1}/{len(train_loader)}] " |
| f"| Loss: {loss.item():.4f}" |
| ) |
|
|
| scheduler.step() |
|
|
| acc = 100. * correct / total |
|
|
| current_lr = optimizer.param_groups[0]["lr"] |
|
|
| print( |
| f"\n🏁 Epoch [{epoch+1:02d}/{EPOCHS}] " |
| f"| Loss: {total_loss / len(train_loader):.4f} " |
| f"| Train Acc: {acc:.2f}% " |
| f"| LR: {current_lr:.6f}\n" |
| ) |
|
|
| |
| |
| |
|
|
| print("\n🧪 Validation...\n") |
|
|
| model.eval() |
|
|
| val_loss = 0 |
| val_correct = 0 |
| val_total = 0 |
|
|
| with torch.no_grad(): |
|
|
| for data, target in val_loader: |
|
|
| data = data.to(DEVICE) |
| target = target.to(DEVICE) |
|
|
| output = model(data) |
|
|
| loss = criterion(output, target) |
|
|
| val_loss += loss.item() |
|
|
| _, predicted = output.max(1) |
|
|
| val_total += target.size(0) |
|
|
| val_correct += predicted.eq(target).sum().item() |
|
|
| val_acc = 100. * val_correct / val_total |
|
|
| print( |
| f"\n🏆 Validation Accuracy: {val_acc:.2f}%" |
| ) |
|
|
| |
| |
| |
|
|
| torch.save( |
| model.state_dict(), |
| MODEL_SAVE_PATH |
| ) |
|
|
| real_size = os.path.getsize( |
| MODEL_SAVE_PATH |
| ) / (1024 * 1024) |
|
|
| print("\n💾 MODEL SAVED!") |
| print(f"📁 Path : {MODEL_SAVE_PATH}") |
| print(f"📦 Size : {real_size:.2f} MB") |
|
|