from transformers import ( AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer, TrainerCallback, EarlyStoppingCallback ) # def get_training_args(output_dir): # return TrainingArguments( # output_dir=output_dir, # num_train_epochs=5, # Increased from 3 # per_device_train_batch_size=4, # per_device_eval_batch_size=4, # gradient_accumulation_steps=8, # Increased from 4 # evaluation_strategy="steps", # eval_steps=50, # More frequent evaluation # save_strategy="steps", # save_steps=50, # logging_dir=f"{output_dir}/logs", # logging_strategy="steps", # logging_steps=10, # learning_rate=5e-5, # Lower learning rate for continued training # weight_decay=0.02, # Increased from 0.01 # warmup_ratio=0.1, # Increased from previous value # lr_scheduler_type="cosine_with_restarts", # Changed from cosine # load_best_model_at_end=True, # metric_for_best_model="eval_loss", # greater_is_better=False, # fp16=True, # gradient_checkpointing=True, # gradient_checkpointing_kwargs={"use_reentrant": False}, # report_to="tensorboard", # remove_unused_columns=False, # optim="adamw_torch_fused", # Using fused optimizer # max_grad_norm=0.5, # Added gradient clipping # ) def get_training_args(output_dir): return TrainingArguments( output_dir=output_dir, num_train_epochs=3, # Reduced epochs for continued training per_device_train_batch_size=2, # Reduced batch size per_device_eval_batch_size=2, gradient_accumulation_steps=16, # Increased for stability evaluation_strategy="steps", eval_steps=25, # More frequent evaluation save_strategy="steps", save_steps=25, learning_rate=1e-5, # Lower learning rate for fine-tuning weight_decay=0.03, # Increased for better regularization warmup_ratio=0.15, # Increased warmup lr_scheduler_type="cosine_with_restarts", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, fp16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, report_to="tensorboard", remove_unused_columns=False, optim="adamw_torch_fused", max_grad_norm=0.3, # Reduced for stability )