File size: 1,010 Bytes
8f51ef2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
import transformers
from echo_prime import EchoPrimeTextEncoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## load the echo encoder
checkpoint = torch.load("model_data/weights/echo_prime_encoder.pt",map_location=device)
echo_encoder = torchvision.models.video.mvit_v2_s()
echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512)
echo_encoder.load_state_dict(checkpoint)
echo_encoder.eval()
echo_encoder.to(device)
for param in echo_encoder.parameters():
param.requires_grad = False
print(f"Echo embedding shape is {echo_encoder(torch.zeros(1,3,16,224,224).to(device)).shape}")
## load the text encoder
text_encoder=EchoPrimeTextEncoder(device=device)
text_encoder.load_state_dict(torch.load("model_data/weights/echo_prime_text_encoder.pt"))
text_encoder.eval()
# produces 512 dimensional embedding
print(f"Text embedding shape is {text_encoder('Sample text').shape}") |