umaiku commited on
Commit
df1b3de
·
verified ·
1 Parent(s): e24fae8

Update app.py

Browse files

Back to version before ChatGPT

Files changed (1) hide show
  1. app.py +83 -217
app.py CHANGED
@@ -1,125 +1,19 @@
1
- import os
2
  import gradio as gr
3
- import pandas as pd
4
- from datetime import datetime
5
-
6
  from transformers import pipeline
7
  from huggingface_hub import InferenceClient, login, snapshot_download
8
-
9
  from langchain_community.vectorstores import FAISS, DistanceStrategy
10
  from langchain_huggingface import HuggingFaceEmbeddings
11
- from langchain_core.vectorstores import VectorStore
 
 
12
 
13
  from smolagents import Tool, HfApiModel, ToolCallingAgent
 
14
 
15
 
16
- # -------- Helpers & Compatibility --------
17
-
18
- def _warn_token():
19
- hf_token = os.getenv("TOKEN") or os.getenv("HF_TOKEN")
20
- if not hf_token:
21
- print("[WARN] No HF token found in env (TOKEN or HF_TOKEN). Private models/endpoints may fail.")
22
- return None
23
- return hf_token
24
-
25
- def _login_hf():
26
- token = _warn_token()
27
- if token:
28
- try:
29
- login(token=token)
30
- except TypeError:
31
- # older huggingface_hub accepted positional
32
- login(token)
33
-
34
- def _stream_chat(client: InferenceClient, messages, max_tokens: int, temperature: float, top_p: float):
35
- """
36
- Try new OpenAI-style streaming first, then older `chat_completion`, then fall back to text_generation.
37
- Yields string chunks.
38
- """
39
- # 1) New: client.chat.completions.create(..., stream=True)
40
- try:
41
- chat = client.chat.completions.create(
42
- messages=messages,
43
- max_tokens=max_tokens,
44
- temperature=temperature,
45
- top_p=top_p,
46
- stream=True,
47
- )
48
- for chunk in chat:
49
- # choices[0].delta.content may be None in some chunks
50
- if chunk and getattr(chunk, "choices", None):
51
- delta = chunk.choices[0].delta
52
- if delta and getattr(delta, "content", None):
53
- yield delta.content
54
- return
55
- except Exception as e_new:
56
- # print for debug, but continue to fallback
57
- print("[INFO] OpenAI-style chat.completions streaming not available:", repr(e_new))
58
-
59
- # 2) Old: client.chat_completion(..., stream=True)
60
- try:
61
- old_stream = client.chat_completion(
62
- messages=messages,
63
- max_tokens=max_tokens,
64
- temperature=temperature,
65
- top_p=top_p,
66
- stream=True,
67
- )
68
- # Old stream objects sometimes have .choices[0].delta.content, sometimes just .token
69
- for chunk in old_stream:
70
- text = None
71
- try:
72
- text = chunk.choices[0].delta.content # may exist
73
- except Exception:
74
- pass
75
- if not text:
76
- # try common fallbacks
77
- text = getattr(chunk, "token", None) or getattr(chunk, "text", None)
78
- if text:
79
- yield text
80
- return
81
- except Exception as e_old:
82
- print("[INFO] Legacy chat_completion streaming not available:", repr(e_old))
83
-
84
- # 3) Fallback: plain text_generation with a single concatenated prompt (no messages)
85
- # The last user message should be the final prompt.
86
- try:
87
- final_prompt = ""
88
- for m in messages:
89
- role = m.get("role", "user")
90
- content = m.get("content", "")
91
- # simple role-tagged concat
92
- final_prompt += f"{role.upper()}: {content}\n"
93
- gen_stream = client.text_generation(
94
- final_prompt,
95
- max_new_tokens=max_tokens,
96
- temperature=temperature,
97
- top_p=top_p,
98
- stream=True,
99
- return_full_text=False,
100
- )
101
- for piece in gen_stream:
102
- # piece may be string or an object with .token/.generated_text
103
- if isinstance(piece, str):
104
- yield piece
105
- else:
106
- text = getattr(piece, "token", None) or getattr(piece, "generated_text", None)
107
- if text:
108
- yield text
109
- return
110
- except Exception as e_gen:
111
- print("[ERROR] All HF streaming methods failed:", repr(e_gen))
112
- yield "\n[Error] Unable to stream from the inference endpoint. Check model name, token, and HF API version.\n"
113
-
114
-
115
- # -------- Data / Vector Store --------
116
-
117
  class RetrieverTool(Tool):
118
  name = "retriever"
119
- description = (
120
- "Using semantic similarity in German, French, English and Italian, retrieves some documents "
121
- "from the knowledge base that have the closest embeddings to the input query."
122
- )
123
  inputs = {
124
  "query": {
125
  "type": "string",
@@ -128,169 +22,141 @@ class RetrieverTool(Tool):
128
  }
129
  output_type = "string"
130
 
131
- def __init__(self, vectordb: VectorStore, df: pd.DataFrame, **kwargs):
132
  super().__init__(**kwargs)
133
  self.vectordb = vectordb
134
- self.df = df
135
 
136
  def forward(self, query: str) -> str:
137
  assert isinstance(query, str), "Your search query must be a string"
138
- try:
139
- docs = self.vectordb.similarity_search(query, k=7)
140
- except Exception as e:
141
- return f"[Retriever error] {e}"
 
142
 
143
  spacer = " \n"
144
  context = ""
145
  nb_char = 100
146
-
147
  for doc in docs:
148
- # Safe metadata access
149
- meta = getattr(doc, "metadata", {}) or {}
150
- case_ref = str(meta.get("case_ref", "") or "")
151
- case_nb = str(meta.get("case_nb", "") or "")
152
- case_date = str(meta.get("case_date", "") or "")
153
- case_url = str(meta.get("case_url", "") or "")
154
-
155
- # Try to find a surrounding extract from the master text
156
- case_text_summary = ""
157
- if case_url:
158
- try:
159
- rows = self.df[self.df["case_url"] == case_url]
160
- if not rows.empty:
161
- case_text = str(rows.iloc[0]["case_text"])
162
- idx = case_text.find(doc.page_content)
163
- if idx >= 0:
164
- start = max(0, idx - nb_char)
165
- end = min(len(case_text), idx + len(doc.page_content) + nb_char)
166
- case_text_summary = case_text[start:end]
167
- except Exception as e:
168
- # If anything goes wrong, fall back to page_content
169
- case_text_summary = doc.page_content
170
-
171
- if not case_text_summary:
172
- case_text_summary = doc.page_content
173
-
174
  context += "#######" + spacer
175
- context += "# Case number: " + (case_ref + " " + case_nb).strip() + spacer
176
- source_name = "Swiss Federal Court" if case_ref == "ATF" else "European Court of Human Rights"
177
- context += "# Case source: " + source_name + spacer
178
- context += "# Case date: " + case_date + spacer
179
- context += "# Case url: " + case_url + spacer
180
  context += "# Case extract: " + case_text_summary + spacer
181
 
182
- return "\nRetrieved documents:\n" + context
183
 
 
184
 
185
- # -------- Init HF / Model / Index --------
186
 
187
- _login_hf()
 
 
 
 
188
 
189
- # Choose your model
190
- MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
191
- # MODEL_ID = "swiss-ai/Apertus-8B-Instruct-2509"
192
 
193
- client = InferenceClient(MODEL_ID)
194
 
195
- # Pull the FAISS dataset snapshot and derive the index path
196
  folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
197
- index_dir = os.path.join(folder, "faiss_index_mpnet_cos")
198
- if not os.path.isdir(index_dir):
199
- # Fallback: try current working directory if you’ve manually placed the index there
200
- alt = os.path.join(os.getcwd(), "faiss_index_mpnet_cos")
201
- if os.path.isdir(alt):
202
- index_dir = alt
203
- else:
204
- print(f"[WARN] Could not find FAISS index directory at {index_dir} or {alt}. Check your dataset contents.")
205
 
206
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
207
 
208
- # Load FAISS (COSINE distance)
209
- vector_db = FAISS.load_local(
210
- index_dir,
211
- embeddings,
212
- allow_dangerous_deserialization=True,
213
- distance_strategy=DistanceStrategy.COSINE,
214
- )
215
-
216
- # Load your case dataframe
217
- CSV_PATH = os.path.join(folder, "bger_cedh_db 1954-2024.csv")
218
- if not os.path.isfile(CSV_PATH):
219
- # also try local if you keep it next to the script
220
- CSV_PATH = "bger_cedh_db 1954-2024.csv"
221
- df = pd.read_csv(CSV_PATH)
222
 
223
- retriever_tool = RetrieverTool(vector_db, df)
224
- agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(MODEL_ID)) # Not used directly, but kept if you expand.
225
 
 
 
226
 
227
- # -------- Chat callback --------
228
 
229
- def respond(
230
- user_message: str,
231
- history: list[tuple[str, str]],
232
- system_message: str,
233
- max_tokens: int,
234
- temperature: float,
235
- top_p: float,
236
- score_threshold: float,
237
- ):
238
  print(datetime.now())
239
- print("[User]", user_message)
 
 
240
 
241
- context = retriever_tool(user_message)
 
 
 
242
 
243
- # Build the RAG prompt
244
- prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
 
 
245
  Respond only to the question asked, response should be relevant to the question and in the same language as the question.
246
  Provide the number of the source document when relevant, as well as the link to the document.
247
  If you cannot find information, do not give up and try calling your retriever again with different arguments!
248
  Always give url of the sources at the end and only answer in the language the question is asked.
249
-
250
  Question:
251
- {user_message}
252
-
253
  {context}
254
  """
255
-
 
 
 
256
  messages = [{"role": "system", "content": system_message}]
257
 
258
- # Rehydrate prior turns (user, assistant)
259
- for u, a in history:
260
- if u:
261
- messages.append({"role": "user", "content": u})
262
- if a:
263
- messages.append({"role": "assistant", "content": a})
264
 
265
  messages.append({"role": "user", "content": prompt})
266
 
267
- response_accum = ""
268
- for chunk_text in _stream_chat(
269
- client,
270
- messages=messages,
 
271
  max_tokens=max_tokens,
 
272
  temperature=temperature,
273
  top_p=top_p,
274
  ):
275
- if chunk_text:
276
- response_accum += chunk_text
277
- yield response_accum
 
278
 
279
 
280
- # -------- Gradio UI --------
281
-
 
282
  demo = gr.ChatInterface(
283
  respond,
284
  additional_inputs=[
285
- gr.Textbox(value="You are assisting a jurist or a lawyer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
286
  gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
287
  gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
288
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
289
  gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
290
  ],
291
  description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
292
  )
293
 
 
294
  if __name__ == "__main__":
295
  print("Ready!")
296
- demo.launch(debug=True)
 
 
1
  import gradio as gr
 
 
 
2
  from transformers import pipeline
3
  from huggingface_hub import InferenceClient, login, snapshot_download
 
4
  from langchain_community.vectorstores import FAISS, DistanceStrategy
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
+ import os
7
+ import pandas as pd
8
+ from datetime import datetime
9
 
10
  from smolagents import Tool, HfApiModel, ToolCallingAgent
11
+ from langchain_core.vectorstores import VectorStore
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class RetrieverTool(Tool):
15
  name = "retriever"
16
+ description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
 
 
 
17
  inputs = {
18
  "query": {
19
  "type": "string",
 
22
  }
23
  output_type = "string"
24
 
25
+ def __init__(self, vectordb: VectorStore, **kwargs):
26
  super().__init__(**kwargs)
27
  self.vectordb = vectordb
 
28
 
29
  def forward(self, query: str) -> str:
30
  assert isinstance(query, str), "Your search query must be a string"
31
+
32
+ docs = self.vectordb.similarity_search(
33
+ query,
34
+ k=7,
35
+ )
36
 
37
  spacer = " \n"
38
  context = ""
39
  nb_char = 100
40
+
41
  for doc in docs:
42
+ case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
43
+ index = case_text.find(doc.page_content)
44
+ start = max(0, index - nb_char)
45
+ end = min(len(case_text), index + len(doc.page_content) + nb_char)
46
+ case_text_summary = case_text[start:end]
47
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  context += "#######" + spacer
49
+ context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
50
+ context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
51
+ context += "# Case date: " + doc.metadata["case_date"] + spacer
52
+ context += "# Case url: " + doc.metadata["case_url"] + spacer
53
+ #context += "# Case text: " + doc.page_content + spacer
54
  context += "# Case extract: " + case_text_summary + spacer
55
 
 
56
 
57
+ return "\nRetrieved documents:\n" + context
58
 
 
59
 
60
+ """
61
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
62
+ """
63
+ HF_TOKEN=os.getenv('TOKEN')
64
+ login(HF_TOKEN)
65
 
66
+ model = "meta-llama/Meta-Llama-3-8B-Instruct"
67
+ #model = "swiss-ai/Apertus-8B-Instruct-2509"
 
68
 
69
+ client = InferenceClient(model)
70
 
 
71
  folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
 
 
 
 
 
 
 
 
72
 
73
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
74
 
75
+ vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ df = pd.read_csv("bger_cedh_db 1954-2024.csv")
 
78
 
79
+ retriever_tool = RetrieverTool(vector_db)
80
+ agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
81
 
82
+ def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
83
 
 
 
 
 
 
 
 
 
 
84
  print(datetime.now())
85
+ context = retriever_tool(message)
86
+
87
+ print(message)
88
 
89
+ # is_law = client.text_generation(f"""Given the user question below, classify it as either being about "Law" or "Other".
90
+ #Do NOT respond with more than one word.
91
+ #Question:
92
+ #{message}""")
93
 
94
+ # print(is_law)
95
+
96
+ if True: #is_law.lower() != "other":
97
+ prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
98
  Respond only to the question asked, response should be relevant to the question and in the same language as the question.
99
  Provide the number of the source document when relevant, as well as the link to the document.
100
  If you cannot find information, do not give up and try calling your retriever again with different arguments!
101
  Always give url of the sources at the end and only answer in the language the question is asked.
102
+
103
  Question:
104
+ {message}
105
+
106
  {context}
107
  """
108
+ else:
109
+ prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message:
110
+ {message}"""
111
+
112
  messages = [{"role": "system", "content": system_message}]
113
 
114
+ for val in history:
115
+ if val[0]:
116
+ messages.append({"role": "user", "content": val[0]})
117
+ if val[1]:
118
+ messages.append({"role": "assistant", "content": val[1]})
 
119
 
120
  messages.append({"role": "user", "content": prompt})
121
 
122
+ response = ""
123
+
124
+
125
+ for message in client.chat_completion(
126
+ messages,
127
  max_tokens=max_tokens,
128
+ stream=True,
129
  temperature=temperature,
130
  top_p=top_p,
131
  ):
132
+ token = message.choices[0].delta.content
133
+
134
+ response += token
135
+ yield response
136
 
137
 
138
+ """
139
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
140
+ """
141
  demo = gr.ChatInterface(
142
  respond,
143
  additional_inputs=[
144
+ gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
145
  gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
146
  gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
147
+ gr.Slider(
148
+ minimum=0.1,
149
+ maximum=1.0,
150
+ value=0.95,
151
+ step=0.05,
152
+ label="Top-p (nucleus sampling)",
153
+ ),
154
  gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
155
  ],
156
  description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
157
  )
158
 
159
+
160
  if __name__ == "__main__":
161
  print("Ready!")
162
+ demo.launch(debug=True)