๐พ HeadCount
A semantic segmentation model for counting wheat heads in field images. Designed for yield estimation, flowering time detection, and field maturity assessment.
Model Details
- Architecture: DeepLabV3+ with ResNet50 encoder
- Framework: PyTorch with segmentation-models-pytorch
- Input: RGB images (resized to 512ร512)
- Output: 4-class segmentation (Background, Leaf, Stem, Head)
- Counting Method: Distance transform + peak detection on head mask
- Loss Function: Dice loss with inverse frequency weighting (1.5ร stem boost)
- Optimizer: Adam with CosineAnnealingLR scheduling
Performance
| Class | F1 |
|---|---|
| Background | 0.858 |
| Leaf | 0.889 |
| Stem | 0.535 |
| Head | 0.897 |
Example Usage
from inference import GWFSSModel
from PIL import Image
# Load model
model = GWFSSModel("model.pth")
# Load and process image
image = Image.open("input.jpg")
predictions = model.predict(image)
# Count heads
num_heads = model.count_heads(predictions)
print(f"๐พ {num_heads} heads detected")
# Create visualisation
overlay = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True)
overlay.save("output.png")
Limitations
Best performance is achieved with overhead imagery under diffuse lighting. Known challenges include:
- Lighting Sensitivity: Bright or harsh lighting can cause over-segmentation, splitting single heads into multiple detections
- Overlapping Heads: Dense clusters with significant overlap are challenging to separate accurately
- Colour Dependency: Performance is lower on senesced plants
Training Data
This model is trained on GWFSS_v1.0_labelled from the Global Wheat Full Semantic Organ Segmentation dataset.
Model tree for chmcbs/HeadCount
Base model
microsoft/resnet-50