|
|
|
|
|
""" |
|
|
Comprehensive test suite for Mon tokenizer Hugging Face integration. |
|
|
|
|
|
This script provides extensive testing for the Mon language tokenizer, |
|
|
including functionality tests, performance benchmarks, and compatibility checks. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
handlers=[logging.StreamHandler()], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MonTokenizerTester: |
|
|
"""Comprehensive testing suite for Mon tokenizer.""" |
|
|
|
|
|
def __init__(self, tokenizer_path: str = "."): |
|
|
""" |
|
|
Initialize the tester. |
|
|
|
|
|
Args: |
|
|
tokenizer_path: Path to the tokenizer files |
|
|
""" |
|
|
self.tokenizer_path = tokenizer_path |
|
|
self.tokenizer = None |
|
|
self.test_results = {} |
|
|
|
|
|
def load_tokenizer(self) -> bool: |
|
|
""" |
|
|
Load the tokenizer for testing. |
|
|
|
|
|
Returns: |
|
|
bool: True if tokenizer loaded successfully, False otherwise |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Loading tokenizer from: {self.tokenizer_path}") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.tokenizer_path, |
|
|
local_files_only=True, |
|
|
trust_remote_code=False |
|
|
) |
|
|
|
|
|
logger.info(f"✓ Tokenizer loaded successfully") |
|
|
logger.info(f" - Vocabulary size: {self.tokenizer.vocab_size:,}") |
|
|
logger.info(f" - Model max length: {self.tokenizer.model_max_length:,}") |
|
|
logger.info(f" - Tokenizer class: {self.tokenizer.__class__.__name__}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Failed to load tokenizer: {e}") |
|
|
return False |
|
|
|
|
|
def test_basic_functionality(self) -> bool: |
|
|
""" |
|
|
Test basic tokenizer functionality. |
|
|
|
|
|
Returns: |
|
|
bool: True if all basic tests pass, False otherwise |
|
|
""" |
|
|
logger.info("=== Testing Basic Functionality ===") |
|
|
|
|
|
test_cases = [ |
|
|
{ |
|
|
"text": "ဘာသာမန်", |
|
|
"description": "Single Mon word", |
|
|
"expected_min_tokens": 1 |
|
|
}, |
|
|
{ |
|
|
"text": "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။", |
|
|
"description": "Complete Mon sentence", |
|
|
"expected_min_tokens": 3 |
|
|
}, |
|
|
{ |
|
|
"text": "မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။", |
|
|
"description": "Mon geographical text", |
|
|
"expected_min_tokens": 3 |
|
|
}, |
|
|
{ |
|
|
"text": "၁၂၃၄၅ ဂတာပ်ခ္ဍာ် ၂၀၂၄ သၞာံ", |
|
|
"description": "Mon numerals and dates", |
|
|
"expected_min_tokens": 2 |
|
|
}, |
|
|
{ |
|
|
"text": "အရေဝ်ဘာသာမန် ပ္ဍဲလောကဏအ် ဂွံဆဵုကေတ် ပ္ဍဲဍုင်သေံ ကဵု ဍုင်ဗၟာ ရ။", |
|
|
"description": "Complex Mon linguistics text", |
|
|
"expected_min_tokens": 5 |
|
|
} |
|
|
] |
|
|
|
|
|
passed = 0 |
|
|
total = len(test_cases) |
|
|
|
|
|
for i, test_case in enumerate(test_cases, 1): |
|
|
text = test_case["text"] |
|
|
description = test_case["description"] |
|
|
expected_min_tokens = test_case["expected_min_tokens"] |
|
|
|
|
|
try: |
|
|
|
|
|
start_time = time.time() |
|
|
tokens = self.tokenizer(text, return_tensors="pt") |
|
|
encoding_time = time.time() - start_time |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
decoded = self.tokenizer.decode( |
|
|
tokens["input_ids"][0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
decoding_time = time.time() - start_time |
|
|
|
|
|
|
|
|
token_count = tokens["input_ids"].shape[1] |
|
|
round_trip_success = text.strip() == decoded.strip() |
|
|
|
|
|
if token_count >= expected_min_tokens and round_trip_success: |
|
|
logger.info(f"✓ Test {i}: {description}") |
|
|
logger.info(f" Tokens: {token_count}, Encoding: {encoding_time*1000:.2f}ms, " |
|
|
f"Decoding: {decoding_time*1000:.2f}ms") |
|
|
passed += 1 |
|
|
else: |
|
|
logger.warning(f"⚠ Test {i}: {description}") |
|
|
if token_count < expected_min_tokens: |
|
|
logger.warning(f" Token count too low: {token_count} < {expected_min_tokens}") |
|
|
if not round_trip_success: |
|
|
logger.warning(f" Round-trip failed:") |
|
|
logger.warning(f" Input: '{text}'") |
|
|
logger.warning(f" Output: '{decoded}'") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Test {i}: {description} - ERROR: {e}") |
|
|
|
|
|
success = passed == total |
|
|
self.test_results["basic_functionality"] = { |
|
|
"passed": passed, |
|
|
"total": total, |
|
|
"success": success |
|
|
} |
|
|
|
|
|
logger.info(f"Basic functionality: {passed}/{total} tests passed") |
|
|
return success |
|
|
|
|
|
def test_special_tokens(self) -> bool: |
|
|
""" |
|
|
Test special token handling. |
|
|
|
|
|
Returns: |
|
|
bool: True if special token tests pass, False otherwise |
|
|
""" |
|
|
logger.info("=== Testing Special Tokens ===") |
|
|
|
|
|
try: |
|
|
|
|
|
special_tokens = { |
|
|
"bos_token": self.tokenizer.bos_token, |
|
|
"eos_token": self.tokenizer.eos_token, |
|
|
"unk_token": self.tokenizer.unk_token, |
|
|
"pad_token": self.tokenizer.pad_token, |
|
|
} |
|
|
|
|
|
special_token_ids = { |
|
|
"bos_token_id": self.tokenizer.bos_token_id, |
|
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
|
"unk_token_id": self.tokenizer.unk_token_id, |
|
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
|
} |
|
|
|
|
|
logger.info("Special tokens:") |
|
|
for name, token in special_tokens.items(): |
|
|
token_id = special_token_ids[f"{name}_id"] |
|
|
logger.info(f" {name}: '{token}' (ID: {token_id})") |
|
|
|
|
|
|
|
|
test_text = "ဘာသာမန်" |
|
|
tokens_with_special = self.tokenizer( |
|
|
test_text, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
tokens_without_special = self.tokenizer( |
|
|
test_text, |
|
|
add_special_tokens=False, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with_special_count = tokens_with_special["input_ids"].shape[1] |
|
|
without_special_count = tokens_without_special["input_ids"].shape[1] |
|
|
|
|
|
if with_special_count > without_special_count: |
|
|
logger.info("✓ Special tokens are properly added") |
|
|
success = True |
|
|
else: |
|
|
logger.warning("⚠ Special tokens may not be properly added") |
|
|
success = False |
|
|
|
|
|
self.test_results["special_tokens"] = {"success": success} |
|
|
return success |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Special token test failed: {e}") |
|
|
self.test_results["special_tokens"] = {"success": False} |
|
|
return False |
|
|
|
|
|
def test_edge_cases(self) -> bool: |
|
|
""" |
|
|
Test edge cases and error handling. |
|
|
|
|
|
Returns: |
|
|
bool: True if edge case tests pass, False otherwise |
|
|
""" |
|
|
logger.info("=== Testing Edge Cases ===") |
|
|
|
|
|
edge_cases = [ |
|
|
("", "Empty string"), |
|
|
(" ", "Whitespace only"), |
|
|
("a", "Single ASCII character"), |
|
|
("123", "Numbers only"), |
|
|
("!@#$%", "Special characters only"), |
|
|
("ဘာသာမန်" * 100, "Very long text"), |
|
|
("ဟ", "Single Mon character"), |
|
|
("၀၁၂၃၄၅၆၇၈၉", "Mon numerals"), |
|
|
] |
|
|
|
|
|
passed = 0 |
|
|
total = len(edge_cases) |
|
|
|
|
|
for text, description in edge_cases: |
|
|
try: |
|
|
tokens = self.tokenizer(text, return_tensors="pt") |
|
|
decoded = self.tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
logger.info(f"✓ {description}: {tokens['input_ids'].shape[1]} tokens") |
|
|
passed += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ {description}: {e}") |
|
|
|
|
|
success = passed == total |
|
|
self.test_results["edge_cases"] = { |
|
|
"passed": passed, |
|
|
"total": total, |
|
|
"success": success |
|
|
} |
|
|
|
|
|
logger.info(f"Edge cases: {passed}/{total} tests passed") |
|
|
return success |
|
|
|
|
|
def test_performance_benchmark(self) -> bool: |
|
|
""" |
|
|
Run performance benchmarks. |
|
|
|
|
|
Returns: |
|
|
bool: True if performance is acceptable, False otherwise |
|
|
""" |
|
|
logger.info("=== Performance Benchmark ===") |
|
|
|
|
|
|
|
|
test_texts = [ |
|
|
"ဘာသာမန်", |
|
|
"ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။", |
|
|
("အရေဝ်ဘာသာမန် ပ္ဍဲလောကဏအ် ဂွံဆဵုကေတ် ပ္ဍဲဍုင်သေံ ကဵု ဍုင်ဗၟာ ရ။ " * 10), |
|
|
("မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။ " * 50), |
|
|
] |
|
|
|
|
|
benchmark_results = [] |
|
|
|
|
|
for i, text in enumerate(test_texts, 1): |
|
|
char_count = len(text) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for _ in range(10): |
|
|
tokens = self.tokenizer(text, return_tensors="pt") |
|
|
encoding_time = (time.time() - start_time) / 10 |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for _ in range(10): |
|
|
decoded = self.tokenizer.decode(tokens["input_ids"][0]) |
|
|
decoding_time = (time.time() - start_time) / 10 |
|
|
|
|
|
token_count = tokens["input_ids"].shape[1] |
|
|
|
|
|
result = { |
|
|
"text_length": char_count, |
|
|
"token_count": token_count, |
|
|
"encoding_time": encoding_time, |
|
|
"decoding_time": decoding_time, |
|
|
"chars_per_second": char_count / encoding_time if encoding_time > 0 else 0, |
|
|
"tokens_per_second": token_count / decoding_time if decoding_time > 0 else 0 |
|
|
} |
|
|
|
|
|
benchmark_results.append(result) |
|
|
|
|
|
logger.info(f"Text {i} ({char_count} chars, {token_count} tokens):") |
|
|
logger.info(f" Encoding: {encoding_time*1000:.2f}ms ({result['chars_per_second']:.0f} chars/s)") |
|
|
logger.info(f" Decoding: {decoding_time*1000:.2f}ms ({result['tokens_per_second']:.0f} tokens/s)") |
|
|
|
|
|
|
|
|
avg_encoding_time = sum(r["encoding_time"] for r in benchmark_results) / len(benchmark_results) |
|
|
avg_decoding_time = sum(r["decoding_time"] for r in benchmark_results) / len(benchmark_results) |
|
|
|
|
|
success = avg_encoding_time < 1.0 and avg_decoding_time < 1.0 |
|
|
|
|
|
self.test_results["performance"] = { |
|
|
"avg_encoding_time": avg_encoding_time, |
|
|
"avg_decoding_time": avg_decoding_time, |
|
|
"success": success, |
|
|
"details": benchmark_results |
|
|
} |
|
|
|
|
|
logger.info(f"Performance benchmark: {'PASSED' if success else 'FAILED'}") |
|
|
return success |
|
|
|
|
|
def test_compatibility(self) -> bool: |
|
|
""" |
|
|
Test compatibility with transformers ecosystem. |
|
|
|
|
|
Returns: |
|
|
bool: True if compatibility tests pass, False otherwise |
|
|
""" |
|
|
logger.info("=== Testing Compatibility ===") |
|
|
|
|
|
try: |
|
|
|
|
|
text = "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။" |
|
|
|
|
|
|
|
|
tokens_pt = self.tokenizer(text, return_tensors="pt") |
|
|
tokens_list = self.tokenizer(text, return_tensors=None) |
|
|
|
|
|
logger.info("✓ PyTorch tensor support") |
|
|
logger.info("✓ List output support") |
|
|
|
|
|
|
|
|
texts = [ |
|
|
"ဘာသာမန်", |
|
|
"ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။", |
|
|
"မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။" |
|
|
] |
|
|
|
|
|
|
|
|
batch_tokens = self.tokenizer( |
|
|
texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
logger.info(f"✓ Batch processing: {batch_tokens['input_ids'].shape}") |
|
|
|
|
|
|
|
|
if "attention_mask" in batch_tokens: |
|
|
logger.info("✓ Attention mask generation") |
|
|
else: |
|
|
logger.warning("⚠ No attention mask generated") |
|
|
|
|
|
success = True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Compatibility test failed: {e}") |
|
|
success = False |
|
|
|
|
|
self.test_results["compatibility"] = {"success": success} |
|
|
return success |
|
|
|
|
|
def run_all_tests(self) -> bool: |
|
|
""" |
|
|
Run all test suites. |
|
|
|
|
|
Returns: |
|
|
bool: True if all tests pass, False otherwise |
|
|
""" |
|
|
logger.info("🚀 Starting Mon Tokenizer Test Suite") |
|
|
logger.info("=" * 50) |
|
|
|
|
|
|
|
|
if not self.load_tokenizer(): |
|
|
return False |
|
|
|
|
|
|
|
|
test_suites = [ |
|
|
("Basic Functionality", self.test_basic_functionality), |
|
|
("Special Tokens", self.test_special_tokens), |
|
|
("Edge Cases", self.test_edge_cases), |
|
|
("Performance Benchmark", self.test_performance_benchmark), |
|
|
("Compatibility", self.test_compatibility), |
|
|
] |
|
|
|
|
|
results = [] |
|
|
for suite_name, test_func in test_suites: |
|
|
logger.info(f"\n--- {suite_name} ---") |
|
|
success = test_func() |
|
|
results.append((suite_name, success)) |
|
|
logger.info(f"{suite_name}: {'✅ PASSED' if success else '❌ FAILED'}") |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 50) |
|
|
logger.info("📊 TEST SUMMARY") |
|
|
logger.info("=" * 50) |
|
|
|
|
|
passed_suites = sum(1 for _, success in results if success) |
|
|
total_suites = len(results) |
|
|
|
|
|
for suite_name, success in results: |
|
|
status = "✅ PASSED" if success else "❌ FAILED" |
|
|
logger.info(f"{suite_name}: {status}") |
|
|
|
|
|
overall_success = passed_suites == total_suites |
|
|
logger.info(f"\nOverall Result: {passed_suites}/{total_suites} test suites passed") |
|
|
|
|
|
if overall_success: |
|
|
logger.info("🎉 ALL TESTS PASSED! Tokenizer is ready for production.") |
|
|
else: |
|
|
logger.error("⚠️ Some tests failed. Please review the issues above.") |
|
|
|
|
|
return overall_success |
|
|
|
|
|
def generate_test_report(self) -> str: |
|
|
""" |
|
|
Generate a detailed test report. |
|
|
|
|
|
Returns: |
|
|
str: Formatted test report |
|
|
""" |
|
|
if not self.test_results: |
|
|
return "No test results available. Run tests first." |
|
|
|
|
|
report = ["# Mon Tokenizer Test Report", ""] |
|
|
|
|
|
for test_name, result in self.test_results.items(): |
|
|
report.append(f"## {test_name.replace('_', ' ').title()}") |
|
|
|
|
|
if isinstance(result, dict) and "success" in result: |
|
|
status = "✅ PASSED" if result["success"] else "❌ FAILED" |
|
|
report.append(f"Status: {status}") |
|
|
|
|
|
if "passed" in result and "total" in result: |
|
|
report.append(f"Tests: {result['passed']}/{result['total']}") |
|
|
|
|
|
report.append("") |
|
|
|
|
|
return "\n".join(report) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point for the test script.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Test Mon tokenizer Hugging Face integration" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tokenizer-path", |
|
|
default=".", |
|
|
help="Path to tokenizer files (default: current directory)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report", |
|
|
action="store_true", |
|
|
help="Generate detailed test report", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verbose", |
|
|
action="store_true", |
|
|
help="Enable verbose logging", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.verbose: |
|
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
tester = MonTokenizerTester(tokenizer_path=args.tokenizer_path) |
|
|
success = tester.run_all_tests() |
|
|
|
|
|
|
|
|
if args.report: |
|
|
report = tester.generate_test_report() |
|
|
report_path = Path("test_report.md") |
|
|
with open(report_path, "w", encoding="utf-8") as f: |
|
|
f.write(report) |
|
|
logger.info(f"Test report saved to: {report_path}") |
|
|
|
|
|
exit(0 if success else 1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |