| """
|
| Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)
|
|
|
| These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:
|
|
|
| TestBitLinearPython (5 tests)
|
| 1. test_shape_correctness - Verifies output dimensions for 3D inputs
|
| 2. test_no_bias - Tests forward pass without bias term
|
| 3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}
|
| 4. test_gamma_scaling - Verifies gamma scaling is applied correctly
|
| 5. test_numerical_correctness - Compares against manual torch computation
|
|
|
| TestGreedyTernaryDecomposition (4 tests)
|
| 1. test_decomposition_shape - Checks output tensor shapes
|
| 2. test_ternary_values - Ensures all decomposed weights are ternary
|
| 3. test_reconstruction_error - Validates error decreases with more components
|
| 4. test_single_component - Tests k=1 edge case
|
|
|
| TestMultiTernaryLinearPython (2 tests)
|
| 1. test_shape_correctness - Verifies output shape
|
| 2. test_equivalence_to_sum - Confirms equivalence to summing individual operations
|
|
|
| TestActivationQuant (2 tests)
|
| 1. test_quantization_range - Validates quantization behavior and output
|
| 2. test_absmax_scaling - Tests per-token absmax scaling
|
|
|
| TestFunctionalIntegration (3 tests)
|
| 1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward
|
| 2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear
|
| 3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation
|
| """
|
|
|
| import pytest
|
| import torch
|
| import torch.nn as nn
|
|
|
| from bitlinear.functional import (
|
| bitlinear_python,
|
| greedy_ternary_decomposition,
|
| multi_ternary_linear_python,
|
| activation_quant,
|
| )
|
|
|
|
|
| class TestBitLinearPython:
|
| """Tests for bitlinear_python function."""
|
|
|
| def test_shape_correctness(self):
|
| """Test that output shape matches expected dimensions."""
|
| batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
|
| x = torch.randn(batch_size, seq_len, in_features)
|
| W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| gamma = torch.ones(out_features)
|
| bias = torch.zeros(out_features)
|
|
|
| output = bitlinear_python(x, W_ternary, gamma, bias)
|
|
|
| assert output.shape == (batch_size, seq_len, out_features)
|
|
|
| def test_no_bias(self):
|
| """Test forward pass without bias."""
|
| batch_size, in_features, out_features = 16, 256, 512
|
| x = torch.randn(batch_size, in_features)
|
| W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| gamma = torch.ones(out_features)
|
|
|
| output = bitlinear_python(x, W_ternary, gamma, bias=None)
|
|
|
| assert output.shape == (batch_size, out_features)
|
| assert not torch.isnan(output).any()
|
|
|
| def test_ternary_constraint(self):
|
| """Test that function works correctly with ternary weights {-1, 0, +1}."""
|
| x = torch.randn(8, 64)
|
| W_ternary = torch.randint(-1, 2, (128, 64)).float()
|
| gamma = torch.ones(128)
|
|
|
|
|
| unique_values = torch.unique(W_ternary)
|
| assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
|
|
|
|
| output = bitlinear_python(x, W_ternary, gamma)
|
| assert output.shape == (8, 128)
|
| assert not torch.isnan(output).any()
|
|
|
| def test_gamma_scaling(self):
|
| """Test that gamma scaling is applied correctly."""
|
| x = torch.randn(4, 32)
|
| W_ternary = torch.randint(-1, 2, (64, 32)).float()
|
| gamma = torch.rand(64) * 2 + 0.5
|
|
|
|
|
| output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
|
|
|
|
|
| gamma_ones = torch.ones_like(gamma)
|
| output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
|
| output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
|
|
|
|
|
| assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
|
|
|
| def test_numerical_correctness(self):
|
| """Test numerical correctness against standard nn.Linear."""
|
| in_features, out_features = 128, 256
|
| x = torch.randn(16, in_features)
|
| W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| gamma = torch.ones(out_features)
|
| bias = torch.randn(out_features)
|
|
|
|
|
| output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
|
|
|
|
|
| output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
|
|
|
|
|
| assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)
|
|
|
|
|
| class TestGreedyTernaryDecomposition:
|
| """Tests for greedy_ternary_decomposition function."""
|
|
|
| def test_decomposition_shape(self):
|
| """Test that decomposition returns correct shapes."""
|
| W = torch.randn(512, 768)
|
| k = 4
|
| W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
| assert W_ternary.shape == (k, 512, 768)
|
| assert gammas.shape == (k, 512)
|
|
|
| def test_ternary_values(self):
|
| """Test that decomposed weights are ternary."""
|
| W = torch.randn(64, 128)
|
| k = 2
|
| W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
| unique_values = torch.unique(W_ternary)
|
| assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
|
| f"Found non-ternary values: {unique_values.tolist()}"
|
|
|
| def test_reconstruction_error(self):
|
| """Test that reconstruction error decreases with more components."""
|
| W = torch.randn(128, 256)
|
| errors = []
|
|
|
| for k in [1, 2, 4, 8]:
|
| W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
| reconstruction = torch.zeros_like(W)
|
| for i in range(k):
|
| reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
|
|
|
| error = torch.norm(W - reconstruction).item()
|
| errors.append(error)
|
|
|
|
|
| assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
|
| assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
|
| assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
|
|
|
| def test_single_component(self):
|
| """Test k=1 case (single ternary quantization)."""
|
| W = torch.randn(32, 64)
|
| k = 1
|
| W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
| assert W_ternary.shape == (1, 32, 64)
|
| assert gammas.shape == (1, 32)
|
|
|
|
|
| unique_values = torch.unique(W_ternary)
|
| assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
|
|
|
|
| class TestMultiTernaryLinearPython:
|
| """Tests for multi_ternary_linear_python function."""
|
|
|
| def test_shape_correctness(self):
|
| """Test output shape for multi-ternary linear."""
|
| batch_size, in_features, out_features = 16, 128, 256
|
| k = 4
|
|
|
| x = torch.randn(batch_size, in_features)
|
| W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
| gammas = torch.rand(k, out_features)
|
| bias = torch.randn(out_features)
|
|
|
| output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
| assert output.shape == (batch_size, out_features)
|
|
|
| def test_equivalence_to_sum(self):
|
| """Test that multi-ternary equals sum of individual ternary ops."""
|
| batch_size, in_features, out_features = 8, 64, 128
|
| k = 3
|
|
|
| x = torch.randn(batch_size, in_features)
|
| W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
| gammas = torch.rand(k, out_features)
|
| bias = torch.randn(out_features)
|
|
|
|
|
| output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
|
|
| output_sum = torch.zeros(batch_size, out_features)
|
| for i in range(k):
|
| output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
|
| output_sum += bias
|
|
|
|
|
| assert torch.allclose(output_multi, output_sum, atol=1e-5)
|
|
|
|
|
| class TestActivationQuant:
|
| """Tests for activation quantization."""
|
|
|
| def test_quantization_range(self):
|
| """Test that quantized activations are in expected range."""
|
| x = torch.randn(16, 128, 512) * 10
|
| bits = 8
|
|
|
| x_quant = activation_quant(x, bits=bits)
|
|
|
|
|
| assert x_quant.shape == x.shape
|
|
|
|
|
| assert not torch.allclose(x, x_quant, atol=1e-6)
|
|
|
|
|
| assert torch.isfinite(x_quant).all()
|
|
|
| def test_absmax_scaling(self):
|
| """Test that absmax scaling is applied correctly."""
|
|
|
| x = torch.tensor([
|
| [1.0, 2.0, 3.0, 4.0],
|
| [-5.0, -10.0, 5.0, 10.0],
|
| ])
|
|
|
| x_quant = activation_quant(x, bits=8)
|
|
|
|
|
|
|
|
|
| assert x_quant.shape == (2, 4)
|
| assert torch.isfinite(x_quant).all()
|
|
|
|
|
|
|
| relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
|
| assert relative_error.mean() < 0.1
|
|
|
|
|
|
|
| class TestFunctionalIntegration:
|
| """Integration tests combining multiple functional components."""
|
|
|
| def test_full_pipeline(self):
|
| """Test full pipeline: decomposition → multi-ternary forward."""
|
|
|
| in_features, out_features = 256, 512
|
| W_dense = torch.randn(out_features, in_features)
|
|
|
|
|
| k = 4
|
| W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
|
|
|
|
|
| batch_size = 16
|
| x = torch.randn(batch_size, in_features)
|
| bias = torch.randn(out_features)
|
|
|
| output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
|
|
| assert output.shape == (batch_size, out_features)
|
| assert torch.isfinite(output).all()
|
|
|
|
|
| output_dense = torch.matmul(x, W_dense.t()) + bias
|
|
|
|
|
| relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
|
| assert relative_error < 1.0
|
|
|
| def test_bitlinear_with_activation_quant(self):
|
| """Test combining bitlinear with activation quantization."""
|
| batch_size, in_features, out_features = 8, 128, 256
|
|
|
|
|
| x = torch.randn(batch_size, in_features)
|
| W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| gamma = torch.ones(out_features)
|
|
|
|
|
| x_quant = activation_quant(x, bits=8)
|
|
|
|
|
| output = bitlinear_python(x_quant, W_ternary, gamma)
|
|
|
|
|
| assert output.shape == (batch_size, out_features)
|
| assert torch.isfinite(output).all()
|
|
|
| def test_multi_ternary_end_to_end(self):
|
| """Test multi-ternary from weight decomposition to forward pass."""
|
|
|
| W = torch.randn(64, 128) * 0.1
|
| x = torch.randn(4, 128)
|
|
|
|
|
| for k in [1, 2, 4]:
|
| W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
|
|
|
|
|
| assert output.shape == (4, 64)
|
| assert torch.isfinite(output).all()
|
|
|
|
|
| W_reconstructed = torch.zeros_like(W)
|
| for i in range(k):
|
| W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
|
|
|
|
|
| output_expected = torch.matmul(x, W_reconstructed.t())
|
|
|
|
|
| assert torch.allclose(output, output_expected, atol=1e-4)
|
|
|