File size: 1,425 Bytes
229a9b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import re
import os
import matplotlib.pyplot as plt
def parse_losses(filepath):
losses = []
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# Try to find JSON-like dicts containing 'loss'
for match in re.finditer(r"\{'loss':\s*'([^']+)'", content):
try:
val = float(match.group(1))
losses.append(val)
except ValueError:
pass
for match in re.finditer(r"\{'loss':\s*([0-9\.]+)", content):
try:
val = float(match.group(1))
except ValueError:
pass
return losses
def main():
qwen35_file = 'training_logs/qwen3.5-9b_sft.txt'
qwen25_file = 'training_logs/qwen2.5-7b-instruct_sft.txt'
qwen35_losses = parse_losses(qwen35_file)
qwen25_losses = parse_losses(qwen25_file)
plt.figure(figsize=(8, 6))
plt.plot(qwen25_losses, marker='o', linestyle='-', color='blue', label='Qwen 2.5-7B SFT')
plt.plot(qwen35_losses, marker='s', linestyle='-', color='green', label='Qwen 3.5-9B SFT')
plt.title('SFT Loss Comparison: Qwen 2.5 vs 3.5')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
output_path = os.path.abspath('results/qwen_sft_comparison.png')
plt.savefig(output_path)
print(f"Saved plot to {output_path}")
if __name__ == '__main__':
main()
|