ASomeoneWhoInterestedWithAI's picture
Upload 3 files
20b63f9 verified
# ============================================================
# LOOKTHEM V7.6 FULL TRAINING + INFERENCE
# Backbone + Lite Residual Classifier
# ============================================================
import os
import io
import math
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from datasets import load_dataset
# ============================================================
# CONFIG
# ============================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE_TRAIN = 96
BATCH_SIZE_VAL = 32
EPOCHS = 20
LR = 1e-3
WEIGHT_DECAY = 1e-4
MODEL_SAVE_PATH = "LookThem_V76_Full_LiteResidual.pth"
# ============================================================
# TRANSFORM
# ============================================================
transform_train = transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)
)
])
transform_val = transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)
)
])
# ============================================================
# DATASET
# ============================================================
class ImageNet100ParquetDataset(Dataset):
def __init__(self, hf_subset, transform=None):
self.dataset = hf_subset
self.transform = transform
def __getitem__(self, index):
row = self.dataset[index]
img_data = row["image"]
if isinstance(img_data, dict) and "bytes" in img_data:
img = Image.open(io.BytesIO(img_data["bytes"]))
elif isinstance(img_data, Image.Image):
img = img_data
else:
img = Image.open(io.BytesIO(img_data))
label = row["label"]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.dataset)
# ============================================================
# LOAD DATASET
# ============================================================
print("📡 Loading ImageNet-100...")
raw_train = load_dataset(
"clane9/imagenet-100",
split="train"
)
raw_val = load_dataset(
"clane9/imagenet-100",
split="validation"
)
train_dataset = ImageNet100ParquetDataset(
raw_train,
transform=transform_train
)
val_dataset = ImageNet100ParquetDataset(
raw_val,
transform=transform_val
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE_TRAIN,
shuffle=True,
num_workers=2,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE_VAL,
shuffle=False,
num_workers=2,
pin_memory=True
)
# ============================================================
# LOOKTHEM LAYER
# ============================================================
class LookThemLayer(nn.Module):
def __init__(self, num_tokens, in_features, hidden_dim):
super().__init__()
self.num_tokens = num_tokens
self.mod1_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod1_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod1_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod1_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self.mod2_w1 = nn.Parameter(
torch.randn(num_tokens, in_features, hidden_dim)
)
self.mod2_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod2_w2 = nn.Parameter(
torch.randn(num_tokens, hidden_dim, 1)
)
self.mod2_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self.trans_w = nn.Parameter(
torch.randn(num_tokens, 1, 1)
)
self.trans_b = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self._init_weights()
def _init_weights(self):
for w in [
self.mod1_w1,
self.mod2_w1,
self.mod1_w2,
self.mod2_w2,
self.trans_w
]:
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
def forward(self, x):
N = self.num_tokens
h1 = (
torch.einsum(
"bti,tij->btj",
x,
self.mod1_w1
)
+ self.mod1_b1
)
out_m1 = (
torch.einsum(
"btj,tjk->btk",
F.gelu(h1),
self.mod1_w2
)
+ self.mod1_b2
)
h2 = (
torch.einsum(
"bti,tij->btj",
x,
self.mod2_w1
)
+ self.mod2_b1
)
out_m2 = (
torch.einsum(
"btj,tjk->btk",
F.gelu(h2),
self.mod2_w2
)
+ self.mod2_b2
)
out_m2_safe = out_m2 + 1e-5
compare = torch.tanh(
out_m1.unsqueeze(2)
/ out_m2_safe.unsqueeze(1)
)
compare2 = torch.tanh(
out_m1.unsqueeze(1)
/ out_m2_safe.unsqueeze(2)
)
bias_reshaped = self.trans_b.view(
1, 1, N, 1
)
trans_compare = (
torch.einsum(
"bije,jef->bijf",
compare,
self.trans_w
)
+ bias_reshaped
)
trans_compare2 = (
torch.einsum(
"bije,jef->bijf",
compare2,
self.trans_w
)
+ bias_reshaped
)
interaksi = (
trans_compare * x.unsqueeze(2)
+ trans_compare2 * x.unsqueeze(1)
) / 2
mask = 1.0 - torch.eye(
N,
device=x.device
)
interaksi_masked = (
interaksi
* mask.view(1, N, N, 1)
)
return interaksi_masked.sum(dim=2) / (N - 1.0)
# ============================================================
# LITE RESIDUAL BLOCK
# ============================================================
class LiteResidualBlock(nn.Module):
def __init__(self, dim, dropout=0.05):
super().__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
residual = x
x = self.block(x)
x = x + residual
x = self.norm(x)
return x
# ============================================================
# FULL MODEL
# ============================================================
class LookThemV76LiteResidual(nn.Module):
def __init__(self):
super().__init__()
# ====================================================
# STREAM A
# ====================================================
self.stream_a = nn.Sequential(
nn.Conv2d(
3,
16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32,
64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(
64,
64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
# ====================================================
# STREAM B
# ====================================================
self.stream_b = nn.Sequential(
nn.Conv2d(
3,
16,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32,
64,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(
64,
64,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8))
)
# ====================================================
# LOOKTHEM
# ====================================================
self.lookthemA = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=32
)
self.lookthemB = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=32
)
self.lookthem = LookThemLayer(
num_tokens=64,
in_features=128,
hidden_dim=32
)
self.compressor = nn.Conv1d(
128,
64,
kernel_size=1
)
self.imageCorrupter = nn.Dropout(0.1)
# ====================================================
# CLASSIFIER
# ====================================================
self.flatten = nn.Flatten()
self.input_proj = nn.Sequential(
nn.Linear(4096, 256),
nn.GELU(),
nn.Dropout(0.08)
)
self.res1 = LiteResidualBlock(256, 0.05)
self.res2 = LiteResidualBlock(256, 0.05)
self.head = nn.Sequential(
nn.Linear(256, 128),
nn.GELU(),
nn.Linear(128, 100)
)
def extract_features(self, x):
batch_size = x.size(0)
# ====================================================
# STREAM A
# ====================================================
feat_a = self.stream_a(x)
feat_a_tokens = feat_a.view(
batch_size,
64,
64
).transpose(1, 2)
feat_a_tokens = self.imageCorrupter(
feat_a_tokens
)
feat_a_lt = self.lookthemA(
feat_a_tokens
)
# ====================================================
# STREAM B
# ====================================================
feat_b = self.stream_b(x)
feat_b_tokens = feat_b.view(
batch_size,
64,
64
).transpose(1, 2)
feat_b_tokens = self.imageCorrupter(
feat_b_tokens
)
feat_b_lt = self.lookthemB(
feat_b_tokens
)
# ====================================================
# COMBINE
# ====================================================
tokens_combined = torch.cat(
[feat_a_lt, feat_b_lt],
dim=2
)
out_lookthem = self.lookthem(
tokens_combined
)
out_lookthem = out_lookthem.transpose(1, 2)
compressed = self.compressor(
out_lookthem
)
return compressed
def forward(self, x):
x = self.extract_features(x)
x = self.flatten(x)
x = self.input_proj(x)
x = self.res1(x)
x = self.res2(x)
x = self.head(x)
return x
# ============================================================
# MODEL INIT
# ============================================================
model = LookThemV76LiteResidual().to(DEVICE)
# ============================================================
# PARAMETER COUNT
# ============================================================
total_params = sum(
p.numel()
for p in model.parameters()
)
print(f"\n🧠 Total Parameters : {total_params:,}")
size_mb = total_params * 4 / (1024 * 1024)
print(f"📦 Estimated Size : {size_mb:.2f} MB")
# ============================================================
# LOSS & OPTIMIZER
# ============================================================
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
model.parameters(),
lr=LR,
weight_decay=WEIGHT_DECAY
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=EPOCHS
)
# ============================================================
# TRAINING
# ============================================================
print("\n🚀 Training Started...\n")
for epoch in range(EPOCHS):
model.train()
total_loss = 0
correct = 0
total = 0
for step, (data, target) in enumerate(train_loader):
data = data.to(DEVICE)
target = target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if (step + 1) % 100 == 0:
print(
f"Epoch [{epoch+1:02d}/{EPOCHS}] "
f"| Step [{step+1}/{len(train_loader)}] "
f"| Loss: {loss.item():.4f}"
)
scheduler.step()
acc = 100. * correct / total
current_lr = optimizer.param_groups[0]["lr"]
print(
f"\n🏁 Epoch [{epoch+1:02d}/{EPOCHS}] "
f"| Loss: {total_loss / len(train_loader):.4f} "
f"| Train Acc: {acc:.2f}% "
f"| LR: {current_lr:.6f}\n"
)
# ============================================================
# VALIDATION
# ============================================================
print("\n🧪 Validation...\n")
model.eval()
val_loss = 0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
data = data.to(DEVICE)
target = target.to(DEVICE)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = output.max(1)
val_total += target.size(0)
val_correct += predicted.eq(target).sum().item()
val_acc = 100. * val_correct / val_total
print(
f"\n🏆 Validation Accuracy: {val_acc:.2f}%"
)
# ============================================================
# SAVE MODEL
# ============================================================
torch.save(
model.state_dict(),
MODEL_SAVE_PATH
)
real_size = os.path.getsize(
MODEL_SAVE_PATH
) / (1024 * 1024)
print("\n💾 MODEL SAVED!")
print(f"📁 Path : {MODEL_SAVE_PATH}")
print(f"📦 Size : {real_size:.2f} MB")