File size: 1,425 Bytes
229a9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import re
import os
import matplotlib.pyplot as plt

def parse_losses(filepath):
    losses = []
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Try to find JSON-like dicts containing 'loss'
    for match in re.finditer(r"\{'loss':\s*'([^']+)'", content):
        try:
            val = float(match.group(1))
            losses.append(val)
        except ValueError:
            pass
            
    for match in re.finditer(r"\{'loss':\s*([0-9\.]+)", content):
        try:
            val = float(match.group(1))
        except ValueError:
            pass

    return losses

def main():
    qwen35_file = 'training_logs/qwen3.5-9b_sft.txt'
    qwen25_file = 'training_logs/qwen2.5-7b-instruct_sft.txt'
    
    qwen35_losses = parse_losses(qwen35_file)
    qwen25_losses = parse_losses(qwen25_file)
    
    plt.figure(figsize=(8, 6))
    
    plt.plot(qwen25_losses, marker='o', linestyle='-', color='blue', label='Qwen 2.5-7B SFT')
    plt.plot(qwen35_losses, marker='s', linestyle='-', color='green', label='Qwen 3.5-9B SFT')
    
    plt.title('SFT Loss Comparison: Qwen 2.5 vs 3.5')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    output_path = os.path.abspath('results/qwen_sft_comparison.png')
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")

if __name__ == '__main__':
    main()