| import torch |
| from safetensors.torch import load_file |
|
|
| def load_model(path='model.safetensors'): |
| return load_file(path) |
|
|
| def equals2(a1, a0, b1, b0, w): |
| inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)]) |
|
|
| and1 = int((inp * w['layer1.and1.weight']).sum() + w['layer1.and1.bias'] >= 0) |
| nor1 = int((inp * w['layer1.nor1.weight']).sum() + w['layer1.nor1.bias'] >= 0) |
| and0 = int((inp * w['layer1.and0.weight']).sum() + w['layer1.and0.bias'] >= 0) |
| nor0 = int((inp * w['layer1.nor0.weight']).sum() + w['layer1.nor0.bias'] >= 0) |
|
|
| l1 = torch.tensor([float(and1), float(nor1), float(and0), float(nor0)]) |
| xnor1 = int((l1 * w['layer2.xnor1.weight']).sum() + w['layer2.xnor1.bias'] >= 0) |
| xnor0 = int((l1 * w['layer2.xnor0.weight']).sum() + w['layer2.xnor0.bias'] >= 0) |
|
|
| l2 = torch.tensor([float(xnor1), float(xnor0)]) |
| return int((l2 * w['layer3.eq.weight']).sum() + w['layer3.eq.bias'] >= 0) |
|
|
| if __name__ == '__main__': |
| w = load_model() |
| print('equals2 truth table:') |
| for a in range(4): |
| for b in range(4): |
| a1, a0, b1, b0 = (a >> 1) & 1, a & 1, (b >> 1) & 1, b & 1 |
| print(f' {a} == {b}? {equals2(a1, a0, b1, b0, w)}') |
|
|