Byte-lingua-code / plot_only.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import json
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd # <--- 必须引入 pandas
def main():
parser = argparse.ArgumentParser()
# 这里的默认路径对应你刚才运行的输出路径
parser.add_argument("--output_dir", type=str, default="analysis_output_parallel", help="Directory with partial .json results")
args = parser.parse_args()
print(f"📂 Reading results from: {args.output_dir}")
# 1. 合并结果
final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []}
files_found = 0
if not os.path.exists(args.output_dir):
print(f"❌ Error: Directory {args.output_dir} does not exist.")
return
for filename in os.listdir(args.output_dir):
if filename.startswith("partial_result_") and filename.endswith(".json"):
files_found += 1
file_path = os.path.join(args.output_dir, filename)
try:
with open(file_path, 'r') as f:
data = json.load(f)
for k in final_results:
if k in data:
final_results[k].extend(data[k])
except Exception as e:
print(f"⚠️ Error reading {filename}: {e}")
print(f"✅ Merged data from {files_found} files.")
# 2. 准备绘图数据
plot_records = []
stats_summary = {}
for algo, vals in final_results.items():
if not vals:
continue
# 过滤异常值 (大于 2.0 的通常是极少数的离群点)
cleaned = [v for v in vals if v < 2.0]
# 记录统计信息
stats_summary[algo] = {
"mean": float(np.mean(vals)),
"median": float(np.median(vals)),
"count": len(vals)
}
# 构建用于 DataFrame 的列表
for v in cleaned:
plot_records.append({"Algorithm": algo, "Normalized Edit Distance": v})
if not plot_records:
print("❌ No valid data collected to plot.")
return
# === 关键修正:转换为 Pandas DataFrame ===
df = pd.DataFrame(plot_records)
print(f"📊 Plotting {len(df)} data points...")
# 3. 绘图
plt.figure(figsize=(12, 7))
sns.set_style("whitegrid")
# 使用 DataFrame 进行绘图
sns.kdeplot(
data=df,
x="Normalized Edit Distance",
hue="Algorithm",
fill=True,
common_norm=False,
palette="tab10",
alpha=0.5,
linewidth=2
)
plt.title("Compression Stability Analysis (Impact of 10% Perturbation)")
plt.xlabel("Normalized Levenshtein Distance (Lower = More Stable)")
plt.ylabel("Density")
plt.xlim(0, 1.2) # 聚焦在 0~1.2 范围内
output_img = os.path.join(args.output_dir, "stability_parallel_fixed.png")
plt.savefig(output_img, dpi=300)
print(f"🖼️ Plot saved to: {output_img}")
# 4. 保存统计结果
stats_file = os.path.join(args.output_dir, "final_stats_summary.json")
with open(stats_file, 'w') as f:
json.dump(stats_summary, f, indent=2)
print(f"📄 Stats saved to: {stats_file}")
# 打印简要统计
print("\n=== Summary Stats ===")
for algo, stat in stats_summary.items():
print(f"{algo}: Mean={stat['mean']:.4f}, Count={stat['count']}")
if __name__ == "__main__":
main()