| --- |
| language: en |
| license: apache-2.0 |
| tags: |
| - text-classification |
| - multi-label-classification |
| - tinybert |
| - pytorch |
| datasets: |
| - JayShah07/multi_label_reporting |
| metrics: |
| - accuracy |
| - f1 |
| widget: |
| - text: "Show me my current holdings" |
| - text: "What are my capital gains for this year?" |
| - text: "Give me monthly scheme-wise returns" |
| --- |
| |
| # TinyBERT Dual Classifier for Investment Reporting |
|
|
| This model is a fine-tuned TinyBERT with two classification heads for multi-label classification of investment reporting queries. |
|
|
| ## Model Description |
|
|
| - **Base Model**: TinyBERT (huawei-noah/TinyBERT_General_4L_312D) |
| - **Parameters**: ~14-15M |
| - **Architecture**: Single encoder with two independent classification heads |
| - **Task**: Multi-label classification (Module + Date) |
| |
| ## Labels |
| |
| **Module Labels (6 classes)**: |
| - holdings |
| - capital_gains |
| - scheme_wise_returns |
| - investment_account_wise_returns |
| - portfolio_update |
| - None_module |
| |
| **Date Labels (7 classes)**: |
| - Current Year |
| - Previous Year |
| - Daily |
| - Monthly |
| - Weekly |
| - Yearly |
| - None_date |
|
|
| ## Performance |
|
|
| **Test Set Results**: |
| - Module Accuracy: 1.0000 |
| - Module F1 Score: 1.0000 |
| - Date Accuracy: 1.0000 |
| - Date F1 Score: 1.0000 |
|
|
| ## Usage |
| ```python |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import torch.nn as nn |
| |
| # Define model class |
| class TinyBERTDualClassifier(nn.Module): |
| def __init__(self, num_module_labels, num_date_labels, dropout_rate=0.1): |
| super(TinyBERTDualClassifier, self).__init__() |
| self.encoder = AutoModel.from_pretrained("JayShah07/tinybert-dual-classifier") |
| self.hidden_size = self.encoder.config.hidden_size |
| self.dropout = nn.Dropout(p=dropout_rate) |
| self.module_classifier = nn.Linear(self.hidden_size, num_module_labels) |
| self.date_classifier = nn.Linear(self.hidden_size, num_date_labels) |
| |
| def forward(self, input_ids, attention_mask): |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| cls_output = outputs.last_hidden_state[:, 0, :] |
| cls_output = self.dropout(cls_output) |
| module_logits = self.module_classifier(cls_output) |
| date_logits = self.date_classifier(cls_output) |
| return module_logits, date_logits |
| |
| # Load model |
| classifier_config = torch.hub.load_state_dict_from_url( |
| f"https://huggingface.co/JayShah07/tinybert-dual-classifier/resolve/main/classifier_heads.pt" |
| ) |
| |
| model = TinyBERTDualClassifier( |
| num_module_labels=6, |
| num_date_labels=7 |
| ) |
| |
| model.module_classifier.load_state_dict(classifier_config['module_classifier']) |
| model.date_classifier.load_state_dict(classifier_config['date_classifier']) |
| |
| tokenizer = AutoTokenizer.from_pretrained("JayShah07/tinybert-dual-classifier") |
| |
| # Inference |
| model.eval() |
| text = "Show my holdings for this month" |
| inputs = tokenizer(text, return_tensors='pt', padding='max_length', |
| truncation=True, max_length=128) |
| |
| with torch.no_grad(): |
| module_logits, date_logits = model(inputs['input_ids'], inputs['attention_mask']) |
| module_pred = torch.argmax(module_logits, dim=1).item() |
| date_pred = torch.argmax(date_logits, dim=1).item() |
| |
| module_labels = ['holdings', 'capital_gains', 'scheme_wise_returns', 'investment_account_wise_returns', 'portfolio_update', 'None_module'] |
| date_labels = ['Current Year', 'Previous Year', 'Daily', 'Monthly', 'Weekly', 'Yearly', 'None_date'] |
| |
| print(f"Module: {module_labels[module_pred]}") |
| print(f"Date: {date_labels[date_pred]}") |
| ``` |
|
|
| ## Training Details |
|
|
| - **Dataset**: JayShah07/multi_label_reporting |
| - **Training Samples**: 3097 |
| - **Validation Samples**: 387 |
| - **Test Samples**: 388 |
| - **Epochs**: 10 |
| - **Batch Size**: 16 |
| - **Learning Rate**: 2e-05 |
| - **Optimizer**: AdamW |
| - **Loss Function**: CrossEntropyLoss (separate for each head) |
|
|
| ## Latency |
|
|
| Average inference latency on sample queries (mean ± std): |
| - See notebook for detailed latency analysis |
|
|
| ## Citation |
|
|
| If you use this model, please cite: |
| ```bibtex |
| @misc{tinybert-dual-classifier, |
| author = {Jay Shah}, |
| title = {TinyBERT Dual Classifier for Investment Reporting}, |
| year = {2025}, |
| publisher = {Hugging Face}, |
| howpublished = {\url{https://huggingface.co/JayShah07/tinybert-dual-classifier}} |
| } |
| ``` |
|
|
| ## License |
|
|
| Apache 2.0 |
|
|