| import json |
| import heapq |
| def get_top_k_indices(json_file_path, k): |
| """ |
| 读取JSON文件中的列表,返回最大的k个元素的索引 |
| |
| Args: |
| json_file_path: JSON文件路径 |
| k: 需要获取的最大元素的个数 |
| |
| Returns: |
| list: 按元素大小降序排列的索引列表 |
| |
| Raises: |
| FileNotFoundError: 文件不存在时抛出 |
| ValueError: k值无效或数据格式错误时抛出 |
| """ |
| |
| try: |
| with open(json_file_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| except FileNotFoundError: |
| raise FileNotFoundError(f"文件 {json_file_path} 不存在") |
| except json.JSONDecodeError: |
| raise ValueError("JSON文件格式错误") |
| |
| |
| if not isinstance(data, list): |
| raise ValueError("JSON文件内容不是列表") |
| |
| |
| if k <= 0 or k > len(data): |
| raise ValueError(f"k值无效,应在1到{len(data)}之间") |
| |
| |
| value_index_pairs = [(value, idx) for idx, value in enumerate(data)] |
| |
| |
| top_k_pairs = heapq.nlargest(k, value_index_pairs, key=lambda x: x[0]) |
| |
| |
| |
| |
| |
| |
| top_k_indices = [pair[1] for pair in top_k_pairs] |
| |
| return top_k_indices |
|
|
| a = get_top_k_indices('/mnt/bn/life-mllm/users/cxr/quantization/quantization_metric/metrics/alpha/alpha_mlp_Llama-2-7b-hf.json', 10) |
| print(a) |