Textilindo-AI / scripts /setup_textilindo_training.py
harismlnaslm's picture
Add complete scripts directory with training, testing, and deployment tools
e207dc8
raw
history blame
5.41 kB
#!/usr/bin/env python3
"""
Setup script untuk Textilindo AI Assistant training
Download model dan prepare environment
"""
import os
import sys
import yaml
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_config(config_path):
"""Load configuration from YAML file"""
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
except Exception as e:
logger.error(f"Error loading config: {e}")
return None
def download_model(config):
"""Download base model"""
model_name = config['model_name']
model_path = config['model_path']
logger.info(f"Downloading model: {model_name}")
logger.info(f"Target path: {model_path}")
# Create models directory
Path(model_path).mkdir(parents=True, exist_ok=True)
try:
# Download tokenizer
logger.info("Downloading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=model_path
)
# Download model with memory optimization
logger.info("Downloading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
cache_dir=model_path,
low_cpu_mem_usage=True,
load_in_8bit=True # Use 8-bit quantization for memory efficiency
)
# Save to local path
logger.info(f"Saving model to: {model_path}")
tokenizer.save_pretrained(model_path)
model.save_pretrained(model_path)
logger.info("βœ… Model downloaded successfully!")
return True
except Exception as e:
logger.error(f"Error downloading model: {e}")
return False
def check_requirements():
"""Check if all requirements are met"""
print("πŸ” Checking requirements...")
# Check Python version
if sys.version_info < (3, 8):
print("❌ Python 3.8+ required")
return False
# Check PyTorch
try:
import torch
print(f"βœ… PyTorch {torch.__version__}")
except ImportError:
print("❌ PyTorch not installed")
return False
# Check CUDA availability
if torch.cuda.is_available():
print(f"βœ… CUDA available: {torch.cuda.get_device_name(0)}")
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
print("⚠️ CUDA not available - training will be slower on CPU")
# Check required packages
required_packages = [
'transformers',
'peft',
'datasets',
'accelerate',
'bitsandbytes'
]
missing_packages = []
for package in required_packages:
try:
__import__(package)
print(f"βœ… {package}")
except ImportError:
missing_packages.append(package)
print(f"❌ {package}")
if missing_packages:
print(f"\n❌ Missing packages: {', '.join(missing_packages)}")
print("Install with: pip install " + " ".join(missing_packages))
return False
return True
def main():
print("πŸš€ Textilindo AI Assistant - Setup")
print("=" * 50)
# Check requirements
if not check_requirements():
print("\n❌ Requirements not met. Please install missing packages.")
sys.exit(1)
# Load configuration
config_path = "configs/training_config.yaml"
if not os.path.exists(config_path):
print(f"❌ Config file tidak ditemukan: {config_path}")
sys.exit(1)
config = load_config(config_path)
if not config:
sys.exit(1)
# Check if model already exists
model_path = config['model_path']
if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "config.json")):
print(f"βœ… Model already exists: {model_path}")
print("Skipping download...")
else:
# Download model
print("1️⃣ Downloading base model...")
if not download_model(config):
print("❌ Failed to download model")
sys.exit(1)
# Check dataset
dataset_path = config['dataset_path']
if not os.path.exists(dataset_path):
print(f"❌ Dataset tidak ditemukan: {dataset_path}")
print("Please ensure your dataset is in the correct location")
sys.exit(1)
else:
print(f"βœ… Dataset found: {dataset_path}")
# Check system prompt
system_prompt_path = "configs/system_prompt.md"
if not os.path.exists(system_prompt_path):
print(f"❌ System prompt tidak ditemukan: {system_prompt_path}")
sys.exit(1)
else:
print(f"βœ… System prompt found: {system_prompt_path}")
print("\nβœ… Setup completed successfully!")
print("\nπŸ“‹ Next steps:")
print("1. Run training: python scripts/train_textilindo_ai.py")
print("2. Test model: python scripts/test_textilindo_ai.py")
print("3. Test with LoRA: python scripts/test_textilindo_ai.py --lora_path models/textilindo-ai-lora-YYYYMMDD_HHMMSS")
if __name__ == "__main__":
main()