| --- |
| license: mit |
| datasets: |
| - pawlo2013/EarthLoc_2021_Database |
| base_model: |
| - facebook/dinov2-base |
| pipeline_tag: image-feature-extraction |
| --- |
| |
|
|
| ## EarthLoc2 model |
|
|
| This is the EarthLoc2 model = DINOv2 base with SALAD aggregator out dim = 3072. |
|
|
| Trained on the original EarthLoc dataset (zooms 9,10,11) , in range -60,60 latitude, polar regions not supported. |
|
|
| Training included additional queries which were not part of the test/val sets |
|
|
|
|
| Achieves average R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is). |
|
|
| 5000 iterations with a batch size of 96, lr = 0.0001, only last block of Dinov2 + aggregator trainable. |
|
|
| To use the prediction of the model, see the FAISS index https://huggingface.co/datasets/pawlo2013/EarthLoc2_FAISS, 2021 database https://huggingface.co/datasets/pawlo2013/EarthLoc_2021_Database, |
| and the inference space https://huggingface.co/spaces/pawlo2013/EarthLoc2. |
| |
| See EarthLoc for more details about the training, data and use cases https://earthloc-and-earthmatch.github.io/ |
| |
| |
| |
| | **Model** | **Average R@1** | **Average R@10** | **Average R@100** | |
| |------------|-----------------|------------------|-------------------| |
| | EarthLoc | 50.8 | 65.9 | 80.5 | |
| | **EarthLoc2** | **79.6** | **90.0** | **95.5** | |
| |
| *Wide world search. Results across evaluation sets when all of the images in the database from 2021 are encoded.* |
| *EarthLoc = (ResNet + MixVPR), EarthLoc2 = (DINOv2-B + SALAD-B + Query Data)* |
| |
| |
| |
| ## Loading and Inspecting the DINOv2 Feature Extractor Model |
| |
| ```python |
| from model import DINOv2FeatureExtractor |
| import torch |
| |
| # Set device |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| # Path to the pretrained weights |
| MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch' |
|
|
| # Initialize the model |
| model = DINOv2FeatureExtractor( |
| model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
| num_of_layers_to_unfreeze=0, |
| desc_dim=768, |
| aggregator_type="SALAD", |
| ) |
| |
| print('Loading model ...') |
| # Load weights |
| model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) |
| model.load_state_dict(model_state_dict) |
| |
| # Move model to device and set to evaluation mode |
| model = model.to(DEVICE) |
| model.eval() |
| print('Model loaded.') |
| |
| # Print model parameters info |
| num_params = sum(p.numel() for p in model.parameters()) |
| num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Model total parameters: {num_params:,}") |
| print(f"Model trainable parameters: {num_trainable:,}") |
|
|
| # Print aggregator type |
| print(f"Aggregator type: {model.aggregator_type}") |