Jyo-K commited on
Commit
c5ad64e
·
verified ·
1 Parent(s): d1c266e

Upload tools.py

Browse files
Files changed (1) hide show
  1. tools.py +142 -0
tools.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import io
4
+ from typing import List
5
+ from PIL.Image import Image
6
+ from langchain_core.tools import tool
7
+ from pinecone import Pinecone
8
+
9
+ SWIN_API_URL = os.environ.get("SWIN_MODEL_URL", "https://api-inference.huggingface.co/models/Jyo-K/skin_swin")
10
+ HF_API_KEY = os.environ.get("HF_API_KEY")
11
+ EMBEDDING_API_URL = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
12
+ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
13
+ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
14
+
15
+ SWIN_LABELS = [
16
+ '1. Enfeksiyonel',
17
+ '2. Ekzama',
18
+ '3. Akne',
19
+ '4. Pigment',
20
+ '5. Benign',
21
+ '6. Malign'
22
+ ]
23
+
24
+
25
+ _pinecone_client = None
26
+ _pinecone_index = None
27
+
28
+ def get_pinecone_index():
29
+ """Lazily initializes and returns the Pinecone index."""
30
+ global _pinecone_client, _pinecone_index
31
+ if _pinecone_index is None:
32
+ if not PINECONE_API_KEY or not PINECONE_INDEX_NAME:
33
+ raise ValueError("PINECONE_API_KEY or PINECONE_INDEX_NAME not set.")
34
+
35
+ _pinecone_client = Pinecone(api_key=PINECONE_API_KEY)
36
+ _pinecone_index = _pinecone_client.Index(PINECONE_INDEX_NAME)
37
+ print("--- Pinecone Index Initialized ---")
38
+ return _pinecone_index
39
+
40
+ def get_embedding_hf(text: str) -> List[float]:
41
+ """Gets the embedding for a text query using the HF Inference API."""
42
+ if not HF_API_KEY:
43
+ raise ValueError("HF_API_KEY not set. Cannot get embeddings.")
44
+
45
+ response = requests.post(
46
+ EMBEDDING_API_URL,
47
+ headers={"Authorization": f"Bearer {HF_API_KEY}"},
48
+ json={"inputs": text, "options": {"wait_for_model": True}}
49
+ )
50
+ response.raise_for_status()
51
+ return response.json()[0]
52
+
53
+ @tool
54
+ def tool_analyze_skin_image(image: Image) -> str:
55
+ """
56
+ Analyzes a PIL Image of a skin condition using the Swin Transformer
57
+ Inference API and returns the top predicted disease name.
58
+ """
59
+ if not HF_API_KEY:
60
+ return "Error: Hugging Face API token not found."
61
+
62
+ headers = {"Authorization": f"Bearer {HF_API_KEY}"}
63
+
64
+ buffered = io.BytesIO()
65
+ image.save(buffered, format="JPEG")
66
+ img_data = buffered.getvalue()
67
+
68
+ try:
69
+ response = requests.post(
70
+ SWIN_API_URL,
71
+ headers=headers,
72
+ data=img_data
73
+ )
74
+ response.raise_for_status()
75
+ api_output = response.json()
76
+
77
+ if isinstance(api_output, dict) and 'error' in api_output:
78
+ return f"Error from Swin API: {api_output['error']}"
79
+
80
+ if isinstance(api_output, list) and api_output:
81
+ top_prediction = max(api_output, key=lambda x: x['score'])
82
+
83
+ label_name = top_prediction['label']
84
+ if "LABEL_" in label_name:
85
+ try:
86
+ idx = int(label_name.split('_')[-1])
87
+ disease_name_with_prefix = SWIN_LABELS[idx]
88
+ except (IndexError, ValueError):
89
+ return f"Error: Model returned unknown label {label_name}"
90
+ else:
91
+ disease_name_with_prefix = label_name
92
+
93
+ disease_name = disease_name_with_prefix.split('. ')[-1]
94
+ print(f"Image Analysis Tool: Predicted '{disease_name}'")
95
+ return disease_name
96
+ else:
97
+ return "Error: Invalid API response format from Swin model."
98
+
99
+ except Exception as e:
100
+ print(f"Image Analysis Tool Error: {e}")
101
+ return f"Error during Swin API call: {e}"
102
+
103
+ @tool
104
+ def tool_fetch_disease_info(disease_name: str) -> dict:
105
+ """
106
+ Queries the Pinecone vector database to find symptoms and treatment
107
+ information for a given disease name.
108
+ """
109
+ try:
110
+ index = get_pinecone_index()
111
+ except ValueError as e:
112
+ return {"error": str(e)}
113
+
114
+ try:
115
+ print(f"Vector DB Tool: Getting embedding for '{disease_name}'")
116
+ query_embedding = get_embedding_hf(disease_name)
117
+
118
+ query_response = index.query(
119
+ vector=query_embedding,
120
+ top_k=1,
121
+ include_metadata=True
122
+ )
123
+
124
+ if not query_response.get('matches') or query_response['matches'][0]['score'] < 0.5:
125
+ return {"error": f"No high-confidence information found for '{disease_name}' in the database."}
126
+
127
+ metadata = query_response['matches'][0]['metadata']
128
+
129
+ symptoms_str = metadata.get("symptoms", "")
130
+ symptoms_list = [s.strip() for s in symptoms_str.split(',') if s.strip()]
131
+ treatment = metadata.get("treatment", "No treatment information found.")
132
+
133
+ return {
134
+ "disease": metadata.get("disease", disease_name),
135
+ "symptoms": symptoms_list,
136
+ "treatment": treatment,
137
+ "context": metadata.get("text_content", "")
138
+ }
139
+ except Exception as e:
140
+ print(f"Vector DB Tool Error: {e}")
141
+ return {"error": f"Error during Pinecone query: {e}"}
142
+