| import torch |
| from safetensors.torch import load_file |
|
|
| def load_model(path='model.safetensors'): |
| return load_file(path) |
|
|
| def hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, weights): |
| """Compute Hamming distance between two 4-bit values. Returns (d2, d1, d0).""" |
| inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), |
| float(b3), float(b2), float(b1), float(b0)]) |
|
|
| |
| diff_hi, diff_lo = [], [] |
| for i in range(4): |
| dh = int((inp @ weights[f'layer1.diff_hi_{i}.weight'].T + weights[f'layer1.diff_hi_{i}.bias'] >= 0).item()) |
| dl = int((inp @ weights[f'layer1.diff_lo_{i}.weight'].T + weights[f'layer1.diff_lo_{i}.bias'] >= 0).item()) |
| diff_hi.append(dh) |
| diff_lo.append(dl) |
|
|
| |
| diffs = [] |
| for i in range(4): |
| d_inp = torch.tensor([float(diff_hi[i]), float(diff_lo[i])]) |
| d = int((d_inp @ weights[f'layer2.diff_{i}.weight'].T + weights[f'layer2.diff_{i}.bias'] >= 0).item()) |
| diffs.append(d) |
|
|
| |
| diff_tensor = torch.tensor([float(d) for d in diffs]) |
| ge = [] |
| for k in range(1, 5): |
| g = int((diff_tensor @ weights[f'layer3.ge{k}.weight'].T + weights[f'layer3.ge{k}.bias'] >= 0).item()) |
| ge.append(g) |
|
|
| |
| ge_tensor = torch.tensor([float(g) for g in ge]) |
| d2 = int((ge_tensor @ weights['layer4.d2.weight'].T + weights['layer4.d2.bias'] >= 0).item()) |
| d1 = int((ge_tensor @ weights['layer4.d1.weight'].T + weights['layer4.d1.bias'] >= 0).item()) |
| d0_part1 = int((ge_tensor @ weights['layer4.d0_part1.weight'].T + weights['layer4.d0_part1.bias'] >= 0).item()) |
| d0_part2 = int((ge_tensor @ weights['layer4.d0_part2.weight'].T + weights['layer4.d0_part2.bias'] >= 0).item()) |
|
|
| |
| d0_inp = torch.tensor([float(d0_part1), float(d0_part2)]) |
| d0 = int((d0_inp @ weights['layer5.d0.weight'].T + weights['layer5.d0.bias'] >= 0).item()) |
|
|
| return d2, d1, d0 |
|
|
| if __name__ == '__main__': |
| w = load_model() |
| print('Hamming Distance 4-bit examples:') |
| test_cases = [(0b0000, 0b0000), (0b1111, 0b0000), (0b1010, 0b0101), (0b1100, 0b1010)] |
| for a, b in test_cases: |
| a3, a2, a1, a0 = (a >> 3) & 1, (a >> 2) & 1, (a >> 1) & 1, a & 1 |
| b3, b2, b1, b0 = (b >> 3) & 1, (b >> 2) & 1, (b >> 1) & 1, b & 1 |
| d2, d1, d0 = hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, w) |
| dist = 4*d2 + 2*d1 + d0 |
| print(f' HD({a:04b}, {b:04b}) = {dist}') |
|
|