| from typing import Dict, List |
| from jinja2 import Environment |
|
|
| from schemas import ExtractedRelation |
|
|
|
|
| def build_visjs_graph(entities: List[str], relations: List[ExtractedRelation]) -> Dict[str, List[Dict]]: |
| """Builds a vertex and edge graph for displaying in UI""" |
|
|
| unique_entities = set(entities) |
| entity_to_id = {entity: idx for idx, entity in enumerate(unique_entities)} |
| nodes = [ |
| {"id": entity_to_id[entity], "label": entity, "title": entity} |
| for entity in unique_entities |
| ] |
|
|
| |
| edges = [] |
| for rel in relations: |
| start_id = entity_to_id.get(rel.start) |
| end_id = entity_to_id.get(rel.to) |
| if start_id is not None and end_id is not None: |
| edges.append({ |
| "from": start_id, |
| "to": end_id, |
| "label": rel.tag, |
| "title": rel.description, |
| "arrows": "to", |
| }) |
|
|
| return {"nodes": nodes, "edges": edges} |
|
|
|
|
| async def fmt_prompt(env: Environment, prompt_id: str, **args): |
| """Returns a formatted prompt""" |
| prompt = env.get_template(prompt_id) |
| return await prompt.render_async(args) |
|
|