BitFinTrainer / scripts /test_inference.py
luohoa97's picture
Deploy BitNet-Transformer Trainer
d5b7ee9 verified
raw
history blame contribute delete
832 Bytes
import torch
from safetensors.torch import load_file
from trading_cli.strategy.ai.model import create_model
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_inference():
model = create_model(input_dim=9)
try:
model.load_state_dict(load_file("models/ai_fusion_bitnet.safetensors"))
model.eval()
logger.info("Model loaded successfully ✓")
# Test with random input
x = torch.randn(1, 9)
with torch.no_grad():
output = model(x)
logger.info(f"Output: {output}")
action = torch.argmax(output, dim=-1).item()
logger.info(f"Action: {action}")
except Exception as e:
logger.error(f"Inference test failed: {e}")
if __name__ == "__main__":
test_inference()