Spaces:
Running
on
T4
Running
on
T4
Update auditqa/process_chunks.py
Browse files- auditqa/process_chunks.py +21 -3
auditqa/process_chunks.py
CHANGED
|
@@ -9,10 +9,27 @@ from langchain_community.vectorstores import Qdrant
|
|
| 9 |
from qdrant_client import QdrantClient
|
| 10 |
from auditqa.reports import files, report_list
|
| 11 |
from langchain.docstore.document import Document
|
|
|
|
|
|
|
|
|
|
| 12 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def open_file(filepath):
|
| 17 |
with open(filepath) as file:
|
| 18 |
simple_json = json.load(file)
|
|
@@ -26,6 +43,7 @@ def load_chunks():
|
|
| 26 |
# we iterate through the files which contain information about its
|
| 27 |
# 'source'=='category', 'subtype', these are used in UI for document selection
|
| 28 |
# which will be used later for filtering database
|
|
|
|
| 29 |
all_documents = {}
|
| 30 |
categories = list(files.keys())
|
| 31 |
# iterate through 'source'
|
|
@@ -70,8 +88,8 @@ def load_chunks():
|
|
| 70 |
# define embedding model
|
| 71 |
embeddings = HuggingFaceEmbeddings(
|
| 72 |
model_kwargs = {'device': device},
|
| 73 |
-
encode_kwargs = {'normalize_embeddings':
|
| 74 |
-
model_name=
|
| 75 |
)
|
| 76 |
# placeholder for collection
|
| 77 |
qdrant_collections = {}
|
|
|
|
| 9 |
from qdrant_client import QdrantClient
|
| 10 |
from auditqa.reports import files, report_list
|
| 11 |
from langchain.docstore.document import Document
|
| 12 |
+
import configparser
|
| 13 |
+
|
| 14 |
+
# read all the necessary variables
|
| 15 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
| 16 |
+
path_to_data = "./reports/"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
##---------------------fucntions -------------------------------------------##
|
| 20 |
+
def getconfig(configfile_path:str):
|
| 21 |
+
"""
|
| 22 |
+
configfile_path: file path of .cfg file
|
| 23 |
+
"""
|
| 24 |
|
| 25 |
+
config = configparser.ConfigParser()
|
| 26 |
|
| 27 |
+
try:
|
| 28 |
+
config.read_file(open(configfile_path))
|
| 29 |
+
return config
|
| 30 |
+
except:
|
| 31 |
+
logging.warning("config file not found")
|
| 32 |
+
|
| 33 |
def open_file(filepath):
|
| 34 |
with open(filepath) as file:
|
| 35 |
simple_json = json.load(file)
|
|
|
|
| 43 |
# we iterate through the files which contain information about its
|
| 44 |
# 'source'=='category', 'subtype', these are used in UI for document selection
|
| 45 |
# which will be used later for filtering database
|
| 46 |
+
config = getconfig("./model_params.cfg")
|
| 47 |
all_documents = {}
|
| 48 |
categories = list(files.keys())
|
| 49 |
# iterate through 'source'
|
|
|
|
| 88 |
# define embedding model
|
| 89 |
embeddings = HuggingFaceEmbeddings(
|
| 90 |
model_kwargs = {'device': device},
|
| 91 |
+
encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
|
| 92 |
+
model_name=config.get('retriever','MODEL')
|
| 93 |
)
|
| 94 |
# placeholder for collection
|
| 95 |
qdrant_collections = {}
|