| import os | |
| from typing import Optional, Tuple, List, Dict, Union, Any | |
| from pprint import pprint | |
| from tqdm import tqdm | |
| import torch | |
| from functools import wraps | |
| import time | |
| from .utils import Result, batch_iterator | |
| from .input_assembler import AutoAssembler | |
| from .prompt_builder import PromptBuilder | |
| from .result_parser import ResultParser | |
| from .config_manager import ConfigManager | |
| class AutoLLMReranker: | |
| def from_prebuilt(cls, method_name, model_name_or_path, **kwargs) -> "AutoLLMReranker": | |
| import importlib.resources as pkg_resources | |
| default_path = pkg_resources.files("autollmrerank.configs").joinpath(f"{method_name}.yaml") | |
| path = pkg_resources.files("autollmrerank.configs").joinpath(f"{method_name}.yaml") | |
| path = path if path.exists() else default_path | |
| # TODO: figure out what else | |
| llmconfig = {'model_name_or_path': model_name_or_path} | |
| llmconfig.update(kwargs.pop('llm', {})) | |
| config = ConfigManager(path=path, llm=llmconfig, **kwargs).get_config() | |
| return cls(config, **kwargs) | |
| def timer(func): | |
| def wrapper(*args, **kwargs): | |
| start = time.time() | |
| result = func(*args, **kwargs) | |
| end = time.time() | |
| print(f"\n\n{func.__qualname__} took {end - start:.6f} seconds") | |
| return result | |
| return wrapper | |
| def __init__(self, config, **kwargs) -> None: | |
| self.config = config | |
| prompt_builder = PromptBuilder(config=config) | |
| # TODO: make it clearer loaded by argument | |
| if config.llm.backend == 'vllm': | |
| # from .llm_provider.vllm import LLM # for v100 | |
| from .llm_provider.vllm_dev import LLM # for v100 | |
| if (config.llm.backend == 'openai') or (config.llm.backend == 'request'): | |
| from .llm_provider.request import LLM | |
| if config.llm.backend == 'vllm_dev': | |
| from .llm_provider.vllm_dev import LLM | |
| agent = LLM( | |
| model_name_or_path=config.llm.model_name_or_path, | |
| temperature=config.llm.temperature, | |
| top_p=config.llm.top_p, | |
| logprobs=20 if config.llm.use_logits else None, | |
| max_model_len=config.llm.max_model_len, | |
| max_tokens=5 if config.llm.use_logits else 128, | |
| dtype=config.llm.dtype, | |
| num_gpus=max(1, int(torch.cuda.device_count())), | |
| base_url=('http://localhost:8000/v1' or config.llm.base_url), | |
| api_key='EMPTY' | |
| ) | |
| # agent.set_classification(target_ratings=[3,4,5]) | |
| result_parser = ResultParser(use_alpha=config.use_alphabetical) | |
| # initialize the algorithm module | |
| self.assembler = AutoAssembler.from_config( | |
| config, | |
| prompt_builder=prompt_builder, | |
| llm_provider=agent, | |
| result_parser=result_parser, | |
| ) | |
| def convert_run_to_result(run, queries=None, corpus=None): | |
| results = [] | |
| for qid, hits in run.items(): | |
| query = queries[qid] | |
| hit_docs = [] | |
| for docid, score in hits.items(): | |
| hit_docs.append({'docid': docid, 'score': float(score), 'content_dict': corpus[docid]}) | |
| results.append(Result(qid=qid, query=query, hits=hit_docs)) | |
| return results | |
| # TODO: Figure out another input format called `text_pairs=[(q1, [d1, d2, ...]), (q2, ...)]`. | |
| # This is more friendly for users who only have texts. | |
| def rerank( | |
| self, | |
| run: Dict[str, Dict[str, float]], | |
| queries: Dict[str, str], | |
| corpus: Dict[str, Dict[str, str]], | |
| query_batch_size: int = 32, | |
| ) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Args | |
| run (Dict[str, Dict[str, float]]): The initial run to be reranked. | |
| queries: (Dict[str, str]): A dictionary mapping query IDs to query strings. | |
| corpus (Dict[str, Dict[str, str]]): A dictionary mapping document IDs to their content and title (if applicable). | |
| batch_size (int): The number of query (with their results) to process in each batch. | |
| """ | |
| init_results = self.convert_run_to_result(run, queries, corpus) | |
| reranked_results = [] | |
| for batch_results in tqdm( | |
| batch_iterator(init_results, size=query_batch_size), | |
| desc=f"Reranking with query batch size {query_batch_size}", | |
| total=len(init_results) // query_batch_size + 1 | |
| ): | |
| batch_reranked_results = self.assembler.run( | |
| init_results=batch_results, | |
| rank_start=0, | |
| rank_end=min(self.config.rank_end, self.config.top_k), | |
| batch_size=query_batch_size, | |
| num_runs=self.config.num_runs, | |
| ) | |
| reranked_results.extend(batch_reranked_results) | |
| # sort | |
| for r in reranked_results: | |
| r.sort_by(field='score') | |
| # covert back to run | |
| reranked_run = {} | |
| for result in reranked_results: | |
| reranked_run[result.qid] = {} | |
| for rank, hit in enumerate(result.hits, start=1): | |
| reranked_run[result.qid].update({ hit['docid']: hit['score'] }) | |
| return reranked_run | |
| if __name__ == "__main__": | |
| import ir_measures | |
| from ir_measures import * | |
| import importlib | |
| # init config with CLI commands | |
| config = ConfigManager().get_config() | |
| config_dict = ConfigManager().get_config(return_dict=True) | |
| pprint(config_dict) | |
| results = {} | |
| # init reranker | |
| rankllm = AutoLLMReranker(config) | |
| # load data | |
| loader = importlib.import_module(f"autollmrerank.loader_dev.{config.data.loader_type}", package=__name__) | |
| run = loader.load_run(config.data.input_run, topk=getattr(config.data, 'topk', 100)) | |
| corpus, queries, qrels = loader.load(config.data.dataset_name, query_fields=None, doc_fields=None) | |
| qrels = {qid: qrel for qid, qrel in qrels.items() if qid in run} | |
| # reranking | |
| reranked_run = rankllm.rerank(run=run, queries=queries, corpus=corpus, query_batch_size=config.data.batch_size) | |
| # output reranked result | |
| if config.data.output_run is None: | |
| output_path = os.path.join(config.data.input_run.replace('runs', f'runs/{config.rerank_mode}')) | |
| else: | |
| output_path = config.data.output_run | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| for qid in reranked_run: | |
| for i, (docid, score) in enumerate(reranked_run[qid].items()): | |
| f.write(f"{qid} Q0 {docid} {i+1} {score} {config.rerank_mode}\n") | |
| # evaluation | |
| r1 = ir_measures.calc_aggregate([nDCG@10], qrels, run) | |
| r2 = ir_measures.calc_aggregate([nDCG@10], qrels, reranked_run) | |
| # print logs | |
| eval_log = { | |
| 'rerank_mode': config.rerank_mode, | |
| 'model_name_or_path': config.llm.model_name_or_path, | |
| 'dataset_name': f"{config.data.loader_type}:{config.data.dataset_name}", | |
| 'run_path': config.data.input_run, | |
| 'original': r1, | |
| 'reranked': r2 | |
| } | |
| pprint(eval_log) | |
Xet Storage Details
- Size:
- 7.16 kB
- Xet hash:
- 193979b5d9c7234ee69e87a63fa2fb4f33b97ae38a180ff48dcb5af651341fd2
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.