Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
b6265c3
0
Parent(s):
Sync: Added functionality to save to S3 and save logs to DynamoDB when using cli_topics
Browse files- .dockerignore +27 -0
- .gitattributes +1 -0
- .github/workflows/ci.yml +196 -0
- .github/workflows/simple-test.yml +46 -0
- .github/workflows/sync_to_hf.yml +53 -0
- .gitignore +22 -0
- Dockerfile +166 -0
- README.md +176 -0
- app.py +0 -0
- cli_topics.py +1943 -0
- entrypoint.sh +18 -0
- example_data/case_note_headers_specific.csv +7 -0
- example_data/combined_case_notes.csv +19 -0
- example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_structured_summaries.xlsx +3 -0
- example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis.xlsx +3 -0
- example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis_grouped.xlsx +3 -0
- example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis.xlsx +3 -0
- example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis_zero_shot.xlsx +3 -0
- example_data/dummy_consultation_response.csv +31 -0
- example_data/dummy_consultation_response_themes.csv +26 -0
- intros/intro.txt +7 -0
- lambda_entrypoint.py +466 -0
- load_dynamo_logs.py +102 -0
- load_s3_logs.py +93 -0
- pyproject.toml +147 -0
- requirements.txt +29 -0
- requirements_cpu.txt +24 -0
- requirements_gpu.txt +28 -0
- requirements_lightweight.txt +18 -0
- test/README.md +87 -0
- test/__init__.py +5 -0
- test/mock_inference_server.py +225 -0
- test/mock_llm_calls.py +185 -0
- test/run_tests.py +34 -0
- test/test.py +1067 -0
- test/test_gui_only.py +189 -0
- tools/__init__.py +0 -0
- tools/auth.py +85 -0
- tools/aws_functions.py +387 -0
- tools/combine_sheets_into_xlsx.py +615 -0
- tools/config.py +950 -0
- tools/custom_csvlogger.py +333 -0
- tools/dedup_summaries.py +0 -0
- tools/example_table_outputs.py +94 -0
- tools/helper_functions.py +1245 -0
- tools/llm_api_call.py +0 -0
- tools/llm_funcs.py +1999 -0
- tools/prompts.py +260 -0
- windows_install_llama-cpp-python.txt +111 -0
.dockerignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pdf
|
| 2 |
+
*.url
|
| 3 |
+
*.jpg
|
| 4 |
+
*.png
|
| 5 |
+
*.ipynb
|
| 6 |
+
*.xls
|
| 7 |
+
*.xlsx
|
| 8 |
+
examples/*
|
| 9 |
+
output/*
|
| 10 |
+
tools/__pycache__/*
|
| 11 |
+
build/*
|
| 12 |
+
dist/*
|
| 13 |
+
logs/*
|
| 14 |
+
usage/*
|
| 15 |
+
feedback/*
|
| 16 |
+
test_code/*
|
| 17 |
+
test/tmp/*
|
| 18 |
+
unsloth_compiled_cache/*
|
| 19 |
+
.vscode/*
|
| 20 |
+
llm_topic_modelling.egg-info/*
|
| 21 |
+
input/
|
| 22 |
+
output/
|
| 23 |
+
logs/
|
| 24 |
+
usage/
|
| 25 |
+
feedback/
|
| 26 |
+
config/
|
| 27 |
+
tmp/
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI/CD Pipeline
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [ main ]
|
| 8 |
+
#schedule:
|
| 9 |
+
# Run tests daily at 2 AM UTC
|
| 10 |
+
# - cron: '0 2 * * *'
|
| 11 |
+
|
| 12 |
+
permissions:
|
| 13 |
+
contents: read
|
| 14 |
+
actions: read
|
| 15 |
+
pull-requests: write
|
| 16 |
+
issues: write
|
| 17 |
+
|
| 18 |
+
env:
|
| 19 |
+
PYTHON_VERSION: "3.11"
|
| 20 |
+
|
| 21 |
+
jobs:
|
| 22 |
+
lint:
|
| 23 |
+
runs-on: ubuntu-latest
|
| 24 |
+
steps:
|
| 25 |
+
- uses: actions/checkout@v4
|
| 26 |
+
|
| 27 |
+
- name: Set up Python
|
| 28 |
+
uses: actions/setup-python@v4
|
| 29 |
+
with:
|
| 30 |
+
python-version: ${{ env.PYTHON_VERSION }}
|
| 31 |
+
|
| 32 |
+
- name: Install dependencies
|
| 33 |
+
run: |
|
| 34 |
+
python -m pip install --upgrade pip
|
| 35 |
+
pip install ruff black
|
| 36 |
+
|
| 37 |
+
- name: Run Ruff linter
|
| 38 |
+
run: ruff check .
|
| 39 |
+
|
| 40 |
+
- name: Run Black formatter check
|
| 41 |
+
run: black --check .
|
| 42 |
+
|
| 43 |
+
test-unit:
|
| 44 |
+
runs-on: ubuntu-latest
|
| 45 |
+
strategy:
|
| 46 |
+
matrix:
|
| 47 |
+
python-version: [3.11, 3.12, 3.13]
|
| 48 |
+
|
| 49 |
+
steps:
|
| 50 |
+
- uses: actions/checkout@v4
|
| 51 |
+
|
| 52 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 53 |
+
uses: actions/setup-python@v4
|
| 54 |
+
with:
|
| 55 |
+
python-version: ${{ matrix.python-version }}
|
| 56 |
+
|
| 57 |
+
- name: Cache pip dependencies
|
| 58 |
+
uses: actions/cache@v4
|
| 59 |
+
with:
|
| 60 |
+
path: ~/.cache/pip
|
| 61 |
+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }}
|
| 62 |
+
restore-keys: |
|
| 63 |
+
${{ runner.os }}-pip-
|
| 64 |
+
|
| 65 |
+
- name: Install Python dependencies
|
| 66 |
+
run: |
|
| 67 |
+
python -m pip install --upgrade pip
|
| 68 |
+
pip install -r requirements_lightweight.txt
|
| 69 |
+
pip install pytest pytest-cov pytest-html pytest-xdist
|
| 70 |
+
|
| 71 |
+
- name: Verify example data files
|
| 72 |
+
run: |
|
| 73 |
+
echo "Checking if example data directory exists:"
|
| 74 |
+
ls -la example_data/ || echo "example_data directory not found"
|
| 75 |
+
echo "Checking for specific CSV files:"
|
| 76 |
+
ls -la example_data/*.csv || echo "No CSV files found"
|
| 77 |
+
|
| 78 |
+
- name: Run CLI and GUI tests
|
| 79 |
+
run: |
|
| 80 |
+
cd test
|
| 81 |
+
python run_tests.py
|
| 82 |
+
|
| 83 |
+
- name: Run tests with pytest
|
| 84 |
+
run: |
|
| 85 |
+
pytest test/test.py test/test_gui_only.py -v --tb=short --junitxml=test-results.xml
|
| 86 |
+
|
| 87 |
+
- name: Run tests with coverage
|
| 88 |
+
run: |
|
| 89 |
+
pytest test/test.py test/test_gui_only.py --cov=. --cov-report=xml --cov-report=html --cov-report=term
|
| 90 |
+
|
| 91 |
+
- name: Upload test results
|
| 92 |
+
uses: actions/upload-artifact@v4
|
| 93 |
+
if: always()
|
| 94 |
+
with:
|
| 95 |
+
name: test-results-python-${{ matrix.python-version }}
|
| 96 |
+
path: |
|
| 97 |
+
test-results.xml
|
| 98 |
+
htmlcov/
|
| 99 |
+
coverage.xml
|
| 100 |
+
|
| 101 |
+
test-integration:
|
| 102 |
+
runs-on: ubuntu-latest
|
| 103 |
+
needs: [lint, test-unit]
|
| 104 |
+
|
| 105 |
+
steps:
|
| 106 |
+
- uses: actions/checkout@v4
|
| 107 |
+
|
| 108 |
+
- name: Set up Python
|
| 109 |
+
uses: actions/setup-python@v4
|
| 110 |
+
with:
|
| 111 |
+
python-version: ${{ env.PYTHON_VERSION }}
|
| 112 |
+
|
| 113 |
+
- name: Install dependencies
|
| 114 |
+
run: |
|
| 115 |
+
python -m pip install --upgrade pip
|
| 116 |
+
pip install -r requirements_lightweight.txt
|
| 117 |
+
pip install pytest pytest-cov
|
| 118 |
+
|
| 119 |
+
- name: Verify example data files
|
| 120 |
+
run: |
|
| 121 |
+
echo "Checking if example data directory exists:"
|
| 122 |
+
ls -la example_data/
|
| 123 |
+
echo "Checking for specific CSV files:"
|
| 124 |
+
ls -la example_data/*.csv || echo "No CSV files found"
|
| 125 |
+
|
| 126 |
+
- name: Run integration tests (CLI and GUI)
|
| 127 |
+
run: |
|
| 128 |
+
cd test
|
| 129 |
+
python run_tests.py
|
| 130 |
+
|
| 131 |
+
- name: Test CLI help
|
| 132 |
+
run: |
|
| 133 |
+
python cli_topics.py --help
|
| 134 |
+
|
| 135 |
+
- name: Test CLI version
|
| 136 |
+
run: |
|
| 137 |
+
python -c "import sys; print(f'Python {sys.version}')"
|
| 138 |
+
|
| 139 |
+
security:
|
| 140 |
+
runs-on: ubuntu-latest
|
| 141 |
+
steps:
|
| 142 |
+
- uses: actions/checkout@v4
|
| 143 |
+
|
| 144 |
+
- name: Set up Python
|
| 145 |
+
uses: actions/setup-python@v4
|
| 146 |
+
with:
|
| 147 |
+
python-version: ${{ env.PYTHON_VERSION }}
|
| 148 |
+
|
| 149 |
+
- name: Install dependencies
|
| 150 |
+
run: |
|
| 151 |
+
python -m pip install --upgrade pip
|
| 152 |
+
pip install bandit
|
| 153 |
+
|
| 154 |
+
- name: Run bandit security check
|
| 155 |
+
run: |
|
| 156 |
+
bandit -r . -f json -o bandit-report.json || true
|
| 157 |
+
|
| 158 |
+
- name: Upload security report
|
| 159 |
+
uses: actions/upload-artifact@v4
|
| 160 |
+
if: always()
|
| 161 |
+
with:
|
| 162 |
+
name: security-report
|
| 163 |
+
path: bandit-report.json
|
| 164 |
+
|
| 165 |
+
build:
|
| 166 |
+
runs-on: ubuntu-latest
|
| 167 |
+
needs: [lint, test-unit]
|
| 168 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
| 169 |
+
|
| 170 |
+
steps:
|
| 171 |
+
- uses: actions/checkout@v4
|
| 172 |
+
|
| 173 |
+
- name: Set up Python
|
| 174 |
+
uses: actions/setup-python@v4
|
| 175 |
+
with:
|
| 176 |
+
python-version: ${{ env.PYTHON_VERSION }}
|
| 177 |
+
|
| 178 |
+
- name: Install build dependencies
|
| 179 |
+
run: |
|
| 180 |
+
python -m pip install --upgrade pip
|
| 181 |
+
pip install build twine
|
| 182 |
+
|
| 183 |
+
- name: Build package
|
| 184 |
+
run: |
|
| 185 |
+
python -m build
|
| 186 |
+
|
| 187 |
+
- name: Check package
|
| 188 |
+
run: |
|
| 189 |
+
twine check dist/*
|
| 190 |
+
|
| 191 |
+
- name: Upload build artifacts
|
| 192 |
+
uses: actions/upload-artifact@v4
|
| 193 |
+
with:
|
| 194 |
+
name: dist
|
| 195 |
+
path: dist/
|
| 196 |
+
|
.github/workflows/simple-test.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Simple Test Run
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ dev ]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [ dev ]
|
| 8 |
+
|
| 9 |
+
permissions:
|
| 10 |
+
contents: read
|
| 11 |
+
actions: read
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
test:
|
| 15 |
+
runs-on: ubuntu-latest
|
| 16 |
+
|
| 17 |
+
steps:
|
| 18 |
+
- uses: actions/checkout@v4
|
| 19 |
+
|
| 20 |
+
- name: Set up Python 3.11
|
| 21 |
+
uses: actions/setup-python@v4
|
| 22 |
+
with:
|
| 23 |
+
python-version: "3.11"
|
| 24 |
+
|
| 25 |
+
- name: Install Python dependencies
|
| 26 |
+
run: |
|
| 27 |
+
python -m pip install --upgrade pip
|
| 28 |
+
pip install -r requirements_lightweight.txt
|
| 29 |
+
pip install pytest pytest-cov
|
| 30 |
+
|
| 31 |
+
- name: Verify example data files
|
| 32 |
+
run: |
|
| 33 |
+
echo "Checking if example data directory exists:"
|
| 34 |
+
ls -la example_data/ || echo "example_data directory not found"
|
| 35 |
+
echo "Checking for specific CSV files:"
|
| 36 |
+
ls -la example_data/*.csv || echo "No CSV files found"
|
| 37 |
+
|
| 38 |
+
- name: Run CLI and GUI tests
|
| 39 |
+
run: |
|
| 40 |
+
cd test
|
| 41 |
+
python run_tests.py
|
| 42 |
+
|
| 43 |
+
- name: Run tests with pytest
|
| 44 |
+
run: |
|
| 45 |
+
pytest test/test.py test/test_gui_only.py -v --tb=short
|
| 46 |
+
|
.github/workflows/sync_to_hf.yml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face hub
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [dev]
|
| 5 |
+
|
| 6 |
+
permissions:
|
| 7 |
+
contents: read
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
sync-to-hub:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
steps:
|
| 13 |
+
- uses: actions/checkout@v4
|
| 14 |
+
with:
|
| 15 |
+
fetch-depth: 1 # Only get the latest state
|
| 16 |
+
lfs: true # Download actual LFS files so they can be pushed
|
| 17 |
+
|
| 18 |
+
- name: Install Git LFS
|
| 19 |
+
run: git lfs install
|
| 20 |
+
|
| 21 |
+
- name: Recreate repo history (single-commit force push)
|
| 22 |
+
run: |
|
| 23 |
+
# 1. Capture the message BEFORE we delete the .git folder
|
| 24 |
+
COMMIT_MSG=$(git log -1 --pretty=%B)
|
| 25 |
+
echo "Syncing commit message: $COMMIT_MSG"
|
| 26 |
+
|
| 27 |
+
# 2. DELETE the .git folder.
|
| 28 |
+
# This turns the repo into a standard folder of files.
|
| 29 |
+
rm -rf .git
|
| 30 |
+
|
| 31 |
+
# 3. Re-initialize a brand new git repo
|
| 32 |
+
git init -b main
|
| 33 |
+
git config --global user.name "$HF_USERNAME"
|
| 34 |
+
git config --global user.email "$HF_EMAIL"
|
| 35 |
+
|
| 36 |
+
# 4. Re-install LFS (needs to be done after git init)
|
| 37 |
+
git lfs install
|
| 38 |
+
|
| 39 |
+
# 5. Add the remote
|
| 40 |
+
git remote add hf https://$HF_USERNAME:$HF_TOKEN@huggingface.co/spaces/$HF_USERNAME/$HF_REPO_ID
|
| 41 |
+
|
| 42 |
+
# 6. Add all files
|
| 43 |
+
# Since this is a fresh init, Git sees EVERY file as "New"
|
| 44 |
+
git add .
|
| 45 |
+
|
| 46 |
+
# 7. Commit and Force Push
|
| 47 |
+
git commit -m "Sync: $COMMIT_MSG"
|
| 48 |
+
git push --force hf main
|
| 49 |
+
env:
|
| 50 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 51 |
+
HF_USERNAME: ${{ secrets.HF_USERNAME }}
|
| 52 |
+
HF_EMAIL: ${{ secrets.HF_EMAIL }}
|
| 53 |
+
HF_REPO_ID: ${{ secrets.HF_REPO_ID }}
|
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pdf
|
| 2 |
+
*.url
|
| 3 |
+
*.jpg
|
| 4 |
+
*.png
|
| 5 |
+
*.ipynb
|
| 6 |
+
*.xls
|
| 7 |
+
*.pyc
|
| 8 |
+
examples/*
|
| 9 |
+
output/*
|
| 10 |
+
tools/__pycache__/*
|
| 11 |
+
build/*
|
| 12 |
+
dist/*
|
| 13 |
+
logs/*
|
| 14 |
+
usage/*
|
| 15 |
+
feedback/*
|
| 16 |
+
test_code/*
|
| 17 |
+
config/*
|
| 18 |
+
tmp/*
|
| 19 |
+
test/tmp/*
|
| 20 |
+
unsloth_compiled_cache/*
|
| 21 |
+
.vscode/*
|
| 22 |
+
llm_topic_modelling.egg-info/*
|
Dockerfile
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This Dockerfile is optimised for AWS ECS using Python 3.11, and assumes CPU inference with OpenBLAS for local models.
|
| 2 |
+
# Stage 1: Build dependencies and download models
|
| 3 |
+
FROM public.ecr.aws/docker/library/python:3.11.13-slim-trixie AS builder
|
| 4 |
+
|
| 5 |
+
# Install system dependencies.
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
cmake \
|
| 11 |
+
#libopenblas-dev \
|
| 12 |
+
pkg-config \
|
| 13 |
+
python3-dev \
|
| 14 |
+
libffi-dev \
|
| 15 |
+
&& apt-get clean \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
WORKDIR /src
|
| 19 |
+
|
| 20 |
+
COPY requirements_lightweight.txt .
|
| 21 |
+
|
| 22 |
+
# Set environment variables for OpenBLAS - not necessary if not building from source
|
| 23 |
+
# ENV OPENBLAS_VERBOSE=1
|
| 24 |
+
# ENV CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS"
|
| 25 |
+
|
| 26 |
+
ARG INSTALL_TORCH=False
|
| 27 |
+
ENV INSTALL_TORCH=${INSTALL_TORCH}
|
| 28 |
+
|
| 29 |
+
RUN if [ "$INSTALL_TORCH" = "True" ]; then \
|
| 30 |
+
pip install --no-cache-dir --target=/install torch==2.9.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu; \
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
ARG INSTALL_LLAMA_CPP_PYTHON=False
|
| 34 |
+
ENV INSTALL_LLAMA_CPP_PYTHON=${INSTALL_LLAMA_CPP_PYTHON}
|
| 35 |
+
|
| 36 |
+
RUN if [ "$INSTALL_LLAMA_CPP_PYTHON" = "True" ]; then \
|
| 37 |
+
pip install --no-cache-dir --target=/install https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl; \
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
RUN pip install --no-cache-dir --target=/install -r requirements_lightweight.txt
|
| 41 |
+
|
| 42 |
+
RUN rm requirements_lightweight.txt
|
| 43 |
+
|
| 44 |
+
# ===================================================================
|
| 45 |
+
# Stage 2: A common 'base' for both Lambda and Gradio
|
| 46 |
+
# ===================================================================
|
| 47 |
+
FROM public.ecr.aws/docker/library/python:3.11.13-slim-trixie AS base
|
| 48 |
+
|
| 49 |
+
# Set build-time and runtime environment variable for whether to run in Gradio mode or Lambda mode
|
| 50 |
+
ARG APP_MODE=gradio
|
| 51 |
+
ENV APP_MODE=${APP_MODE}
|
| 52 |
+
|
| 53 |
+
# Install runtime system dependencies
|
| 54 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 55 |
+
libopenblas0 \
|
| 56 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 57 |
+
|
| 58 |
+
ENV APP_HOME=/home/user
|
| 59 |
+
|
| 60 |
+
# Set env variables for Gradio & other apps
|
| 61 |
+
ENV GRADIO_TEMP_DIR=/tmp/gradio_tmp/ \
|
| 62 |
+
MPLCONFIGDIR=/tmp/matplotlib_cache/ \
|
| 63 |
+
GRADIO_OUTPUT_FOLDER=$APP_HOME/app/output/ \
|
| 64 |
+
GRADIO_INPUT_FOLDER=$APP_HOME/app/input/ \
|
| 65 |
+
FEEDBACK_LOGS_FOLDER=$APP_HOME/app/feedback/ \
|
| 66 |
+
ACCESS_LOGS_FOLDER=$APP_HOME/app/logs/ \
|
| 67 |
+
USAGE_LOGS_FOLDER=$APP_HOME/app/usage/ \
|
| 68 |
+
CONFIG_FOLDER=$APP_HOME/app/config/ \
|
| 69 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 70 |
+
GRADIO_SERVER_PORT=7860 \
|
| 71 |
+
PATH=$APP_HOME/.local/bin:$PATH \
|
| 72 |
+
PYTHONPATH=$APP_HOME/app \
|
| 73 |
+
PYTHONUNBUFFERED=1 \
|
| 74 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 75 |
+
GRADIO_ALLOW_FLAGGING=never \
|
| 76 |
+
GRADIO_NUM_PORTS=1 \
|
| 77 |
+
GRADIO_THEME=huggingface \
|
| 78 |
+
SYSTEM=spaces
|
| 79 |
+
|
| 80 |
+
# Copy Python packages from the builder stage
|
| 81 |
+
COPY --from=builder /install /usr/local/lib/python3.11/site-packages/
|
| 82 |
+
COPY --from=builder /install/bin /usr/local/bin/
|
| 83 |
+
|
| 84 |
+
# Copy your application code and entrypoint
|
| 85 |
+
COPY . ${APP_HOME}/app
|
| 86 |
+
COPY entrypoint.sh ${APP_HOME}/app/entrypoint.sh
|
| 87 |
+
# Fix line endings and set execute permissions
|
| 88 |
+
RUN sed -i 's/\r$//' ${APP_HOME}/app/entrypoint.sh \
|
| 89 |
+
&& chmod +x ${APP_HOME}/app/entrypoint.sh
|
| 90 |
+
|
| 91 |
+
WORKDIR ${APP_HOME}/app
|
| 92 |
+
|
| 93 |
+
# ===================================================================
|
| 94 |
+
# FINAL Stage 3: The Lambda Image (runs as root for simplicity)
|
| 95 |
+
# ===================================================================
|
| 96 |
+
FROM base AS lambda
|
| 97 |
+
# Set runtime ENV for Lambda mode
|
| 98 |
+
ENV APP_MODE=lambda
|
| 99 |
+
ENTRYPOINT ["/home/user/app/entrypoint.sh"]
|
| 100 |
+
CMD ["lambda_entrypoint.lambda_handler"]
|
| 101 |
+
|
| 102 |
+
# ===================================================================
|
| 103 |
+
# FINAL Stage 4: The Gradio Image (runs as a secure, non-root user)
|
| 104 |
+
# ===================================================================
|
| 105 |
+
FROM base AS gradio
|
| 106 |
+
# Set runtime ENV for Gradio mode
|
| 107 |
+
ENV APP_MODE=gradio
|
| 108 |
+
|
| 109 |
+
# Create non-root user
|
| 110 |
+
RUN useradd -m -u 1000 user
|
| 111 |
+
|
| 112 |
+
# Create the base application directory and set its ownership
|
| 113 |
+
RUN mkdir -p ${APP_HOME}/app && chown user:user ${APP_HOME}/app
|
| 114 |
+
|
| 115 |
+
# Create required sub-folders within the app directory and set their permissions
|
| 116 |
+
RUN mkdir -p \
|
| 117 |
+
${APP_HOME}/app/output \
|
| 118 |
+
${APP_HOME}/app/input \
|
| 119 |
+
${APP_HOME}/app/logs \
|
| 120 |
+
${APP_HOME}/app/usage \
|
| 121 |
+
${APP_HOME}/app/feedback \
|
| 122 |
+
${APP_HOME}/app/config \
|
| 123 |
+
&& chown user:user \
|
| 124 |
+
${APP_HOME}/app/output \
|
| 125 |
+
${APP_HOME}/app/input \
|
| 126 |
+
${APP_HOME}/app/logs \
|
| 127 |
+
${APP_HOME}/app/usage \
|
| 128 |
+
${APP_HOME}/app/feedback \
|
| 129 |
+
${APP_HOME}/app/config \
|
| 130 |
+
&& chmod 755 \
|
| 131 |
+
${APP_HOME}/app/output \
|
| 132 |
+
${APP_HOME}/app/input \
|
| 133 |
+
${APP_HOME}/app/logs \
|
| 134 |
+
${APP_HOME}/app/usage \
|
| 135 |
+
${APP_HOME}/app/feedback \
|
| 136 |
+
${APP_HOME}/app/config
|
| 137 |
+
|
| 138 |
+
# Now handle the /tmp directories
|
| 139 |
+
RUN mkdir -p /tmp/gradio_tmp /tmp/matplotlib_cache /tmp /var/tmp \
|
| 140 |
+
&& chown user:user /tmp /var/tmp /tmp/gradio_tmp /tmp/matplotlib_cache \
|
| 141 |
+
&& chmod 1777 /tmp /var/tmp /tmp/gradio_tmp /tmp/matplotlib_cache
|
| 142 |
+
|
| 143 |
+
# Fix apply user ownership to all files in the home directory
|
| 144 |
+
RUN chown -R user:user /home/user
|
| 145 |
+
|
| 146 |
+
# Set permissions for Python executable
|
| 147 |
+
RUN chmod 755 /usr/local/bin/python
|
| 148 |
+
|
| 149 |
+
# Declare volumes
|
| 150 |
+
VOLUME ["/tmp/matplotlib_cache"]
|
| 151 |
+
VOLUME ["/tmp/gradio_tmp"]
|
| 152 |
+
VOLUME ["/home/user/app/output"]
|
| 153 |
+
VOLUME ["/home/user/app/input"]
|
| 154 |
+
VOLUME ["/home/user/app/logs"]
|
| 155 |
+
VOLUME ["/home/user/app/usage"]
|
| 156 |
+
VOLUME ["/home/user/app/feedback"]
|
| 157 |
+
VOLUME ["/home/user/app/config"]
|
| 158 |
+
VOLUME ["/tmp"]
|
| 159 |
+
VOLUME ["/var/tmp"]
|
| 160 |
+
|
| 161 |
+
USER user
|
| 162 |
+
|
| 163 |
+
EXPOSE $GRADIO_SERVER_PORT
|
| 164 |
+
|
| 165 |
+
ENTRYPOINT ["/home/user/app/entrypoint.sh"]
|
| 166 |
+
CMD ["python", "app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Large language model topic modelling
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: agpl-3.0
|
| 11 |
+
short_description: Create thematic summaries for open text data with LLMs
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Large language model topic modelling
|
| 15 |
+
|
| 16 |
+
Version: 0.6.0
|
| 17 |
+
|
| 18 |
+
Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets on the main app page, which will show you example outputs from a local model run. API keys for AWS, Azure, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
|
| 19 |
+
|
| 20 |
+
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.
|
| 21 |
+
|
| 22 |
+
Basic use:
|
| 23 |
+
1. On the front page, choose your model for inference. Gemma 3/GPT-OSS will use 'on-device' inference. Calls to Gemini or AWS will require an API key that can be input on the 'LLM and topic extraction' page.
|
| 24 |
+
1. Upload a csv/xlsx/parquet file containing at least one open text column.
|
| 25 |
+
2. Select the relevant open text column from the dropdown.
|
| 26 |
+
3. If you have your own suggested (zero shot) topics, upload this (see examples folder for an example file)
|
| 27 |
+
4. Write a one sentence description of the consultation/context of the open text.
|
| 28 |
+
5. Click 'Extract topics, deduplicate, and summarise'. This will run through the whole analysis process from topic extraction, to topic deduplication, to topic-level and overall summaries.
|
| 29 |
+
6. A summary xlsx file workbook will be created on the front page in the box 'Overall summary xlsx file'. This will combine all the results from the different processes into one workbook.
|
| 30 |
+
|
| 31 |
+
# Installation guide
|
| 32 |
+
|
| 33 |
+
Here is a step-by-step guide to clone the repository, create a virtual environment, and install dependencies from the relevant `requirements` file. This guide assumes you have **Git** and **Python 3.11** installed.
|
| 34 |
+
|
| 35 |
+
-----
|
| 36 |
+
|
| 37 |
+
### Step 1: Clone the Git Repository
|
| 38 |
+
|
| 39 |
+
First, you need to copy the project files to your local machine. Navigate to the directory where you want to store the project using the `cd` (change directory) command. Then, use `git clone` with the repository's URL.
|
| 40 |
+
|
| 41 |
+
1. **Clone the repo:**
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
git clone https://github.com/seanpedrick-case/llm_topic_modelling.git
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
2. **Navigate into the new project folder:**
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
cd llm_topic_modelling
|
| 51 |
+
```
|
| 52 |
+
-----
|
| 53 |
+
|
| 54 |
+
### Step 2: Create and Activate a Virtual Environment
|
| 55 |
+
|
| 56 |
+
A virtual environment is a self-contained directory that holds a specific Python interpreter and its own set of installed packages. This is crucial for isolating your project's dependencies.
|
| 57 |
+
|
| 58 |
+
NOTE: Alternatively you could also create and activate a Conda environment instead of using venv below.
|
| 59 |
+
|
| 60 |
+
1. **Create the virtual environment:** We'll use Python's built-in `venv` module. It's common practice to name the environment folder `.venv`.
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python -m venv .venv
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
*This command tells Python to create a new virtual environment in a folder named `.venv`.*
|
| 67 |
+
|
| 68 |
+
2. **Activate the environment:** You must "activate" the environment to start using it. The command differs based on your operating system and shell.
|
| 69 |
+
|
| 70 |
+
* **On macOS / Linux (bash/zsh):**
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
source .venv/bin/activate
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
* **On Windows (Command Prompt):**
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
.\.venv\Scripts\activate
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
* **On Windows (PowerShell):**
|
| 83 |
+
|
| 84 |
+
```powershell
|
| 85 |
+
.\.venv\Scripts\Activate.ps1
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
You'll know it's active because your command prompt will be prefixed with `(.venv)`.
|
| 89 |
+
|
| 90 |
+
-----
|
| 91 |
+
|
| 92 |
+
### Step 3: Install Dependencies
|
| 93 |
+
|
| 94 |
+
Now that your virtual environment is active, you can install all the required packages. Here you have two options, install from the pyproject.toml file (recommended), or install from requirements files.
|
| 95 |
+
|
| 96 |
+
1. **Install from pyproject.toml (recommended)**
|
| 97 |
+
|
| 98 |
+
You can install the 'lightweight' version of the app to access all available cloud provider or local inference (e.g. llama server, vLLM server) APIs. This version will not allow you to run local models such as Gemma 12b or GPT-OSS-20b 'in-app', i.e. accessible from the GUI interface directly. However, you will have access to AWS, Gemma, or Azure/OpenAI models with appropriate API keys. Use the following command in your environment to install the relevant packages:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
pip install .
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
#### Install torch (optional)
|
| 105 |
+
|
| 106 |
+
If you want to run inference with transformers with full/quantised models, and the associated Unsloth package, you can run the following command for CPU inference. For GPU inference, please refer to the requirements_gpu.txt guide, and the 'Install from a requirements file' section below:
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
pip install .[torch]
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
#### Install llama-cpp-python (optional)
|
| 113 |
+
|
| 114 |
+
You can run quantised GGUF models in-app using llama-cpp-python. However, installation of this package is not always straightforward, particularly considering that wheels are not available for the latest version apart from for linux. This package is not being updated regularly, and so support may be removed for this package in future. Long term I would advise instead looking into running GGUF models using llama-server and calling the API from this app using the lightweight version (details here: https://github.com/ggml-org/llama.cpp).
|
| 115 |
+
|
| 116 |
+
If you do want to install llama-cpp-python in app, first try the following command:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
pip install .[llamacpp]
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
This will install the CPU version of llama-cpp-python. If you want GPU support, first I would try using pip install with specific wheels for your system, e.g. for Linux: See files in https://github.com/abetlen/llama-cpp-python/releases/tag/v0.3.16-cu124 . If you are still struggling, see here for more details on installation here: https://llama-cpp-python.readthedocs.io/en/latest
|
| 123 |
+
|
| 124 |
+
**NOTE:** A sister repository contains [llama-cpp-python 3.16 wheels for Python version 3.11/10](https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/tag/v0.1.0) so that users can avoid having to build the package from source. I also have a guide to building the package on a Windows system [here](https://github.com/seanpedrick-case/llm_topic_modelling/blob/main/windows_install_llama-cpp-python.txt).
|
| 125 |
+
|
| 126 |
+
#### Install mcp version of gradio
|
| 127 |
+
|
| 128 |
+
You can install an mcp-compatible version of gradio for this app with the following command:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
pip install .[mcp]
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
2. **Install from a requirements file (not recommended)**
|
| 135 |
+
|
| 136 |
+
The repo provides several requirements files that are relevant for different situations. To start, I advise installing using the **requirements_lightweight.txt** file, which installs the app with access to all cloud provider or local inference (e.g. llama server, vLLM server) APIs. This approach is much simpler as a first step, and avoids issues with potentially complicated llama-cpp-python installation and GPU management described below.
|
| 137 |
+
|
| 138 |
+
If you want to run models locally 'in app', then you have two further requirements files to choose from:
|
| 139 |
+
|
| 140 |
+
- **requirements_cpu.txt**: Used for Python 3.11 CPU-only environments. Uncomment the requirements under 'Windows' for Windows compatibility. Make sure you have [Openblas](https://github.com/OpenMathLib/OpenBLAS) installed!
|
| 141 |
+
- **requirements_gpu.txt**: Used for Python 3.11 GPU-enabled environments. Uncomment the requirements under 'Windows' for Windows compatibility (CUDA 12.4).
|
| 142 |
+
|
| 143 |
+
Example The below instructions will guide you in how to install the GPU-enabled version of the app for local inference.
|
| 144 |
+
|
| 145 |
+
**Install packages for local model 'in-app' inference from the requirements file:**
|
| 146 |
+
```bash
|
| 147 |
+
pip install -r requirements_gpu.txt
|
| 148 |
+
```
|
| 149 |
+
*This command reads every package name listed in the file and installs it into your `.venv` environment.*
|
| 150 |
+
|
| 151 |
+
NOTE: If default llama-cpp-python installation does not work when installing from the above, go into the requirements_gpu.txt file and uncomment the lines to install a wheel for llama-cpp-python 0.3.16 relevant to your system.
|
| 152 |
+
|
| 153 |
+
### Step 4: Verify CUDA compatibility (if using a GPU environment)
|
| 154 |
+
|
| 155 |
+
Install the relevant toolkit for CUDA 12.4 from here: https://developer.nvidia.com/cuda-12-4-0-download-archive
|
| 156 |
+
|
| 157 |
+
Restart your computer
|
| 158 |
+
|
| 159 |
+
Ensure you have the latest drivers for your NVIDIA GPU. Check your current version and memory availability by running nvidia-smi
|
| 160 |
+
|
| 161 |
+
In command line, CUDA compatibility can be checked by running nvcc --version
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
### Step 5: Ensure you have compatible NVIDIA drivers
|
| 165 |
+
|
| 166 |
+
Make sure you have the latest NVIDIA drivers installed on your system for your GPU (be careful in particular if using WSL that you have drivers compatible with this). Official drivers can be found here: https://www.nvidia.com/en-us/drivers
|
| 167 |
+
|
| 168 |
+
Current drivers can be found by running nvidia-smi in command line
|
| 169 |
+
|
| 170 |
+
### Step 6: Run the app
|
| 171 |
+
|
| 172 |
+
Go to the app project directory. Run python app.py
|
| 173 |
+
|
| 174 |
+
### Step 7: (optional) change default configuration
|
| 175 |
+
|
| 176 |
+
A number of configuration options can be seen the tools/config.py file. You can either pass in these variables as environment variables, or you can create a file in config/app_config.env to read this into the app on initialisation.
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cli_topics.py
ADDED
|
@@ -0,0 +1,1943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
import boto3
|
| 9 |
+
import botocore
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
from tools.aws_functions import download_file_from_s3, export_outputs_to_s3
|
| 13 |
+
from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
|
| 14 |
+
from tools.config import (
|
| 15 |
+
API_URL,
|
| 16 |
+
AWS_ACCESS_KEY,
|
| 17 |
+
AWS_REGION,
|
| 18 |
+
AWS_SECRET_KEY,
|
| 19 |
+
AZURE_OPENAI_API_KEY,
|
| 20 |
+
AZURE_OPENAI_INFERENCE_ENDPOINT,
|
| 21 |
+
BATCH_SIZE_DEFAULT,
|
| 22 |
+
CHOSEN_INFERENCE_SERVER_MODEL,
|
| 23 |
+
CSV_USAGE_LOG_HEADERS,
|
| 24 |
+
DEDUPLICATION_THRESHOLD,
|
| 25 |
+
DEFAULT_COST_CODE,
|
| 26 |
+
DEFAULT_SAMPLED_SUMMARIES,
|
| 27 |
+
DYNAMODB_USAGE_LOG_HEADERS,
|
| 28 |
+
GEMINI_API_KEY,
|
| 29 |
+
GRADIO_TEMP_DIR,
|
| 30 |
+
HF_TOKEN,
|
| 31 |
+
INPUT_FOLDER,
|
| 32 |
+
LLM_MAX_NEW_TOKENS,
|
| 33 |
+
LLM_SEED,
|
| 34 |
+
LLM_TEMPERATURE,
|
| 35 |
+
MAX_TIME_FOR_LOOP,
|
| 36 |
+
OUTPUT_DEBUG_FILES,
|
| 37 |
+
OUTPUT_FOLDER,
|
| 38 |
+
RUN_AWS_FUNCTIONS,
|
| 39 |
+
S3_OUTPUTS_BUCKET,
|
| 40 |
+
S3_OUTPUTS_FOLDER,
|
| 41 |
+
SAVE_LOGS_TO_CSV,
|
| 42 |
+
SAVE_LOGS_TO_DYNAMODB,
|
| 43 |
+
SAVE_OUTPUTS_TO_S3,
|
| 44 |
+
SESSION_OUTPUT_FOLDER,
|
| 45 |
+
USAGE_LOG_DYNAMODB_TABLE_NAME,
|
| 46 |
+
USAGE_LOG_FILE_NAME,
|
| 47 |
+
USAGE_LOGS_FOLDER,
|
| 48 |
+
convert_string_to_boolean,
|
| 49 |
+
default_model_choice,
|
| 50 |
+
default_model_source,
|
| 51 |
+
model_name_map,
|
| 52 |
+
)
|
| 53 |
+
from tools.dedup_summaries import (
|
| 54 |
+
deduplicate_topics,
|
| 55 |
+
deduplicate_topics_llm,
|
| 56 |
+
overall_summary,
|
| 57 |
+
wrapper_summarise_output_topics_per_group,
|
| 58 |
+
)
|
| 59 |
+
from tools.helper_functions import (
|
| 60 |
+
load_in_data_file,
|
| 61 |
+
load_in_previous_data_files,
|
| 62 |
+
)
|
| 63 |
+
from tools.llm_api_call import (
|
| 64 |
+
all_in_one_pipeline,
|
| 65 |
+
validate_topics_wrapper,
|
| 66 |
+
wrapper_extract_topics_per_column_value,
|
| 67 |
+
)
|
| 68 |
+
from tools.prompts import (
|
| 69 |
+
add_existing_topics_prompt,
|
| 70 |
+
add_existing_topics_system_prompt,
|
| 71 |
+
initial_table_prompt,
|
| 72 |
+
initial_table_system_prompt,
|
| 73 |
+
single_para_summary_format_prompt,
|
| 74 |
+
two_para_summary_format_prompt,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _generate_session_hash() -> str:
|
| 79 |
+
"""Generate a unique session hash for logging purposes."""
|
| 80 |
+
return str(uuid.uuid4())[:8]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _download_s3_file_if_needed(
|
| 84 |
+
file_path: str,
|
| 85 |
+
default_filename: str = "downloaded_file",
|
| 86 |
+
aws_access_key: str = "",
|
| 87 |
+
aws_secret_key: str = "",
|
| 88 |
+
aws_region: str = "",
|
| 89 |
+
) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Download a file from S3 if the path starts with 's3://' or 'S3://', otherwise return the path as-is.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
file_path: File path (either local or S3 URL)
|
| 95 |
+
default_filename: Default filename to use if S3 key doesn't have a filename
|
| 96 |
+
aws_access_key: AWS access key ID (optional, uses environment/config if not provided)
|
| 97 |
+
aws_secret_key: AWS secret access key (optional, uses environment/config if not provided)
|
| 98 |
+
aws_region: AWS region (optional, uses environment/config if not provided)
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Local file path (downloaded from S3 or original path)
|
| 102 |
+
"""
|
| 103 |
+
if not file_path:
|
| 104 |
+
return file_path
|
| 105 |
+
|
| 106 |
+
# Check for S3 URL (case-insensitive)
|
| 107 |
+
file_path_stripped = file_path.strip()
|
| 108 |
+
file_path_upper = file_path_stripped.upper()
|
| 109 |
+
if not file_path_upper.startswith("S3://"):
|
| 110 |
+
return file_path
|
| 111 |
+
|
| 112 |
+
# Ensure temp directory exists
|
| 113 |
+
os.makedirs(GRADIO_TEMP_DIR, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# Parse S3 URL: s3://bucket/key (preserve original case for bucket/key)
|
| 116 |
+
# Remove 's3://' prefix (case-insensitive)
|
| 117 |
+
s3_path = (
|
| 118 |
+
file_path_stripped.split("://", 1)[1]
|
| 119 |
+
if "://" in file_path_stripped
|
| 120 |
+
else file_path_stripped
|
| 121 |
+
)
|
| 122 |
+
# Split bucket and key (first '/' separates bucket from key)
|
| 123 |
+
if "/" in s3_path:
|
| 124 |
+
bucket_name_s3, s3_key = s3_path.split("/", 1)
|
| 125 |
+
else:
|
| 126 |
+
# If no key provided, use bucket name as key (unlikely but handle it)
|
| 127 |
+
bucket_name_s3 = s3_path
|
| 128 |
+
s3_key = ""
|
| 129 |
+
|
| 130 |
+
# Get the filename from the S3 key
|
| 131 |
+
filename = os.path.basename(s3_key) if s3_key else bucket_name_s3
|
| 132 |
+
if not filename:
|
| 133 |
+
filename = default_filename
|
| 134 |
+
|
| 135 |
+
# Create local file path in temp directory
|
| 136 |
+
local_file_path = os.path.join(GRADIO_TEMP_DIR, filename)
|
| 137 |
+
|
| 138 |
+
# Download file from S3
|
| 139 |
+
try:
|
| 140 |
+
download_file_from_s3(
|
| 141 |
+
bucket_name=bucket_name_s3,
|
| 142 |
+
key=s3_key,
|
| 143 |
+
local_file_path=local_file_path,
|
| 144 |
+
aws_access_key_textbox=aws_access_key,
|
| 145 |
+
aws_secret_key_textbox=aws_secret_key,
|
| 146 |
+
aws_region_textbox=aws_region,
|
| 147 |
+
)
|
| 148 |
+
print(f"S3 file downloaded successfully: {file_path} -> {local_file_path}")
|
| 149 |
+
return local_file_path
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error downloading file from S3 ({file_path}): {e}")
|
| 152 |
+
raise Exception(f"Failed to download file from S3: {e}")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_username_and_folders(
|
| 156 |
+
username: str = "",
|
| 157 |
+
output_folder_textbox: str = OUTPUT_FOLDER,
|
| 158 |
+
input_folder_textbox: str = INPUT_FOLDER,
|
| 159 |
+
session_output_folder: bool = SESSION_OUTPUT_FOLDER,
|
| 160 |
+
):
|
| 161 |
+
"""Generate session hash and set up output/input folders."""
|
| 162 |
+
# Generate session hash for logging. Either from input user name or generated
|
| 163 |
+
if username:
|
| 164 |
+
out_session_hash = username
|
| 165 |
+
else:
|
| 166 |
+
out_session_hash = _generate_session_hash()
|
| 167 |
+
|
| 168 |
+
if session_output_folder:
|
| 169 |
+
output_folder = output_folder_textbox + out_session_hash + "/"
|
| 170 |
+
input_folder = input_folder_textbox + out_session_hash + "/"
|
| 171 |
+
else:
|
| 172 |
+
output_folder = output_folder_textbox
|
| 173 |
+
input_folder = input_folder_textbox
|
| 174 |
+
|
| 175 |
+
if not os.path.exists(output_folder):
|
| 176 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 177 |
+
if not os.path.exists(input_folder):
|
| 178 |
+
os.makedirs(input_folder, exist_ok=True)
|
| 179 |
+
|
| 180 |
+
return (
|
| 181 |
+
out_session_hash,
|
| 182 |
+
output_folder,
|
| 183 |
+
out_session_hash,
|
| 184 |
+
input_folder,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def upload_outputs_to_s3_if_enabled(
|
| 189 |
+
output_files: list,
|
| 190 |
+
base_file_name: str = None,
|
| 191 |
+
session_hash: str = "",
|
| 192 |
+
s3_output_folder: str = S3_OUTPUTS_FOLDER,
|
| 193 |
+
s3_bucket: str = S3_OUTPUTS_BUCKET,
|
| 194 |
+
save_outputs_to_s3: bool = None,
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Upload output files to S3 if SAVE_OUTPUTS_TO_S3 is enabled.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
output_files: List of output file paths to upload
|
| 201 |
+
base_file_name: Base file name (input file) for organizing S3 folder structure
|
| 202 |
+
session_hash: Session hash to include in S3 path
|
| 203 |
+
s3_output_folder: S3 output folder path
|
| 204 |
+
s3_bucket: S3 bucket name
|
| 205 |
+
save_outputs_to_s3: Override for SAVE_OUTPUTS_TO_S3 config (if None, uses config value)
|
| 206 |
+
"""
|
| 207 |
+
# Use provided value or fall back to config
|
| 208 |
+
if save_outputs_to_s3 is None:
|
| 209 |
+
save_outputs_to_s3 = convert_string_to_boolean(SAVE_OUTPUTS_TO_S3)
|
| 210 |
+
|
| 211 |
+
if not save_outputs_to_s3:
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
if not s3_bucket:
|
| 215 |
+
print("Warning: S3_OUTPUTS_BUCKET not configured. Skipping S3 upload.")
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
if not output_files:
|
| 219 |
+
print("No output files to upload to S3.")
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
# Filter out empty/None values and ensure files exist
|
| 223 |
+
valid_files = []
|
| 224 |
+
for file_path in output_files:
|
| 225 |
+
if file_path and os.path.exists(file_path):
|
| 226 |
+
valid_files.append(file_path)
|
| 227 |
+
elif file_path:
|
| 228 |
+
print(f"Warning: Output file does not exist, skipping: {file_path}")
|
| 229 |
+
|
| 230 |
+
if not valid_files:
|
| 231 |
+
print("No valid output files to upload to S3.")
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
# Construct S3 output folder path
|
| 235 |
+
# Include session hash if provided and SESSION_OUTPUT_FOLDER is enabled
|
| 236 |
+
s3_folder_path = s3_output_folder or ""
|
| 237 |
+
if session_hash and convert_string_to_boolean(SESSION_OUTPUT_FOLDER):
|
| 238 |
+
if s3_folder_path and not s3_folder_path.endswith("/"):
|
| 239 |
+
s3_folder_path += "/"
|
| 240 |
+
s3_folder_path += session_hash + "/"
|
| 241 |
+
|
| 242 |
+
print(f"\nUploading {len(valid_files)} output file(s) to S3...")
|
| 243 |
+
try:
|
| 244 |
+
export_outputs_to_s3(
|
| 245 |
+
file_list_state=valid_files,
|
| 246 |
+
s3_output_folder_state_value=s3_folder_path,
|
| 247 |
+
save_outputs_to_s3_flag=True,
|
| 248 |
+
base_file_state=base_file_name,
|
| 249 |
+
s3_bucket=s3_bucket,
|
| 250 |
+
)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
print(f"Warning: Failed to upload outputs to S3: {e}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def write_usage_log(
|
| 256 |
+
session_hash: str,
|
| 257 |
+
file_name: str,
|
| 258 |
+
text_column: str,
|
| 259 |
+
model_choice: str,
|
| 260 |
+
conversation_metadata: str,
|
| 261 |
+
input_tokens: int,
|
| 262 |
+
output_tokens: int,
|
| 263 |
+
number_of_calls: int,
|
| 264 |
+
estimated_time_taken: float,
|
| 265 |
+
cost_code: str = DEFAULT_COST_CODE,
|
| 266 |
+
save_to_csv: bool = SAVE_LOGS_TO_CSV,
|
| 267 |
+
save_to_dynamodb: bool = SAVE_LOGS_TO_DYNAMODB,
|
| 268 |
+
include_conversation_metadata: bool = False,
|
| 269 |
+
):
|
| 270 |
+
"""
|
| 271 |
+
Write usage log entry to CSV file and/or DynamoDB.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
session_hash: Session identifier
|
| 275 |
+
file_name: Name of the input file
|
| 276 |
+
text_column: Column name used for analysis (as list for CSV)
|
| 277 |
+
model_choice: LLM model used
|
| 278 |
+
conversation_metadata: Metadata string
|
| 279 |
+
input_tokens: Number of input tokens
|
| 280 |
+
output_tokens: Number of output tokens
|
| 281 |
+
number_of_calls: Number of LLM calls
|
| 282 |
+
estimated_time_taken: Time taken in seconds
|
| 283 |
+
cost_code: Cost code for tracking
|
| 284 |
+
save_to_csv: Whether to save to CSV
|
| 285 |
+
save_to_dynamodb: Whether to save to DynamoDB
|
| 286 |
+
include_conversation_metadata: Whether to include conversation metadata in the log
|
| 287 |
+
"""
|
| 288 |
+
# Convert boolean parameters if they're strings
|
| 289 |
+
if isinstance(save_to_csv, str):
|
| 290 |
+
save_to_csv = convert_string_to_boolean(save_to_csv)
|
| 291 |
+
if isinstance(save_to_dynamodb, str):
|
| 292 |
+
save_to_dynamodb = convert_string_to_boolean(save_to_dynamodb)
|
| 293 |
+
|
| 294 |
+
# Return early if neither logging method is enabled
|
| 295 |
+
if not save_to_csv and not save_to_dynamodb:
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
if not conversation_metadata:
|
| 299 |
+
conversation_metadata = ""
|
| 300 |
+
|
| 301 |
+
# Ensure usage logs folder exists
|
| 302 |
+
os.makedirs(USAGE_LOGS_FOLDER, exist_ok=True)
|
| 303 |
+
|
| 304 |
+
# Construct full file path
|
| 305 |
+
usage_log_file_path = os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 306 |
+
|
| 307 |
+
# Prepare data row - order matches app.py component order
|
| 308 |
+
# session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice,
|
| 309 |
+
# conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num,
|
| 310 |
+
# number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop
|
| 311 |
+
data = [
|
| 312 |
+
session_hash,
|
| 313 |
+
file_name,
|
| 314 |
+
(
|
| 315 |
+
text_column
|
| 316 |
+
if isinstance(text_column, str)
|
| 317 |
+
else (text_column[0] if text_column else "")
|
| 318 |
+
),
|
| 319 |
+
model_choice,
|
| 320 |
+
conversation_metadata if conversation_metadata else "",
|
| 321 |
+
input_tokens,
|
| 322 |
+
output_tokens,
|
| 323 |
+
number_of_calls,
|
| 324 |
+
estimated_time_taken,
|
| 325 |
+
cost_code,
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
# Add id and timestamp
|
| 329 |
+
generated_id = str(uuid.uuid4())
|
| 330 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| 331 |
+
data.extend([generated_id, timestamp])
|
| 332 |
+
|
| 333 |
+
# Use custom headers if available, otherwise use default
|
| 334 |
+
# Note: CSVLogger_custom uses component labels, but we need to match what collect_output_csvs_and_create_excel_output expects
|
| 335 |
+
if CSV_USAGE_LOG_HEADERS and len(CSV_USAGE_LOG_HEADERS) == len(data):
|
| 336 |
+
headers = CSV_USAGE_LOG_HEADERS
|
| 337 |
+
else:
|
| 338 |
+
# Default headers - these should match what CSVLogger_custom creates from Gradio component labels
|
| 339 |
+
# The components are: session_hash_textbox, original_data_file_name_textbox, in_colnames,
|
| 340 |
+
# model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num,
|
| 341 |
+
# number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop
|
| 342 |
+
# Since these are hidden components without labels, CSVLogger_custom uses component variable names
|
| 343 |
+
# or default labels. We need to match what collect_output_csvs_and_create_excel_output expects:
|
| 344 |
+
# "Total LLM calls", "Total input tokens", "Total output tokens"
|
| 345 |
+
# But the actual CSV from Gradio likely has: "Number of calls", "Input tokens", "Output tokens"
|
| 346 |
+
# Let's use the names that match what the Excel function expects
|
| 347 |
+
headers = [
|
| 348 |
+
"Session hash",
|
| 349 |
+
"Reference data file name",
|
| 350 |
+
"Select the open text column of interest. In an Excel file, this shows columns across all sheets.",
|
| 351 |
+
"Large language model for topic extraction and summarisation",
|
| 352 |
+
"Conversation metadata",
|
| 353 |
+
"Total input tokens", # Changed from "Input tokens" to match Excel function
|
| 354 |
+
"Total output tokens", # Changed from "Output tokens" to match Excel function
|
| 355 |
+
"Total LLM calls", # Changed from "Number of calls" to match Excel function
|
| 356 |
+
"Estimated time taken (seconds)",
|
| 357 |
+
"Cost code",
|
| 358 |
+
"id",
|
| 359 |
+
"timestamp",
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
# Write to CSV if enabled
|
| 363 |
+
if save_to_csv:
|
| 364 |
+
# Ensure usage logs folder exists
|
| 365 |
+
os.makedirs(USAGE_LOGS_FOLDER, exist_ok=True)
|
| 366 |
+
|
| 367 |
+
# Construct full file path
|
| 368 |
+
usage_log_file_path = os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 369 |
+
|
| 370 |
+
# Write to CSV
|
| 371 |
+
file_exists = os.path.exists(usage_log_file_path)
|
| 372 |
+
with open(
|
| 373 |
+
usage_log_file_path, "a", newline="", encoding="utf-8-sig"
|
| 374 |
+
) as csvfile:
|
| 375 |
+
writer = csv.writer(csvfile)
|
| 376 |
+
if not file_exists:
|
| 377 |
+
# Write headers if file doesn't exist
|
| 378 |
+
writer.writerow(headers)
|
| 379 |
+
writer.writerow(data)
|
| 380 |
+
|
| 381 |
+
# Write to DynamoDB if enabled
|
| 382 |
+
if save_to_dynamodb:
|
| 383 |
+
# DynamoDB logging implementation
|
| 384 |
+
print("Saving to DynamoDB")
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
# Connect to DynamoDB
|
| 388 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 389 |
+
try:
|
| 390 |
+
print("Connecting to DynamoDB via existing SSO connection")
|
| 391 |
+
dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
|
| 392 |
+
dynamodb.meta.client.list_tables()
|
| 393 |
+
except Exception as e:
|
| 394 |
+
print("No SSO credentials found:", e)
|
| 395 |
+
if AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 396 |
+
print("Trying DynamoDB credentials from environment variables")
|
| 397 |
+
dynamodb = boto3.resource(
|
| 398 |
+
"dynamodb",
|
| 399 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 400 |
+
aws_secret_access_key=AWS_SECRET_KEY,
|
| 401 |
+
region_name=AWS_REGION,
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
raise Exception(
|
| 405 |
+
"AWS credentials for DynamoDB logging not found"
|
| 406 |
+
)
|
| 407 |
+
else:
|
| 408 |
+
raise Exception("AWS credentials for DynamoDB logging not found")
|
| 409 |
+
|
| 410 |
+
# Get table name from config
|
| 411 |
+
dynamodb_table_name = USAGE_LOG_DYNAMODB_TABLE_NAME
|
| 412 |
+
if not dynamodb_table_name:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
"USAGE_LOG_DYNAMODB_TABLE_NAME not configured. Cannot save to DynamoDB."
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Determine headers for DynamoDB
|
| 418 |
+
# Use DYNAMODB_USAGE_LOG_HEADERS if available and matches data length,
|
| 419 |
+
# otherwise use CSV_USAGE_LOG_HEADERS if it matches, otherwise use default headers
|
| 420 |
+
# Note: headers and data are guaranteed to have the same length and include id/timestamp
|
| 421 |
+
if DYNAMODB_USAGE_LOG_HEADERS and len(DYNAMODB_USAGE_LOG_HEADERS) == len(
|
| 422 |
+
data
|
| 423 |
+
):
|
| 424 |
+
dynamodb_headers = list(DYNAMODB_USAGE_LOG_HEADERS) # Make a copy
|
| 425 |
+
elif CSV_USAGE_LOG_HEADERS and len(CSV_USAGE_LOG_HEADERS) == len(data):
|
| 426 |
+
dynamodb_headers = list(CSV_USAGE_LOG_HEADERS) # Make a copy
|
| 427 |
+
else:
|
| 428 |
+
# Use the headers we created which are guaranteed to match data
|
| 429 |
+
dynamodb_headers = headers
|
| 430 |
+
|
| 431 |
+
# Check if table exists, create if it doesn't
|
| 432 |
+
try:
|
| 433 |
+
table = dynamodb.Table(dynamodb_table_name)
|
| 434 |
+
table.load()
|
| 435 |
+
except botocore.exceptions.ClientError as e:
|
| 436 |
+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
|
| 437 |
+
print(
|
| 438 |
+
f"Table '{dynamodb_table_name}' does not exist. Creating it..."
|
| 439 |
+
)
|
| 440 |
+
attribute_definitions = [
|
| 441 |
+
{
|
| 442 |
+
"AttributeName": "id",
|
| 443 |
+
"AttributeType": "S",
|
| 444 |
+
}
|
| 445 |
+
]
|
| 446 |
+
|
| 447 |
+
table = dynamodb.create_table(
|
| 448 |
+
TableName=dynamodb_table_name,
|
| 449 |
+
KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}],
|
| 450 |
+
AttributeDefinitions=attribute_definitions,
|
| 451 |
+
BillingMode="PAY_PER_REQUEST",
|
| 452 |
+
)
|
| 453 |
+
# Wait until the table exists
|
| 454 |
+
table.meta.client.get_waiter("table_exists").wait(
|
| 455 |
+
TableName=dynamodb_table_name
|
| 456 |
+
)
|
| 457 |
+
time.sleep(5)
|
| 458 |
+
print(f"Table '{dynamodb_table_name}' created successfully.")
|
| 459 |
+
else:
|
| 460 |
+
raise
|
| 461 |
+
|
| 462 |
+
# Prepare the DynamoDB item to upload
|
| 463 |
+
# Map the headers to values (headers and data should match in length)
|
| 464 |
+
if len(dynamodb_headers) == len(data):
|
| 465 |
+
item = {
|
| 466 |
+
header: str(value) for header, value in zip(dynamodb_headers, data)
|
| 467 |
+
}
|
| 468 |
+
else:
|
| 469 |
+
# Fallback: use the default headers which are guaranteed to match data
|
| 470 |
+
print(
|
| 471 |
+
f"Warning: DynamoDB headers length ({len(dynamodb_headers)}) doesn't match data length ({len(data)}). Using default headers."
|
| 472 |
+
)
|
| 473 |
+
item = {header: str(value) for header, value in zip(headers, data)}
|
| 474 |
+
|
| 475 |
+
# Upload to DynamoDB
|
| 476 |
+
table.put_item(Item=item)
|
| 477 |
+
print("Successfully uploaded log to DynamoDB")
|
| 478 |
+
|
| 479 |
+
except Exception as e:
|
| 480 |
+
print(f"Could not upload log to DynamoDB due to: {e}")
|
| 481 |
+
import traceback
|
| 482 |
+
|
| 483 |
+
traceback.print_exc()
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# --- Main CLI Function ---
|
| 487 |
+
def main(direct_mode_args={}):
|
| 488 |
+
"""
|
| 489 |
+
A unified command-line interface for topic extraction, validation, deduplication, and summarisation.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
direct_mode_args (dict, optional): Dictionary of arguments for direct mode execution.
|
| 493 |
+
If provided, uses these instead of parsing command line arguments.
|
| 494 |
+
"""
|
| 495 |
+
parser = argparse.ArgumentParser(
|
| 496 |
+
description="A versatile CLI for topic extraction, validation, deduplication, and summarisation using LLMs.",
|
| 497 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
| 498 |
+
epilog="""
|
| 499 |
+
Examples:
|
| 500 |
+
|
| 501 |
+
To run these, you need to do the following:
|
| 502 |
+
|
| 503 |
+
- Open a terminal window
|
| 504 |
+
|
| 505 |
+
- CD to the app folder that contains this file (cli_topics.py)
|
| 506 |
+
|
| 507 |
+
- Load the virtual environment using either conda or venv depending on your setup
|
| 508 |
+
|
| 509 |
+
- Run one of the example commands below
|
| 510 |
+
|
| 511 |
+
- The examples below use the free Gemini 2.5 Flash Lite model, that is free with an API key that you can get from here: https://aistudio.google.com/api-keys. You can either set this or API keys for other services as an environment variable (e.g. in config/app_config.py. See the file tools/config.py for more details about variables relevant to each service) or you can set them manually at the time of the function call via command line arguments, such as the following:
|
| 512 |
+
|
| 513 |
+
Google/Gemini: --google_api_key
|
| 514 |
+
AWS Bedrock: --aws_access_key, --aws_secret_key, --aws_region
|
| 515 |
+
Hugging Face (for model download): --hf_token
|
| 516 |
+
Azure/OpenAI: --azure_api_key, --azure_endpoint
|
| 517 |
+
Inference Server endpoint for local models(e.g. llama server, vllm): --api_url
|
| 518 |
+
|
| 519 |
+
- Use --create_xlsx_output to create an Excel file combining all CSV outputs after task completion
|
| 520 |
+
|
| 521 |
+
- Look in the output/ folder to see output files:
|
| 522 |
+
|
| 523 |
+
# Topic Extraction
|
| 524 |
+
|
| 525 |
+
## Extract topics from a CSV file with default settings:
|
| 526 |
+
python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note"
|
| 527 |
+
|
| 528 |
+
## Extract topics with custom model and context:
|
| 529 |
+
python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note" --model_choice "gemini-2.5-flash-lite" --context "Social Care case notes for young people"
|
| 530 |
+
|
| 531 |
+
## Extract topics with grouping:
|
| 532 |
+
python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note" --group_by "Client"
|
| 533 |
+
|
| 534 |
+
## Extract topics with candidate topics (zero-shot):
|
| 535 |
+
python cli_topics.py --task extract --input_file example_data/dummy_consultation_response.csv --text_column "Response text" --candidate_topics example_data/dummy_consultation_response_themes.csv
|
| 536 |
+
|
| 537 |
+
# Topic Validation
|
| 538 |
+
|
| 539 |
+
## Validate previously extracted topics:
|
| 540 |
+
python cli_topics.py --task validate --input_file example_data/combined_case_notes.csv --text_column "Case Note" --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv
|
| 541 |
+
|
| 542 |
+
# Deduplication
|
| 543 |
+
|
| 544 |
+
Note: you will need to change the reference to previous output files to match the exact file names created from the previous task. This includes the relative path to the app folder. Also, the function will create an xlsx output file by default. the --input_file and --text_column arguments are needed for this, unless you pass in --no_xlsx_output as seen below.
|
| 545 |
+
|
| 546 |
+
## Deduplicate topics using fuzzy matching:
|
| 547 |
+
python cli_topics.py --task deduplicate --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --similarity_threshold 90 --no_xlsx_output
|
| 548 |
+
|
| 549 |
+
## Deduplicate topics using LLM:
|
| 550 |
+
python cli_topics.py --task deduplicate --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --method llm --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
|
| 551 |
+
|
| 552 |
+
# Summarisation
|
| 553 |
+
|
| 554 |
+
Note: you will need to change the reference to previous output files to match the exact file names created from the previous task. This includes the relative path to the app folder. Also, the function will create an xlsx output file by default. the --input_file and --text_column arguments are needed for this, unless you pass in --no_xlsx_output as seen below.
|
| 555 |
+
|
| 556 |
+
## Summarise topics:
|
| 557 |
+
python cli_topics.py --task summarise --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
|
| 558 |
+
|
| 559 |
+
## Create overall summary:
|
| 560 |
+
python cli_topics.py --task overall_summary --previous_output_files output/combined_case_notes_col_Case_Note_unique_topics.csv --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
|
| 561 |
+
|
| 562 |
+
# All-in-one pipeline
|
| 563 |
+
|
| 564 |
+
## Run complete pipeline (extract, deduplicate, summarise):
|
| 565 |
+
python cli_topics.py --task all_in_one --input_file example_data/combined_case_notes.csv --text_column "Case Note" --model_choice "gemini-2.5-flash-lite"
|
| 566 |
+
|
| 567 |
+
""",
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# --- Task Selection ---
|
| 571 |
+
task_group = parser.add_argument_group("Task Selection")
|
| 572 |
+
task_group.add_argument(
|
| 573 |
+
"--task",
|
| 574 |
+
choices=[
|
| 575 |
+
"extract",
|
| 576 |
+
"validate",
|
| 577 |
+
"deduplicate",
|
| 578 |
+
"summarise",
|
| 579 |
+
"overall_summary",
|
| 580 |
+
"all_in_one",
|
| 581 |
+
],
|
| 582 |
+
default="extract",
|
| 583 |
+
help="Task to perform: extract (topic extraction), validate (validate topics), deduplicate (deduplicate topics), summarise (summarise topics), overall_summary (create overall summary), or all_in_one (complete pipeline).",
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# --- General Arguments ---
|
| 587 |
+
general_group = parser.add_argument_group("General Options")
|
| 588 |
+
general_group.add_argument(
|
| 589 |
+
"--input_file",
|
| 590 |
+
nargs="+",
|
| 591 |
+
help="Path to the input file(s) to process. Separate multiple files with a space, and use quotes if there are spaces in the file name.",
|
| 592 |
+
)
|
| 593 |
+
general_group.add_argument(
|
| 594 |
+
"--output_dir", default=OUTPUT_FOLDER, help="Directory for all output files."
|
| 595 |
+
)
|
| 596 |
+
general_group.add_argument(
|
| 597 |
+
"--input_dir", default=INPUT_FOLDER, help="Directory for all input files."
|
| 598 |
+
)
|
| 599 |
+
general_group.add_argument(
|
| 600 |
+
"--text_column",
|
| 601 |
+
help="Name of the text column to process (required for extract, validate, and all_in_one tasks).",
|
| 602 |
+
)
|
| 603 |
+
general_group.add_argument(
|
| 604 |
+
"--previous_output_files",
|
| 605 |
+
nargs="+",
|
| 606 |
+
help="Path(s) to previous output files (reference_table and/or unique_topics files) for validate, deduplicate, summarise, and overall_summary tasks.",
|
| 607 |
+
)
|
| 608 |
+
general_group.add_argument(
|
| 609 |
+
"--username", default="", help="Username for the session."
|
| 610 |
+
)
|
| 611 |
+
general_group.add_argument(
|
| 612 |
+
"--save_to_user_folders",
|
| 613 |
+
default=SESSION_OUTPUT_FOLDER,
|
| 614 |
+
help="Whether to save to user folders or not.",
|
| 615 |
+
)
|
| 616 |
+
general_group.add_argument(
|
| 617 |
+
"--excel_sheets",
|
| 618 |
+
nargs="+",
|
| 619 |
+
default=list(),
|
| 620 |
+
help="Specific Excel sheet names to process.",
|
| 621 |
+
)
|
| 622 |
+
general_group.add_argument(
|
| 623 |
+
"--group_by",
|
| 624 |
+
help="Column name to group results by.",
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# --- Model Configuration ---
|
| 628 |
+
model_group = parser.add_argument_group("Model Configuration")
|
| 629 |
+
model_group.add_argument(
|
| 630 |
+
"--model_choice",
|
| 631 |
+
default=default_model_choice,
|
| 632 |
+
help=f"LLM model to use. Default: {default_model_choice}",
|
| 633 |
+
)
|
| 634 |
+
model_group.add_argument(
|
| 635 |
+
"--model_source",
|
| 636 |
+
default=default_model_source,
|
| 637 |
+
help=f"Model source (e.g., 'Google', 'AWS', 'Local'). Default: {default_model_source}",
|
| 638 |
+
)
|
| 639 |
+
model_group.add_argument(
|
| 640 |
+
"--temperature",
|
| 641 |
+
type=float,
|
| 642 |
+
default=LLM_TEMPERATURE,
|
| 643 |
+
help=f"Temperature for LLM generation. Default: {LLM_TEMPERATURE}",
|
| 644 |
+
)
|
| 645 |
+
model_group.add_argument(
|
| 646 |
+
"--batch_size",
|
| 647 |
+
type=int,
|
| 648 |
+
default=BATCH_SIZE_DEFAULT,
|
| 649 |
+
help=f"Number of responses to submit in a single LLM query. Default: {BATCH_SIZE_DEFAULT}",
|
| 650 |
+
)
|
| 651 |
+
model_group.add_argument(
|
| 652 |
+
"--max_tokens",
|
| 653 |
+
type=int,
|
| 654 |
+
default=LLM_MAX_NEW_TOKENS,
|
| 655 |
+
help=f"Maximum tokens for LLM generation. Default: {LLM_MAX_NEW_TOKENS}",
|
| 656 |
+
)
|
| 657 |
+
model_group.add_argument(
|
| 658 |
+
"--google_api_key",
|
| 659 |
+
default=GEMINI_API_KEY,
|
| 660 |
+
help="Google API key for Gemini models.",
|
| 661 |
+
)
|
| 662 |
+
model_group.add_argument(
|
| 663 |
+
"--aws_access_key",
|
| 664 |
+
default=AWS_ACCESS_KEY,
|
| 665 |
+
help="AWS Access Key ID for Bedrock models.",
|
| 666 |
+
)
|
| 667 |
+
model_group.add_argument(
|
| 668 |
+
"--aws_secret_key",
|
| 669 |
+
default=AWS_SECRET_KEY,
|
| 670 |
+
help="AWS Secret Access Key for Bedrock models.",
|
| 671 |
+
)
|
| 672 |
+
model_group.add_argument(
|
| 673 |
+
"--aws_region",
|
| 674 |
+
default=AWS_REGION,
|
| 675 |
+
help="AWS region for Bedrock models.",
|
| 676 |
+
)
|
| 677 |
+
model_group.add_argument(
|
| 678 |
+
"--hf_token",
|
| 679 |
+
default=HF_TOKEN,
|
| 680 |
+
help="Hugging Face token for downloading gated models.",
|
| 681 |
+
)
|
| 682 |
+
model_group.add_argument(
|
| 683 |
+
"--azure_api_key",
|
| 684 |
+
default=AZURE_OPENAI_API_KEY,
|
| 685 |
+
help="Azure/OpenAI API key for Azure/OpenAI models.",
|
| 686 |
+
)
|
| 687 |
+
model_group.add_argument(
|
| 688 |
+
"--azure_endpoint",
|
| 689 |
+
default=AZURE_OPENAI_INFERENCE_ENDPOINT,
|
| 690 |
+
help="Azure Inference endpoint URL.",
|
| 691 |
+
)
|
| 692 |
+
model_group.add_argument(
|
| 693 |
+
"--api_url",
|
| 694 |
+
default=API_URL,
|
| 695 |
+
help=f"Inference server API URL (for local models). Default: {API_URL}",
|
| 696 |
+
)
|
| 697 |
+
model_group.add_argument(
|
| 698 |
+
"--inference_server_model",
|
| 699 |
+
default=CHOSEN_INFERENCE_SERVER_MODEL,
|
| 700 |
+
help=f"Inference server model name to use. Default: {CHOSEN_INFERENCE_SERVER_MODEL}",
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# --- Topic Extraction Arguments ---
|
| 704 |
+
extract_group = parser.add_argument_group("Topic Extraction Options")
|
| 705 |
+
extract_group.add_argument(
|
| 706 |
+
"--context",
|
| 707 |
+
default="",
|
| 708 |
+
help="Context sentence to provide to the LLM for topic extraction.",
|
| 709 |
+
)
|
| 710 |
+
extract_group.add_argument(
|
| 711 |
+
"--candidate_topics",
|
| 712 |
+
help="Path to CSV file with candidate topics for zero-shot extraction.",
|
| 713 |
+
)
|
| 714 |
+
extract_group.add_argument(
|
| 715 |
+
"--force_zero_shot",
|
| 716 |
+
choices=["Yes", "No"],
|
| 717 |
+
default="No",
|
| 718 |
+
help="Force responses into suggested topics. Default: No",
|
| 719 |
+
)
|
| 720 |
+
extract_group.add_argument(
|
| 721 |
+
"--force_single_topic",
|
| 722 |
+
choices=["Yes", "No"],
|
| 723 |
+
default="No",
|
| 724 |
+
help="Ask the model to assign responses to only a single topic. Default: No",
|
| 725 |
+
)
|
| 726 |
+
extract_group.add_argument(
|
| 727 |
+
"--produce_structured_summary",
|
| 728 |
+
choices=["Yes", "No"],
|
| 729 |
+
default="No",
|
| 730 |
+
help="Produce structured summaries using suggested topics as headers. Default: No",
|
| 731 |
+
)
|
| 732 |
+
extract_group.add_argument(
|
| 733 |
+
"--sentiment",
|
| 734 |
+
choices=[
|
| 735 |
+
"Negative or Positive",
|
| 736 |
+
"Negative, Neutral, or Positive",
|
| 737 |
+
"Do not assess sentiment",
|
| 738 |
+
],
|
| 739 |
+
default="Negative or Positive",
|
| 740 |
+
help="Response sentiment analysis option. Default: Negative or Positive",
|
| 741 |
+
)
|
| 742 |
+
extract_group.add_argument(
|
| 743 |
+
"--additional_summary_instructions",
|
| 744 |
+
default="",
|
| 745 |
+
help="Additional instructions for summary format.",
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# --- Validation Arguments ---
|
| 749 |
+
validate_group = parser.add_argument_group("Topic Validation Options")
|
| 750 |
+
validate_group.add_argument(
|
| 751 |
+
"--additional_validation_issues",
|
| 752 |
+
default="",
|
| 753 |
+
help="Additional validation issues for the model to consider (bullet-point list).",
|
| 754 |
+
)
|
| 755 |
+
validate_group.add_argument(
|
| 756 |
+
"--show_previous_table",
|
| 757 |
+
choices=["Yes", "No"],
|
| 758 |
+
default="Yes",
|
| 759 |
+
help="Provide response data to validation process. Default: Yes",
|
| 760 |
+
)
|
| 761 |
+
validate_group.add_argument(
|
| 762 |
+
"--output_debug_files",
|
| 763 |
+
choices=["True", "False"],
|
| 764 |
+
default=OUTPUT_DEBUG_FILES,
|
| 765 |
+
help=f"Output debug files. Default: {OUTPUT_DEBUG_FILES}",
|
| 766 |
+
)
|
| 767 |
+
validate_group.add_argument(
|
| 768 |
+
"--max_time_for_loop",
|
| 769 |
+
type=int,
|
| 770 |
+
default=MAX_TIME_FOR_LOOP,
|
| 771 |
+
help=f"Maximum time for validation loop in seconds. Default: {MAX_TIME_FOR_LOOP}",
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# --- Deduplication Arguments ---
|
| 775 |
+
dedup_group = parser.add_argument_group("Deduplication Options")
|
| 776 |
+
dedup_group.add_argument(
|
| 777 |
+
"--method",
|
| 778 |
+
choices=["fuzzy", "llm"],
|
| 779 |
+
default="fuzzy",
|
| 780 |
+
help="Deduplication method: fuzzy (fuzzy matching) or llm (LLM semantic matching). Default: fuzzy",
|
| 781 |
+
)
|
| 782 |
+
dedup_group.add_argument(
|
| 783 |
+
"--similarity_threshold",
|
| 784 |
+
type=int,
|
| 785 |
+
default=DEDUPLICATION_THRESHOLD,
|
| 786 |
+
help=f"Similarity threshold (0-100) for fuzzy matching. Default: {DEDUPLICATION_THRESHOLD}",
|
| 787 |
+
)
|
| 788 |
+
dedup_group.add_argument(
|
| 789 |
+
"--merge_sentiment",
|
| 790 |
+
choices=["Yes", "No"],
|
| 791 |
+
default="No",
|
| 792 |
+
help="Merge sentiment values together for duplicate subtopics. Default: No",
|
| 793 |
+
)
|
| 794 |
+
dedup_group.add_argument(
|
| 795 |
+
"--merge_general_topics",
|
| 796 |
+
choices=["Yes", "No"],
|
| 797 |
+
default="Yes",
|
| 798 |
+
help="Merge general topic values together for duplicate subtopics. Default: Yes",
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
# --- Summarisation Arguments ---
|
| 802 |
+
summarise_group = parser.add_argument_group("Summarisation Options")
|
| 803 |
+
summarise_group.add_argument(
|
| 804 |
+
"--summary_format",
|
| 805 |
+
choices=["two_paragraph", "single_paragraph"],
|
| 806 |
+
default="two_paragraph",
|
| 807 |
+
help="Summary format type. Default: two_paragraph",
|
| 808 |
+
)
|
| 809 |
+
summarise_group.add_argument(
|
| 810 |
+
"--sample_reference_table",
|
| 811 |
+
choices=["True", "False"],
|
| 812 |
+
default="True",
|
| 813 |
+
help="Sample reference table (recommended for large datasets). Default: True",
|
| 814 |
+
)
|
| 815 |
+
summarise_group.add_argument(
|
| 816 |
+
"--no_of_sampled_summaries",
|
| 817 |
+
type=int,
|
| 818 |
+
default=DEFAULT_SAMPLED_SUMMARIES,
|
| 819 |
+
help=f"Number of summaries per group. Default: {DEFAULT_SAMPLED_SUMMARIES}",
|
| 820 |
+
)
|
| 821 |
+
summarise_group.add_argument(
|
| 822 |
+
"--random_seed",
|
| 823 |
+
type=int,
|
| 824 |
+
default=LLM_SEED,
|
| 825 |
+
help=f"Random seed for sampling. Default: {LLM_SEED}",
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# --- Output Format Arguments ---
|
| 829 |
+
output_group = parser.add_argument_group("Output Format Options")
|
| 830 |
+
output_group.add_argument(
|
| 831 |
+
"--no_xlsx_output",
|
| 832 |
+
dest="create_xlsx_output",
|
| 833 |
+
action="store_false",
|
| 834 |
+
default=True,
|
| 835 |
+
help="Disable creation of Excel (.xlsx) output file. By default, Excel output is created.",
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# --- Logging Arguments ---
|
| 839 |
+
logging_group = parser.add_argument_group("Logging Options")
|
| 840 |
+
logging_group.add_argument(
|
| 841 |
+
"--save_logs_to_csv",
|
| 842 |
+
default=SAVE_LOGS_TO_CSV,
|
| 843 |
+
help="Save processing logs to CSV files.",
|
| 844 |
+
)
|
| 845 |
+
logging_group.add_argument(
|
| 846 |
+
"--save_logs_to_dynamodb",
|
| 847 |
+
default=SAVE_LOGS_TO_DYNAMODB,
|
| 848 |
+
help="Save processing logs to DynamoDB.",
|
| 849 |
+
)
|
| 850 |
+
logging_group.add_argument(
|
| 851 |
+
"--usage_logs_folder",
|
| 852 |
+
default=USAGE_LOGS_FOLDER,
|
| 853 |
+
help="Directory for usage log files.",
|
| 854 |
+
)
|
| 855 |
+
logging_group.add_argument(
|
| 856 |
+
"--cost_code",
|
| 857 |
+
default=DEFAULT_COST_CODE,
|
| 858 |
+
help="Cost code for tracking usage.",
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# Parse arguments - either from command line or direct mode
|
| 862 |
+
if direct_mode_args:
|
| 863 |
+
# Use direct mode arguments
|
| 864 |
+
args = argparse.Namespace(**direct_mode_args)
|
| 865 |
+
else:
|
| 866 |
+
# Parse command line arguments
|
| 867 |
+
args = parser.parse_args()
|
| 868 |
+
|
| 869 |
+
# --- Handle S3 file downloads ---
|
| 870 |
+
# Get AWS credentials from args or fall back to config values
|
| 871 |
+
aws_access_key = getattr(args, "aws_access_key", None) or AWS_ACCESS_KEY or ""
|
| 872 |
+
aws_secret_key = getattr(args, "aws_secret_key", None) or AWS_SECRET_KEY or ""
|
| 873 |
+
aws_region = getattr(args, "aws_region", None) or AWS_REGION or ""
|
| 874 |
+
|
| 875 |
+
# Download input files from S3 if needed
|
| 876 |
+
# Note: args.input_file is typically a list (from CLI nargs="+" or from direct mode)
|
| 877 |
+
# but we also handle pipe-separated strings for compatibility
|
| 878 |
+
if args.input_file:
|
| 879 |
+
if isinstance(args.input_file, list):
|
| 880 |
+
# Handle list of files (may include S3 paths)
|
| 881 |
+
downloaded_files = []
|
| 882 |
+
for file_path in args.input_file:
|
| 883 |
+
downloaded_path = _download_s3_file_if_needed(
|
| 884 |
+
file_path,
|
| 885 |
+
aws_access_key=aws_access_key,
|
| 886 |
+
aws_secret_key=aws_secret_key,
|
| 887 |
+
aws_region=aws_region,
|
| 888 |
+
)
|
| 889 |
+
downloaded_files.append(downloaded_path)
|
| 890 |
+
args.input_file = downloaded_files
|
| 891 |
+
elif isinstance(args.input_file, str):
|
| 892 |
+
# Handle pipe-separated string (for direct mode compatibility)
|
| 893 |
+
if "|" in args.input_file:
|
| 894 |
+
file_list = [f.strip() for f in args.input_file.split("|") if f.strip()]
|
| 895 |
+
downloaded_files = []
|
| 896 |
+
for file_path in file_list:
|
| 897 |
+
downloaded_path = _download_s3_file_if_needed(
|
| 898 |
+
file_path,
|
| 899 |
+
aws_access_key=aws_access_key,
|
| 900 |
+
aws_secret_key=aws_secret_key,
|
| 901 |
+
aws_region=aws_region,
|
| 902 |
+
)
|
| 903 |
+
downloaded_files.append(downloaded_path)
|
| 904 |
+
args.input_file = downloaded_files
|
| 905 |
+
else:
|
| 906 |
+
# Single file path
|
| 907 |
+
args.input_file = [
|
| 908 |
+
_download_s3_file_if_needed(
|
| 909 |
+
args.input_file,
|
| 910 |
+
aws_access_key=aws_access_key,
|
| 911 |
+
aws_secret_key=aws_secret_key,
|
| 912 |
+
aws_region=aws_region,
|
| 913 |
+
)
|
| 914 |
+
]
|
| 915 |
+
|
| 916 |
+
# Download candidate topics file from S3 if needed
|
| 917 |
+
if args.candidate_topics:
|
| 918 |
+
args.candidate_topics = _download_s3_file_if_needed(
|
| 919 |
+
args.candidate_topics,
|
| 920 |
+
default_filename="downloaded_candidate_topics",
|
| 921 |
+
aws_access_key=aws_access_key,
|
| 922 |
+
aws_secret_key=aws_secret_key,
|
| 923 |
+
aws_region=aws_region,
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# --- Override model_choice with inference_server_model if provided ---
|
| 927 |
+
# If inference_server_model is explicitly provided, use it to override model_choice
|
| 928 |
+
# This allows users to specify which inference-server model to use
|
| 929 |
+
if args.inference_server_model:
|
| 930 |
+
# Check if the current model_choice is an inference-server model
|
| 931 |
+
model_source = model_name_map.get(args.model_choice, {}).get(
|
| 932 |
+
"source", default_model_source
|
| 933 |
+
)
|
| 934 |
+
# If model_source is "inference-server" OR if inference_server_model is explicitly provided
|
| 935 |
+
# (different from default), use it
|
| 936 |
+
if (
|
| 937 |
+
model_source == "inference-server"
|
| 938 |
+
or args.inference_server_model != CHOSEN_INFERENCE_SERVER_MODEL
|
| 939 |
+
):
|
| 940 |
+
args.model_choice = args.inference_server_model
|
| 941 |
+
# Ensure the model is registered in model_name_map with inference-server source
|
| 942 |
+
if args.model_choice not in model_name_map:
|
| 943 |
+
model_name_map[args.model_choice] = {
|
| 944 |
+
"short_name": args.model_choice,
|
| 945 |
+
"source": "inference-server",
|
| 946 |
+
}
|
| 947 |
+
# Also update the model_source to ensure it's set correctly
|
| 948 |
+
model_name_map[args.model_choice]["source"] = "inference-server"
|
| 949 |
+
|
| 950 |
+
# --- Initial Setup ---
|
| 951 |
+
# Convert string boolean variables to boolean
|
| 952 |
+
args.save_to_user_folders = convert_string_to_boolean(args.save_to_user_folders)
|
| 953 |
+
args.save_logs_to_csv = convert_string_to_boolean(str(args.save_logs_to_csv))
|
| 954 |
+
args.save_logs_to_dynamodb = convert_string_to_boolean(
|
| 955 |
+
str(args.save_logs_to_dynamodb)
|
| 956 |
+
)
|
| 957 |
+
args.sample_reference_table = args.sample_reference_table == "True"
|
| 958 |
+
args.output_debug_files = args.output_debug_files == "True"
|
| 959 |
+
|
| 960 |
+
# Get username and folders
|
| 961 |
+
(
|
| 962 |
+
session_hash,
|
| 963 |
+
args.output_dir,
|
| 964 |
+
_,
|
| 965 |
+
args.input_dir,
|
| 966 |
+
) = get_username_and_folders(
|
| 967 |
+
username=args.username,
|
| 968 |
+
output_folder_textbox=args.output_dir,
|
| 969 |
+
input_folder_textbox=args.input_dir,
|
| 970 |
+
session_output_folder=args.save_to_user_folders,
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
print(
|
| 974 |
+
f"Conducting analyses with user {args.username or session_hash}. Outputs will be saved to {args.output_dir}."
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
# --- Route to the Correct Workflow Based on Task ---
|
| 978 |
+
|
| 979 |
+
# Validate input_file requirement for tasks that need it
|
| 980 |
+
if args.task in ["extract", "validate", "all_in_one"] and not args.input_file:
|
| 981 |
+
print(f"Error: --input_file is required for '{args.task}' task.")
|
| 982 |
+
return
|
| 983 |
+
|
| 984 |
+
if (
|
| 985 |
+
args.task in ["validate", "deduplicate", "summarise", "overall_summary"]
|
| 986 |
+
and not args.previous_output_files
|
| 987 |
+
):
|
| 988 |
+
print(f"Error: --previous_output_files is required for '{args.task}' task.")
|
| 989 |
+
return
|
| 990 |
+
|
| 991 |
+
if args.task in ["extract", "validate", "all_in_one"] and not args.text_column:
|
| 992 |
+
print(f"Error: --text_column is required for '{args.task}' task.")
|
| 993 |
+
return
|
| 994 |
+
|
| 995 |
+
start_time = time.time()
|
| 996 |
+
|
| 997 |
+
try:
|
| 998 |
+
# Task 1: Extract Topics
|
| 999 |
+
if args.task == "extract":
|
| 1000 |
+
print("--- Starting Topic Extraction Workflow... ---")
|
| 1001 |
+
|
| 1002 |
+
# Load data file
|
| 1003 |
+
if isinstance(args.input_file, str):
|
| 1004 |
+
args.input_file = [args.input_file]
|
| 1005 |
+
|
| 1006 |
+
file_data, file_name, total_number_of_batches = load_in_data_file(
|
| 1007 |
+
file_paths=args.input_file,
|
| 1008 |
+
in_colnames=[args.text_column],
|
| 1009 |
+
batch_size=args.batch_size,
|
| 1010 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
# Prepare candidate topics if provided
|
| 1014 |
+
candidate_topics = None
|
| 1015 |
+
if args.candidate_topics:
|
| 1016 |
+
candidate_topics = args.candidate_topics
|
| 1017 |
+
|
| 1018 |
+
# Determine summary format prompt
|
| 1019 |
+
summary_format_prompt = (
|
| 1020 |
+
two_para_summary_format_prompt
|
| 1021 |
+
if args.summary_format == "two_paragraph"
|
| 1022 |
+
else single_para_summary_format_prompt
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# Run extraction
|
| 1026 |
+
(
|
| 1027 |
+
display_markdown,
|
| 1028 |
+
master_topic_df_state,
|
| 1029 |
+
master_unique_topics_df_state,
|
| 1030 |
+
master_reference_df_state,
|
| 1031 |
+
topic_extraction_output_files,
|
| 1032 |
+
text_output_file_list_state,
|
| 1033 |
+
latest_batch_completed,
|
| 1034 |
+
log_files_output,
|
| 1035 |
+
log_files_output_list_state,
|
| 1036 |
+
conversation_metadata_textbox,
|
| 1037 |
+
estimated_time_taken_number,
|
| 1038 |
+
deduplication_input_files,
|
| 1039 |
+
summarisation_input_files,
|
| 1040 |
+
modifiable_unique_topics_df_state,
|
| 1041 |
+
modification_input_files,
|
| 1042 |
+
in_join_files,
|
| 1043 |
+
missing_df_state,
|
| 1044 |
+
input_tokens_num,
|
| 1045 |
+
output_tokens_num,
|
| 1046 |
+
number_of_calls_num,
|
| 1047 |
+
output_messages_textbox,
|
| 1048 |
+
logged_content_df,
|
| 1049 |
+
) = wrapper_extract_topics_per_column_value(
|
| 1050 |
+
grouping_col=args.group_by,
|
| 1051 |
+
in_data_file=args.input_file,
|
| 1052 |
+
file_data=file_data,
|
| 1053 |
+
initial_existing_topics_table=pd.DataFrame(),
|
| 1054 |
+
initial_existing_reference_df=pd.DataFrame(),
|
| 1055 |
+
initial_existing_topic_summary_df=pd.DataFrame(),
|
| 1056 |
+
initial_unique_table_df_display_table_markdown="",
|
| 1057 |
+
original_file_name=file_name,
|
| 1058 |
+
total_number_of_batches=total_number_of_batches,
|
| 1059 |
+
in_api_key=args.google_api_key,
|
| 1060 |
+
temperature=args.temperature,
|
| 1061 |
+
chosen_cols=[args.text_column],
|
| 1062 |
+
model_choice=args.model_choice,
|
| 1063 |
+
candidate_topics=candidate_topics,
|
| 1064 |
+
initial_first_loop_state=True,
|
| 1065 |
+
initial_all_metadata_content_str="",
|
| 1066 |
+
initial_latest_batch_completed=0,
|
| 1067 |
+
initial_time_taken=0,
|
| 1068 |
+
batch_size=args.batch_size,
|
| 1069 |
+
context_textbox=args.context,
|
| 1070 |
+
sentiment_checkbox=args.sentiment,
|
| 1071 |
+
force_zero_shot_radio=args.force_zero_shot,
|
| 1072 |
+
in_excel_sheets=args.excel_sheets,
|
| 1073 |
+
force_single_topic_radio=args.force_single_topic,
|
| 1074 |
+
produce_structured_summary_radio=args.produce_structured_summary,
|
| 1075 |
+
aws_access_key_textbox=args.aws_access_key,
|
| 1076 |
+
aws_secret_key_textbox=args.aws_secret_key,
|
| 1077 |
+
aws_region_textbox=args.aws_region,
|
| 1078 |
+
hf_api_key_textbox=args.hf_token,
|
| 1079 |
+
azure_api_key_textbox=args.azure_api_key,
|
| 1080 |
+
azure_endpoint_textbox=args.azure_endpoint,
|
| 1081 |
+
output_folder=args.output_dir,
|
| 1082 |
+
existing_logged_content=list(),
|
| 1083 |
+
additional_instructions_summary_format=args.additional_summary_instructions,
|
| 1084 |
+
additional_validation_issues_provided="",
|
| 1085 |
+
show_previous_table="Yes",
|
| 1086 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1087 |
+
max_tokens=args.max_tokens,
|
| 1088 |
+
model_name_map=model_name_map,
|
| 1089 |
+
max_time_for_loop=99999,
|
| 1090 |
+
reasoning_suffix="",
|
| 1091 |
+
CHOSEN_LOCAL_MODEL_TYPE="",
|
| 1092 |
+
output_debug_files=str(args.output_debug_files),
|
| 1093 |
+
model=None,
|
| 1094 |
+
tokenizer=None,
|
| 1095 |
+
assistant_model=None,
|
| 1096 |
+
max_rows=999999,
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
end_time = time.time()
|
| 1100 |
+
processing_time = end_time - start_time
|
| 1101 |
+
|
| 1102 |
+
print("\n--- Topic Extraction Complete ---")
|
| 1103 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1104 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1105 |
+
if topic_extraction_output_files:
|
| 1106 |
+
print("Generated Files:", sorted(topic_extraction_output_files))
|
| 1107 |
+
|
| 1108 |
+
# Write usage log (before Excel creation so it can be included in Excel)
|
| 1109 |
+
write_usage_log(
|
| 1110 |
+
session_hash=session_hash,
|
| 1111 |
+
file_name=file_name,
|
| 1112 |
+
text_column=args.text_column,
|
| 1113 |
+
model_choice=args.model_choice,
|
| 1114 |
+
conversation_metadata=conversation_metadata_textbox or "",
|
| 1115 |
+
input_tokens=input_tokens_num or 0,
|
| 1116 |
+
output_tokens=output_tokens_num or 0,
|
| 1117 |
+
number_of_calls=number_of_calls_num or 0,
|
| 1118 |
+
estimated_time_taken=estimated_time_taken_number or processing_time,
|
| 1119 |
+
cost_code=args.cost_code,
|
| 1120 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1121 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# Create Excel output if requested
|
| 1125 |
+
xlsx_files = []
|
| 1126 |
+
if args.create_xlsx_output:
|
| 1127 |
+
print("\nCreating Excel output file...")
|
| 1128 |
+
try:
|
| 1129 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1130 |
+
in_data_files=args.input_file,
|
| 1131 |
+
chosen_cols=[args.text_column],
|
| 1132 |
+
reference_data_file_name_textbox=file_name,
|
| 1133 |
+
in_group_col=args.group_by,
|
| 1134 |
+
model_choice=args.model_choice,
|
| 1135 |
+
master_reference_df_state=master_reference_df_state,
|
| 1136 |
+
master_unique_topics_df_state=master_unique_topics_df_state,
|
| 1137 |
+
summarised_output_df=pd.DataFrame(), # No summaries yet
|
| 1138 |
+
missing_df_state=missing_df_state,
|
| 1139 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1140 |
+
usage_logs_location=(
|
| 1141 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1142 |
+
if args.save_logs_to_csv
|
| 1143 |
+
else ""
|
| 1144 |
+
),
|
| 1145 |
+
model_name_map=model_name_map,
|
| 1146 |
+
output_folder=args.output_dir,
|
| 1147 |
+
structured_summaries=args.produce_structured_summary,
|
| 1148 |
+
)
|
| 1149 |
+
if xlsx_files:
|
| 1150 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1151 |
+
except Exception as e:
|
| 1152 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1153 |
+
|
| 1154 |
+
# Upload outputs to S3 if enabled
|
| 1155 |
+
all_output_files = (
|
| 1156 |
+
list(topic_extraction_output_files)
|
| 1157 |
+
if topic_extraction_output_files
|
| 1158 |
+
else []
|
| 1159 |
+
)
|
| 1160 |
+
if xlsx_files:
|
| 1161 |
+
all_output_files.extend(xlsx_files)
|
| 1162 |
+
upload_outputs_to_s3_if_enabled(
|
| 1163 |
+
output_files=all_output_files,
|
| 1164 |
+
base_file_name=file_name,
|
| 1165 |
+
session_hash=session_hash,
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
# Task 2: Validate Topics
|
| 1169 |
+
elif args.task == "validate":
|
| 1170 |
+
print("--- Starting Topic Validation Workflow... ---")
|
| 1171 |
+
|
| 1172 |
+
# Load data file
|
| 1173 |
+
if isinstance(args.input_file, str):
|
| 1174 |
+
args.input_file = [args.input_file]
|
| 1175 |
+
|
| 1176 |
+
file_data, file_name, total_number_of_batches = load_in_data_file(
|
| 1177 |
+
file_paths=args.input_file,
|
| 1178 |
+
in_colnames=[args.text_column],
|
| 1179 |
+
batch_size=args.batch_size,
|
| 1180 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
# Load previous output files
|
| 1184 |
+
(
|
| 1185 |
+
reference_df,
|
| 1186 |
+
topic_summary_df,
|
| 1187 |
+
latest_batch_completed_no_loop,
|
| 1188 |
+
deduplication_input_files_status,
|
| 1189 |
+
working_data_file_name_textbox,
|
| 1190 |
+
unique_topics_table_file_name_textbox,
|
| 1191 |
+
) = load_in_previous_data_files(args.previous_output_files)
|
| 1192 |
+
|
| 1193 |
+
# Run validation
|
| 1194 |
+
(
|
| 1195 |
+
display_markdown,
|
| 1196 |
+
master_topic_df_state,
|
| 1197 |
+
master_unique_topics_df_state,
|
| 1198 |
+
master_reference_df_state,
|
| 1199 |
+
validation_output_files,
|
| 1200 |
+
text_output_file_list_state,
|
| 1201 |
+
latest_batch_completed,
|
| 1202 |
+
log_files_output,
|
| 1203 |
+
log_files_output_list_state,
|
| 1204 |
+
conversation_metadata_textbox,
|
| 1205 |
+
estimated_time_taken_number,
|
| 1206 |
+
deduplication_input_files,
|
| 1207 |
+
summarisation_input_files,
|
| 1208 |
+
modifiable_unique_topics_df_state,
|
| 1209 |
+
modification_input_files,
|
| 1210 |
+
in_join_files,
|
| 1211 |
+
missing_df_state,
|
| 1212 |
+
input_tokens_num,
|
| 1213 |
+
output_tokens_num,
|
| 1214 |
+
number_of_calls_num,
|
| 1215 |
+
output_messages_textbox,
|
| 1216 |
+
logged_content_df,
|
| 1217 |
+
) = validate_topics_wrapper(
|
| 1218 |
+
file_data=file_data,
|
| 1219 |
+
reference_df=reference_df,
|
| 1220 |
+
topic_summary_df=topic_summary_df,
|
| 1221 |
+
file_name=working_data_file_name_textbox,
|
| 1222 |
+
chosen_cols=[args.text_column],
|
| 1223 |
+
batch_size=args.batch_size,
|
| 1224 |
+
model_choice=args.model_choice,
|
| 1225 |
+
in_api_key=args.google_api_key,
|
| 1226 |
+
temperature=args.temperature,
|
| 1227 |
+
max_tokens=args.max_tokens,
|
| 1228 |
+
azure_api_key_textbox=args.azure_api_key,
|
| 1229 |
+
azure_endpoint_textbox=args.azure_endpoint,
|
| 1230 |
+
reasoning_suffix="",
|
| 1231 |
+
group_name=args.group_by or "All",
|
| 1232 |
+
produce_structured_summary_radio=args.produce_structured_summary,
|
| 1233 |
+
force_zero_shot_radio=args.force_zero_shot,
|
| 1234 |
+
force_single_topic_radio=args.force_single_topic,
|
| 1235 |
+
context_textbox=args.context,
|
| 1236 |
+
additional_instructions_summary_format=args.additional_summary_instructions,
|
| 1237 |
+
output_folder=args.output_dir,
|
| 1238 |
+
output_debug_files=str(args.output_debug_files),
|
| 1239 |
+
original_full_file_name=file_name,
|
| 1240 |
+
additional_validation_issues_provided=args.additional_validation_issues,
|
| 1241 |
+
max_time_for_loop=args.max_time_for_loop,
|
| 1242 |
+
in_data_files=args.input_file,
|
| 1243 |
+
sentiment_checkbox=args.sentiment,
|
| 1244 |
+
logged_content=None,
|
| 1245 |
+
show_previous_table=args.show_previous_table,
|
| 1246 |
+
aws_access_key_textbox=args.aws_access_key,
|
| 1247 |
+
aws_secret_key_textbox=args.aws_secret_key,
|
| 1248 |
+
aws_region_textbox=args.aws_region,
|
| 1249 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
end_time = time.time()
|
| 1253 |
+
processing_time = end_time - start_time
|
| 1254 |
+
|
| 1255 |
+
print("\n--- Topic Validation Complete ---")
|
| 1256 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1257 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1258 |
+
if validation_output_files:
|
| 1259 |
+
print("Generated Files:", sorted(validation_output_files))
|
| 1260 |
+
|
| 1261 |
+
# Write usage log
|
| 1262 |
+
write_usage_log(
|
| 1263 |
+
session_hash=session_hash,
|
| 1264 |
+
file_name=file_name,
|
| 1265 |
+
text_column=args.text_column,
|
| 1266 |
+
model_choice=args.model_choice,
|
| 1267 |
+
conversation_metadata=conversation_metadata_textbox or "",
|
| 1268 |
+
input_tokens=input_tokens_num or 0,
|
| 1269 |
+
output_tokens=output_tokens_num or 0,
|
| 1270 |
+
number_of_calls=number_of_calls_num or 0,
|
| 1271 |
+
estimated_time_taken=estimated_time_taken_number or processing_time,
|
| 1272 |
+
cost_code=args.cost_code,
|
| 1273 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1274 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
# Create Excel output if requested
|
| 1278 |
+
if args.create_xlsx_output:
|
| 1279 |
+
print("\nCreating Excel output file...")
|
| 1280 |
+
try:
|
| 1281 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1282 |
+
in_data_files=args.input_file,
|
| 1283 |
+
chosen_cols=[args.text_column],
|
| 1284 |
+
reference_data_file_name_textbox=file_name,
|
| 1285 |
+
in_group_col=args.group_by,
|
| 1286 |
+
model_choice=args.model_choice,
|
| 1287 |
+
master_reference_df_state=master_reference_df_state,
|
| 1288 |
+
master_unique_topics_df_state=master_unique_topics_df_state,
|
| 1289 |
+
summarised_output_df=pd.DataFrame(), # No summaries yet
|
| 1290 |
+
missing_df_state=missing_df_state,
|
| 1291 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1292 |
+
usage_logs_location=(
|
| 1293 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1294 |
+
if args.save_logs_to_csv
|
| 1295 |
+
else ""
|
| 1296 |
+
),
|
| 1297 |
+
model_name_map=model_name_map,
|
| 1298 |
+
output_folder=args.output_dir,
|
| 1299 |
+
structured_summaries=args.produce_structured_summary,
|
| 1300 |
+
)
|
| 1301 |
+
if xlsx_files:
|
| 1302 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1303 |
+
except Exception as e:
|
| 1304 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1305 |
+
|
| 1306 |
+
# Task 3: Deduplicate Topics
|
| 1307 |
+
elif args.task == "deduplicate":
|
| 1308 |
+
print("--- Starting Topic Deduplication Workflow... ---")
|
| 1309 |
+
|
| 1310 |
+
# Load previous output files
|
| 1311 |
+
(
|
| 1312 |
+
reference_df,
|
| 1313 |
+
topic_summary_df,
|
| 1314 |
+
latest_batch_completed_no_loop,
|
| 1315 |
+
deduplication_input_files_status,
|
| 1316 |
+
working_data_file_name_textbox,
|
| 1317 |
+
unique_topics_table_file_name_textbox,
|
| 1318 |
+
) = load_in_previous_data_files(args.previous_output_files)
|
| 1319 |
+
|
| 1320 |
+
if args.method == "fuzzy":
|
| 1321 |
+
# Fuzzy matching deduplication
|
| 1322 |
+
(
|
| 1323 |
+
ref_df_after_dedup,
|
| 1324 |
+
unique_df_after_dedup,
|
| 1325 |
+
summarisation_input_files,
|
| 1326 |
+
log_files_output,
|
| 1327 |
+
summarised_output_markdown,
|
| 1328 |
+
) = deduplicate_topics(
|
| 1329 |
+
reference_df=reference_df,
|
| 1330 |
+
topic_summary_df=topic_summary_df,
|
| 1331 |
+
reference_table_file_name=working_data_file_name_textbox,
|
| 1332 |
+
unique_topics_table_file_name=unique_topics_table_file_name_textbox,
|
| 1333 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1334 |
+
merge_sentiment=args.merge_sentiment,
|
| 1335 |
+
merge_general_topics=args.merge_general_topics,
|
| 1336 |
+
score_threshold=args.similarity_threshold,
|
| 1337 |
+
in_data_files=args.input_file if args.input_file else list(),
|
| 1338 |
+
chosen_cols=[args.text_column] if args.text_column else list(),
|
| 1339 |
+
output_folder=args.output_dir,
|
| 1340 |
+
)
|
| 1341 |
+
else:
|
| 1342 |
+
# LLM deduplication
|
| 1343 |
+
model_source = model_name_map.get(args.model_choice, {}).get(
|
| 1344 |
+
"source", default_model_source
|
| 1345 |
+
)
|
| 1346 |
+
(
|
| 1347 |
+
ref_df_after_dedup,
|
| 1348 |
+
unique_df_after_dedup,
|
| 1349 |
+
summarisation_input_files,
|
| 1350 |
+
log_files_output,
|
| 1351 |
+
summarised_output_markdown,
|
| 1352 |
+
input_tokens_num,
|
| 1353 |
+
output_tokens_num,
|
| 1354 |
+
number_of_calls_num,
|
| 1355 |
+
estimated_time_taken_number,
|
| 1356 |
+
) = deduplicate_topics_llm(
|
| 1357 |
+
reference_df=reference_df,
|
| 1358 |
+
topic_summary_df=topic_summary_df,
|
| 1359 |
+
reference_table_file_name=working_data_file_name_textbox,
|
| 1360 |
+
unique_topics_table_file_name=unique_topics_table_file_name_textbox,
|
| 1361 |
+
model_choice=args.model_choice,
|
| 1362 |
+
in_api_key=args.google_api_key,
|
| 1363 |
+
temperature=args.temperature,
|
| 1364 |
+
model_source=model_source,
|
| 1365 |
+
bedrock_runtime=None,
|
| 1366 |
+
local_model=None,
|
| 1367 |
+
tokenizer=None,
|
| 1368 |
+
assistant_model=None,
|
| 1369 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1370 |
+
merge_sentiment=args.merge_sentiment,
|
| 1371 |
+
merge_general_topics=args.merge_general_topics,
|
| 1372 |
+
in_data_files=args.input_file if args.input_file else list(),
|
| 1373 |
+
chosen_cols=[args.text_column] if args.text_column else list(),
|
| 1374 |
+
output_folder=args.output_dir,
|
| 1375 |
+
candidate_topics=(
|
| 1376 |
+
args.candidate_topics if args.candidate_topics else None
|
| 1377 |
+
),
|
| 1378 |
+
azure_endpoint=args.azure_endpoint,
|
| 1379 |
+
output_debug_files=str(args.output_debug_files),
|
| 1380 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
end_time = time.time()
|
| 1384 |
+
processing_time = end_time - start_time
|
| 1385 |
+
|
| 1386 |
+
print("\n--- Topic Deduplication Complete ---")
|
| 1387 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1388 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1389 |
+
if summarisation_input_files:
|
| 1390 |
+
print("Generated Files:", sorted(summarisation_input_files))
|
| 1391 |
+
|
| 1392 |
+
# Write usage log (only for LLM deduplication which has token counts)
|
| 1393 |
+
if args.method == "llm":
|
| 1394 |
+
# Extract token counts from LLM deduplication result
|
| 1395 |
+
llm_input_tokens = (
|
| 1396 |
+
input_tokens_num if "input_tokens_num" in locals() else 0
|
| 1397 |
+
)
|
| 1398 |
+
llm_output_tokens = (
|
| 1399 |
+
output_tokens_num if "output_tokens_num" in locals() else 0
|
| 1400 |
+
)
|
| 1401 |
+
llm_calls = (
|
| 1402 |
+
number_of_calls_num if "number_of_calls_num" in locals() else 0
|
| 1403 |
+
)
|
| 1404 |
+
llm_time = (
|
| 1405 |
+
estimated_time_taken_number
|
| 1406 |
+
if "estimated_time_taken_number" in locals()
|
| 1407 |
+
else processing_time
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
write_usage_log(
|
| 1411 |
+
session_hash=session_hash,
|
| 1412 |
+
file_name=working_data_file_name_textbox,
|
| 1413 |
+
text_column=args.text_column if args.text_column else "",
|
| 1414 |
+
model_choice=args.model_choice,
|
| 1415 |
+
conversation_metadata="",
|
| 1416 |
+
input_tokens=llm_input_tokens,
|
| 1417 |
+
output_tokens=llm_output_tokens,
|
| 1418 |
+
number_of_calls=llm_calls,
|
| 1419 |
+
estimated_time_taken=llm_time,
|
| 1420 |
+
cost_code=args.cost_code,
|
| 1421 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1422 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
# Create Excel output if requested
|
| 1426 |
+
xlsx_files = []
|
| 1427 |
+
if args.create_xlsx_output:
|
| 1428 |
+
print("\nCreating Excel output file...")
|
| 1429 |
+
try:
|
| 1430 |
+
# Use the deduplicated dataframes
|
| 1431 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1432 |
+
in_data_files=args.input_file if args.input_file else [],
|
| 1433 |
+
chosen_cols=[args.text_column] if args.text_column else [],
|
| 1434 |
+
reference_data_file_name_textbox=working_data_file_name_textbox,
|
| 1435 |
+
in_group_col=args.group_by,
|
| 1436 |
+
model_choice=args.model_choice,
|
| 1437 |
+
master_reference_df_state=ref_df_after_dedup,
|
| 1438 |
+
master_unique_topics_df_state=unique_df_after_dedup,
|
| 1439 |
+
summarised_output_df=pd.DataFrame(), # No summaries yet
|
| 1440 |
+
missing_df_state=pd.DataFrame(),
|
| 1441 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1442 |
+
usage_logs_location=(
|
| 1443 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1444 |
+
if args.save_logs_to_csv
|
| 1445 |
+
else ""
|
| 1446 |
+
),
|
| 1447 |
+
model_name_map=model_name_map,
|
| 1448 |
+
output_folder=args.output_dir,
|
| 1449 |
+
structured_summaries=args.produce_structured_summary,
|
| 1450 |
+
)
|
| 1451 |
+
if xlsx_files:
|
| 1452 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1453 |
+
except Exception as e:
|
| 1454 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1455 |
+
|
| 1456 |
+
# Upload outputs to S3 if enabled
|
| 1457 |
+
all_output_files = (
|
| 1458 |
+
list(summarisation_input_files) if summarisation_input_files else []
|
| 1459 |
+
)
|
| 1460 |
+
if xlsx_files:
|
| 1461 |
+
all_output_files.extend(xlsx_files)
|
| 1462 |
+
upload_outputs_to_s3_if_enabled(
|
| 1463 |
+
output_files=all_output_files,
|
| 1464 |
+
base_file_name=working_data_file_name_textbox,
|
| 1465 |
+
session_hash=session_hash,
|
| 1466 |
+
)
|
| 1467 |
+
|
| 1468 |
+
# Task 4: Summarise Topics
|
| 1469 |
+
elif args.task == "summarise":
|
| 1470 |
+
print("--- Starting Topic Summarisation Workflow... ---")
|
| 1471 |
+
|
| 1472 |
+
# Load previous output files
|
| 1473 |
+
(
|
| 1474 |
+
reference_df,
|
| 1475 |
+
topic_summary_df,
|
| 1476 |
+
latest_batch_completed_no_loop,
|
| 1477 |
+
deduplication_input_files_status,
|
| 1478 |
+
working_data_file_name_textbox,
|
| 1479 |
+
unique_topics_table_file_name_textbox,
|
| 1480 |
+
) = load_in_previous_data_files(args.previous_output_files)
|
| 1481 |
+
|
| 1482 |
+
# Determine summary format prompt
|
| 1483 |
+
summary_format_prompt = (
|
| 1484 |
+
two_para_summary_format_prompt
|
| 1485 |
+
if args.summary_format == "two_paragraph"
|
| 1486 |
+
else single_para_summary_format_prompt
|
| 1487 |
+
)
|
| 1488 |
+
|
| 1489 |
+
# Run summarisation
|
| 1490 |
+
(
|
| 1491 |
+
summary_reference_table_sample_state,
|
| 1492 |
+
master_unique_topics_df_revised_summaries_state,
|
| 1493 |
+
master_reference_df_revised_summaries_state,
|
| 1494 |
+
summary_output_files,
|
| 1495 |
+
summarised_outputs_list,
|
| 1496 |
+
latest_summary_completed_num,
|
| 1497 |
+
conversation_metadata_textbox,
|
| 1498 |
+
summarised_output_markdown,
|
| 1499 |
+
log_files_output,
|
| 1500 |
+
overall_summarisation_input_files,
|
| 1501 |
+
input_tokens_num,
|
| 1502 |
+
output_tokens_num,
|
| 1503 |
+
number_of_calls_num,
|
| 1504 |
+
estimated_time_taken_number,
|
| 1505 |
+
output_messages_textbox,
|
| 1506 |
+
logged_content_df,
|
| 1507 |
+
) = wrapper_summarise_output_topics_per_group(
|
| 1508 |
+
grouping_col=args.group_by,
|
| 1509 |
+
sampled_reference_table_df=reference_df.copy(), # Will be sampled if sample_reference_table=True
|
| 1510 |
+
topic_summary_df=topic_summary_df,
|
| 1511 |
+
reference_table_df=reference_df,
|
| 1512 |
+
model_choice=args.model_choice,
|
| 1513 |
+
in_api_key=args.google_api_key,
|
| 1514 |
+
temperature=args.temperature,
|
| 1515 |
+
reference_data_file_name=working_data_file_name_textbox,
|
| 1516 |
+
summarised_outputs=list(),
|
| 1517 |
+
latest_summary_completed=0,
|
| 1518 |
+
out_metadata_str="",
|
| 1519 |
+
in_data_files=args.input_file if args.input_file else list(),
|
| 1520 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1521 |
+
chosen_cols=[args.text_column] if args.text_column else list(),
|
| 1522 |
+
log_output_files=list(),
|
| 1523 |
+
summarise_format_radio=summary_format_prompt,
|
| 1524 |
+
output_folder=args.output_dir,
|
| 1525 |
+
context_textbox=args.context,
|
| 1526 |
+
aws_access_key_textbox=args.aws_access_key,
|
| 1527 |
+
aws_secret_key_textbox=args.aws_secret_key,
|
| 1528 |
+
aws_region_textbox=args.aws_region,
|
| 1529 |
+
model_name_map=model_name_map,
|
| 1530 |
+
hf_api_key_textbox=args.hf_token,
|
| 1531 |
+
azure_endpoint_textbox=args.azure_endpoint,
|
| 1532 |
+
existing_logged_content=list(),
|
| 1533 |
+
sample_reference_table=args.sample_reference_table,
|
| 1534 |
+
no_of_sampled_summaries=args.no_of_sampled_summaries,
|
| 1535 |
+
random_seed=args.random_seed,
|
| 1536 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1537 |
+
additional_summary_instructions_provided=args.additional_summary_instructions,
|
| 1538 |
+
output_debug_files=str(args.output_debug_files),
|
| 1539 |
+
reasoning_suffix="",
|
| 1540 |
+
local_model=None,
|
| 1541 |
+
tokenizer=None,
|
| 1542 |
+
assistant_model=None,
|
| 1543 |
+
do_summaries="Yes",
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
end_time = time.time()
|
| 1547 |
+
processing_time = end_time - start_time
|
| 1548 |
+
|
| 1549 |
+
print("\n--- Topic Summarisation Complete ---")
|
| 1550 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1551 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1552 |
+
if summary_output_files:
|
| 1553 |
+
print("Generated Files:", sorted(summary_output_files))
|
| 1554 |
+
|
| 1555 |
+
# Write usage log
|
| 1556 |
+
write_usage_log(
|
| 1557 |
+
session_hash=session_hash,
|
| 1558 |
+
file_name=working_data_file_name_textbox,
|
| 1559 |
+
text_column=args.text_column if args.text_column else "",
|
| 1560 |
+
model_choice=args.model_choice,
|
| 1561 |
+
conversation_metadata=conversation_metadata_textbox or "",
|
| 1562 |
+
input_tokens=input_tokens_num or 0,
|
| 1563 |
+
output_tokens=output_tokens_num or 0,
|
| 1564 |
+
number_of_calls=number_of_calls_num or 0,
|
| 1565 |
+
estimated_time_taken=estimated_time_taken_number or processing_time,
|
| 1566 |
+
cost_code=args.cost_code,
|
| 1567 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1568 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1569 |
+
)
|
| 1570 |
+
|
| 1571 |
+
# Create Excel output if requested
|
| 1572 |
+
xlsx_files = []
|
| 1573 |
+
if args.create_xlsx_output:
|
| 1574 |
+
print("\nCreating Excel output file...")
|
| 1575 |
+
try:
|
| 1576 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1577 |
+
in_data_files=args.input_file if args.input_file else [],
|
| 1578 |
+
chosen_cols=[args.text_column] if args.text_column else [],
|
| 1579 |
+
reference_data_file_name_textbox=working_data_file_name_textbox,
|
| 1580 |
+
in_group_col=args.group_by,
|
| 1581 |
+
model_choice=args.model_choice,
|
| 1582 |
+
master_reference_df_state=master_reference_df_revised_summaries_state,
|
| 1583 |
+
master_unique_topics_df_state=master_unique_topics_df_revised_summaries_state,
|
| 1584 |
+
summarised_output_df=pd.DataFrame(), # Summaries are in the revised dataframes
|
| 1585 |
+
missing_df_state=pd.DataFrame(),
|
| 1586 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1587 |
+
usage_logs_location=(
|
| 1588 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1589 |
+
if args.save_logs_to_csv
|
| 1590 |
+
else ""
|
| 1591 |
+
),
|
| 1592 |
+
model_name_map=model_name_map,
|
| 1593 |
+
output_folder=args.output_dir,
|
| 1594 |
+
structured_summaries=args.produce_structured_summary,
|
| 1595 |
+
)
|
| 1596 |
+
if xlsx_files:
|
| 1597 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1598 |
+
except Exception as e:
|
| 1599 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1600 |
+
|
| 1601 |
+
# Upload outputs to S3 if enabled
|
| 1602 |
+
all_output_files = (
|
| 1603 |
+
list(summary_output_files) if summary_output_files else []
|
| 1604 |
+
)
|
| 1605 |
+
if xlsx_files:
|
| 1606 |
+
all_output_files.extend(xlsx_files)
|
| 1607 |
+
upload_outputs_to_s3_if_enabled(
|
| 1608 |
+
output_files=all_output_files,
|
| 1609 |
+
base_file_name=working_data_file_name_textbox,
|
| 1610 |
+
session_hash=session_hash,
|
| 1611 |
+
)
|
| 1612 |
+
|
| 1613 |
+
# Task 5: Overall Summary
|
| 1614 |
+
elif args.task == "overall_summary":
|
| 1615 |
+
print("--- Starting Overall Summary Workflow... ---")
|
| 1616 |
+
|
| 1617 |
+
# Load previous output files
|
| 1618 |
+
(
|
| 1619 |
+
reference_df,
|
| 1620 |
+
topic_summary_df,
|
| 1621 |
+
latest_batch_completed_no_loop,
|
| 1622 |
+
deduplication_input_files_status,
|
| 1623 |
+
working_data_file_name_textbox,
|
| 1624 |
+
unique_topics_table_file_name_textbox,
|
| 1625 |
+
) = load_in_previous_data_files(args.previous_output_files)
|
| 1626 |
+
|
| 1627 |
+
# Run overall summary
|
| 1628 |
+
(
|
| 1629 |
+
overall_summary_output_files,
|
| 1630 |
+
overall_summarised_output_markdown,
|
| 1631 |
+
summarised_output_df,
|
| 1632 |
+
conversation_metadata_textbox,
|
| 1633 |
+
input_tokens_num,
|
| 1634 |
+
output_tokens_num,
|
| 1635 |
+
number_of_calls_num,
|
| 1636 |
+
estimated_time_taken_number,
|
| 1637 |
+
output_messages_textbox,
|
| 1638 |
+
logged_content_df,
|
| 1639 |
+
) = overall_summary(
|
| 1640 |
+
topic_summary_df=topic_summary_df,
|
| 1641 |
+
model_choice=args.model_choice,
|
| 1642 |
+
in_api_key=args.google_api_key,
|
| 1643 |
+
temperature=args.temperature,
|
| 1644 |
+
reference_data_file_name=working_data_file_name_textbox,
|
| 1645 |
+
output_folder=args.output_dir,
|
| 1646 |
+
chosen_cols=[args.text_column] if args.text_column else list(),
|
| 1647 |
+
context_textbox=args.context,
|
| 1648 |
+
aws_access_key_textbox=args.aws_access_key,
|
| 1649 |
+
aws_secret_key_textbox=args.aws_secret_key,
|
| 1650 |
+
aws_region_textbox=args.aws_region,
|
| 1651 |
+
model_name_map=model_name_map,
|
| 1652 |
+
hf_api_key_textbox=args.hf_token,
|
| 1653 |
+
azure_endpoint_textbox=args.azure_endpoint,
|
| 1654 |
+
existing_logged_content=list(),
|
| 1655 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1656 |
+
output_debug_files=str(args.output_debug_files),
|
| 1657 |
+
log_output_files=list(),
|
| 1658 |
+
reasoning_suffix="",
|
| 1659 |
+
local_model=None,
|
| 1660 |
+
tokenizer=None,
|
| 1661 |
+
assistant_model=None,
|
| 1662 |
+
do_summaries="Yes",
|
| 1663 |
+
)
|
| 1664 |
+
|
| 1665 |
+
end_time = time.time()
|
| 1666 |
+
processing_time = end_time - start_time
|
| 1667 |
+
|
| 1668 |
+
print("\n--- Overall Summary Complete ---")
|
| 1669 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1670 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1671 |
+
if overall_summary_output_files:
|
| 1672 |
+
print("Generated Files:", sorted(overall_summary_output_files))
|
| 1673 |
+
|
| 1674 |
+
# Write usage log
|
| 1675 |
+
write_usage_log(
|
| 1676 |
+
session_hash=session_hash,
|
| 1677 |
+
file_name=working_data_file_name_textbox,
|
| 1678 |
+
text_column=args.text_column if args.text_column else "",
|
| 1679 |
+
model_choice=args.model_choice,
|
| 1680 |
+
conversation_metadata=conversation_metadata_textbox or "",
|
| 1681 |
+
input_tokens=input_tokens_num or 0,
|
| 1682 |
+
output_tokens=output_tokens_num or 0,
|
| 1683 |
+
number_of_calls=number_of_calls_num or 0,
|
| 1684 |
+
estimated_time_taken=estimated_time_taken_number or processing_time,
|
| 1685 |
+
cost_code=args.cost_code,
|
| 1686 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1687 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1688 |
+
)
|
| 1689 |
+
|
| 1690 |
+
# Create Excel output if requested
|
| 1691 |
+
xlsx_files = []
|
| 1692 |
+
if args.create_xlsx_output:
|
| 1693 |
+
print("\nCreating Excel output file...")
|
| 1694 |
+
try:
|
| 1695 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1696 |
+
in_data_files=args.input_file if args.input_file else [],
|
| 1697 |
+
chosen_cols=[args.text_column] if args.text_column else [],
|
| 1698 |
+
reference_data_file_name_textbox=working_data_file_name_textbox,
|
| 1699 |
+
in_group_col=args.group_by,
|
| 1700 |
+
model_choice=args.model_choice,
|
| 1701 |
+
master_reference_df_state=reference_df, # Use original reference_df
|
| 1702 |
+
master_unique_topics_df_state=topic_summary_df, # Use original topic_summary_df
|
| 1703 |
+
summarised_output_df=summarised_output_df,
|
| 1704 |
+
missing_df_state=pd.DataFrame(),
|
| 1705 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1706 |
+
usage_logs_location=(
|
| 1707 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1708 |
+
if args.save_logs_to_csv
|
| 1709 |
+
else ""
|
| 1710 |
+
),
|
| 1711 |
+
model_name_map=model_name_map,
|
| 1712 |
+
output_folder=args.output_dir,
|
| 1713 |
+
structured_summaries=args.produce_structured_summary,
|
| 1714 |
+
)
|
| 1715 |
+
if xlsx_files:
|
| 1716 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1717 |
+
except Exception as e:
|
| 1718 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1719 |
+
|
| 1720 |
+
# Upload outputs to S3 if enabled
|
| 1721 |
+
all_output_files = (
|
| 1722 |
+
list(overall_summary_output_files)
|
| 1723 |
+
if overall_summary_output_files
|
| 1724 |
+
else []
|
| 1725 |
+
)
|
| 1726 |
+
if xlsx_files:
|
| 1727 |
+
all_output_files.extend(xlsx_files)
|
| 1728 |
+
upload_outputs_to_s3_if_enabled(
|
| 1729 |
+
output_files=all_output_files,
|
| 1730 |
+
base_file_name=working_data_file_name_textbox,
|
| 1731 |
+
session_hash=session_hash,
|
| 1732 |
+
)
|
| 1733 |
+
|
| 1734 |
+
# Task 6: All-in-One Pipeline
|
| 1735 |
+
elif args.task == "all_in_one":
|
| 1736 |
+
print("--- Starting All-in-One Pipeline Workflow... ---")
|
| 1737 |
+
|
| 1738 |
+
# Load data file
|
| 1739 |
+
if isinstance(args.input_file, str):
|
| 1740 |
+
args.input_file = [args.input_file]
|
| 1741 |
+
|
| 1742 |
+
file_data, file_name, total_number_of_batches = load_in_data_file(
|
| 1743 |
+
file_paths=args.input_file,
|
| 1744 |
+
in_colnames=[args.text_column],
|
| 1745 |
+
batch_size=args.batch_size,
|
| 1746 |
+
in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1747 |
+
)
|
| 1748 |
+
|
| 1749 |
+
# Prepare candidate topics if provided
|
| 1750 |
+
candidate_topics = None
|
| 1751 |
+
if args.candidate_topics:
|
| 1752 |
+
candidate_topics = args.candidate_topics
|
| 1753 |
+
|
| 1754 |
+
# Determine summary format prompt
|
| 1755 |
+
summary_format_prompt = (
|
| 1756 |
+
two_para_summary_format_prompt
|
| 1757 |
+
if args.summary_format == "two_paragraph"
|
| 1758 |
+
else single_para_summary_format_prompt
|
| 1759 |
+
)
|
| 1760 |
+
|
| 1761 |
+
# Run all-in-one pipeline
|
| 1762 |
+
(
|
| 1763 |
+
display_markdown,
|
| 1764 |
+
master_topic_df_state,
|
| 1765 |
+
master_unique_topics_df_state,
|
| 1766 |
+
master_reference_df_state,
|
| 1767 |
+
topic_extraction_output_files,
|
| 1768 |
+
text_output_file_list_state,
|
| 1769 |
+
latest_batch_completed,
|
| 1770 |
+
log_files_output,
|
| 1771 |
+
log_files_output_list_state,
|
| 1772 |
+
conversation_metadata_textbox,
|
| 1773 |
+
estimated_time_taken_number,
|
| 1774 |
+
deduplication_input_files,
|
| 1775 |
+
summarisation_input_files,
|
| 1776 |
+
modifiable_unique_topics_df_state,
|
| 1777 |
+
modification_input_files,
|
| 1778 |
+
in_join_files,
|
| 1779 |
+
missing_df_state,
|
| 1780 |
+
input_tokens_num,
|
| 1781 |
+
output_tokens_num,
|
| 1782 |
+
number_of_calls_num,
|
| 1783 |
+
output_messages_textbox,
|
| 1784 |
+
summary_reference_table_sample_state,
|
| 1785 |
+
summarised_references_markdown,
|
| 1786 |
+
master_unique_topics_df_revised_summaries_state,
|
| 1787 |
+
master_reference_df_revised_summaries_state,
|
| 1788 |
+
summary_output_files,
|
| 1789 |
+
summarised_outputs_list,
|
| 1790 |
+
latest_summary_completed_num,
|
| 1791 |
+
overall_summarisation_input_files,
|
| 1792 |
+
overall_summary_output_files,
|
| 1793 |
+
overall_summarised_output_markdown,
|
| 1794 |
+
summarised_output_df,
|
| 1795 |
+
logged_content_df,
|
| 1796 |
+
) = all_in_one_pipeline(
|
| 1797 |
+
grouping_col=args.group_by,
|
| 1798 |
+
in_data_files=args.input_file,
|
| 1799 |
+
file_data=file_data,
|
| 1800 |
+
existing_topics_table=pd.DataFrame(),
|
| 1801 |
+
existing_reference_df=pd.DataFrame(),
|
| 1802 |
+
existing_topic_summary_df=pd.DataFrame(),
|
| 1803 |
+
unique_table_df_display_table_markdown="",
|
| 1804 |
+
original_file_name=file_name,
|
| 1805 |
+
total_number_of_batches=total_number_of_batches,
|
| 1806 |
+
in_api_key=args.google_api_key,
|
| 1807 |
+
temperature=args.temperature,
|
| 1808 |
+
chosen_cols=[args.text_column],
|
| 1809 |
+
model_choice=args.model_choice,
|
| 1810 |
+
candidate_topics=candidate_topics,
|
| 1811 |
+
first_loop_state=True,
|
| 1812 |
+
conversation_metadata_text="",
|
| 1813 |
+
latest_batch_completed=0,
|
| 1814 |
+
time_taken_so_far=0,
|
| 1815 |
+
initial_table_prompt_text=initial_table_prompt,
|
| 1816 |
+
initial_table_system_prompt_text=initial_table_system_prompt,
|
| 1817 |
+
add_existing_topics_system_prompt_text=add_existing_topics_system_prompt,
|
| 1818 |
+
add_existing_topics_prompt_text=add_existing_topics_prompt,
|
| 1819 |
+
number_of_prompts_used=1,
|
| 1820 |
+
batch_size=args.batch_size,
|
| 1821 |
+
context_text=args.context,
|
| 1822 |
+
sentiment_choice=args.sentiment,
|
| 1823 |
+
force_zero_shot_choice=args.force_zero_shot,
|
| 1824 |
+
in_excel_sheets=args.excel_sheets,
|
| 1825 |
+
force_single_topic_choice=args.force_single_topic,
|
| 1826 |
+
produce_structures_summary_choice=args.produce_structured_summary,
|
| 1827 |
+
aws_access_key_text=args.aws_access_key,
|
| 1828 |
+
aws_secret_key_text=args.aws_secret_key,
|
| 1829 |
+
aws_region_text=args.aws_region,
|
| 1830 |
+
hf_api_key_text=args.hf_token,
|
| 1831 |
+
azure_api_key_text=args.azure_api_key,
|
| 1832 |
+
azure_endpoint_text=args.azure_endpoint,
|
| 1833 |
+
output_folder=args.output_dir,
|
| 1834 |
+
merge_sentiment=args.merge_sentiment,
|
| 1835 |
+
merge_general_topics=args.merge_general_topics,
|
| 1836 |
+
score_threshold=args.similarity_threshold,
|
| 1837 |
+
summarise_format=summary_format_prompt,
|
| 1838 |
+
random_seed=args.random_seed,
|
| 1839 |
+
log_files_output_list_state=list(),
|
| 1840 |
+
model_name_map_state=model_name_map,
|
| 1841 |
+
usage_logs_location=(
|
| 1842 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1843 |
+
if args.save_logs_to_csv
|
| 1844 |
+
else ""
|
| 1845 |
+
),
|
| 1846 |
+
existing_logged_content=list(),
|
| 1847 |
+
additional_instructions_summary_format=args.additional_summary_instructions,
|
| 1848 |
+
additional_validation_issues_provided="",
|
| 1849 |
+
show_previous_table="Yes",
|
| 1850 |
+
sample_reference_table_checkbox=args.sample_reference_table,
|
| 1851 |
+
api_url=args.api_url if args.api_url else API_URL,
|
| 1852 |
+
output_debug_files=str(args.output_debug_files),
|
| 1853 |
+
model=None,
|
| 1854 |
+
tokenizer=None,
|
| 1855 |
+
assistant_model=None,
|
| 1856 |
+
max_rows=999999,
|
| 1857 |
+
)
|
| 1858 |
+
|
| 1859 |
+
end_time = time.time()
|
| 1860 |
+
processing_time = end_time - start_time
|
| 1861 |
+
|
| 1862 |
+
print("\n--- All-in-One Pipeline Complete ---")
|
| 1863 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 1864 |
+
print(f"\nOutput files saved to: {args.output_dir}")
|
| 1865 |
+
if overall_summary_output_files:
|
| 1866 |
+
print("Generated Files:", sorted(overall_summary_output_files))
|
| 1867 |
+
|
| 1868 |
+
# Write usage log
|
| 1869 |
+
write_usage_log(
|
| 1870 |
+
session_hash=session_hash,
|
| 1871 |
+
file_name=file_name,
|
| 1872 |
+
text_column=args.text_column,
|
| 1873 |
+
model_choice=args.model_choice,
|
| 1874 |
+
conversation_metadata=conversation_metadata_textbox or "",
|
| 1875 |
+
input_tokens=input_tokens_num or 0,
|
| 1876 |
+
output_tokens=output_tokens_num or 0,
|
| 1877 |
+
number_of_calls=number_of_calls_num or 0,
|
| 1878 |
+
estimated_time_taken=estimated_time_taken_number or processing_time,
|
| 1879 |
+
cost_code=args.cost_code,
|
| 1880 |
+
save_to_csv=args.save_logs_to_csv,
|
| 1881 |
+
save_to_dynamodb=args.save_logs_to_dynamodb,
|
| 1882 |
+
)
|
| 1883 |
+
|
| 1884 |
+
# Create Excel output if requested
|
| 1885 |
+
xlsx_files = []
|
| 1886 |
+
if args.create_xlsx_output:
|
| 1887 |
+
print("\nCreating Excel output file...")
|
| 1888 |
+
try:
|
| 1889 |
+
xlsx_files, _ = collect_output_csvs_and_create_excel_output(
|
| 1890 |
+
in_data_files=args.input_file,
|
| 1891 |
+
chosen_cols=[args.text_column],
|
| 1892 |
+
reference_data_file_name_textbox=file_name,
|
| 1893 |
+
in_group_col=args.group_by,
|
| 1894 |
+
model_choice=args.model_choice,
|
| 1895 |
+
master_reference_df_state=master_reference_df_revised_summaries_state,
|
| 1896 |
+
master_unique_topics_df_state=master_unique_topics_df_revised_summaries_state,
|
| 1897 |
+
summarised_output_df=summarised_output_df,
|
| 1898 |
+
missing_df_state=missing_df_state,
|
| 1899 |
+
excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
|
| 1900 |
+
usage_logs_location=(
|
| 1901 |
+
os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
|
| 1902 |
+
if args.save_logs_to_csv
|
| 1903 |
+
else ""
|
| 1904 |
+
),
|
| 1905 |
+
model_name_map=model_name_map,
|
| 1906 |
+
output_folder=args.output_dir,
|
| 1907 |
+
structured_summaries=args.produce_structured_summary,
|
| 1908 |
+
)
|
| 1909 |
+
if xlsx_files:
|
| 1910 |
+
print(f"Excel output created: {sorted(xlsx_files)}")
|
| 1911 |
+
except Exception as e:
|
| 1912 |
+
print(f"Warning: Could not create Excel output: {e}")
|
| 1913 |
+
|
| 1914 |
+
# Upload outputs to S3 if enabled
|
| 1915 |
+
# Collect all output files from the pipeline
|
| 1916 |
+
all_output_files = []
|
| 1917 |
+
if topic_extraction_output_files:
|
| 1918 |
+
all_output_files.extend(topic_extraction_output_files)
|
| 1919 |
+
if overall_summary_output_files:
|
| 1920 |
+
all_output_files.extend(overall_summary_output_files)
|
| 1921 |
+
if xlsx_files:
|
| 1922 |
+
all_output_files.extend(xlsx_files)
|
| 1923 |
+
upload_outputs_to_s3_if_enabled(
|
| 1924 |
+
output_files=all_output_files,
|
| 1925 |
+
base_file_name=file_name,
|
| 1926 |
+
session_hash=session_hash,
|
| 1927 |
+
)
|
| 1928 |
+
|
| 1929 |
+
else:
|
| 1930 |
+
print(f"Error: Invalid task '{args.task}'.")
|
| 1931 |
+
print(
|
| 1932 |
+
"Valid options: 'extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one'"
|
| 1933 |
+
)
|
| 1934 |
+
|
| 1935 |
+
except Exception as e:
|
| 1936 |
+
print(f"\nAn error occurred during the workflow: {e}")
|
| 1937 |
+
import traceback
|
| 1938 |
+
|
| 1939 |
+
traceback.print_exc()
|
| 1940 |
+
|
| 1941 |
+
|
| 1942 |
+
if __name__ == "__main__":
|
| 1943 |
+
main()
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# Exit immediately if a command exits with a non-zero status.
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
echo "Starting in APP_MODE: $APP_MODE"
|
| 7 |
+
|
| 8 |
+
# --- Start the app based on mode ---
|
| 9 |
+
|
| 10 |
+
if [ "$APP_MODE" = "lambda" ]; then
|
| 11 |
+
echo "Starting in Lambda mode..."
|
| 12 |
+
# The CMD from Dockerfile will be passed as "$@"
|
| 13 |
+
exec python -m awslambdaric "$@"
|
| 14 |
+
else
|
| 15 |
+
echo "Starting in Gradio mode..."
|
| 16 |
+
exec python app.py
|
| 17 |
+
fi
|
| 18 |
+
|
example_data/case_note_headers_specific.csv
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ο»ΏGeneral Topic,Subtopic
|
| 2 |
+
Mental health,Anger
|
| 3 |
+
Mental health,Social issues
|
| 4 |
+
Physical health,General
|
| 5 |
+
Physical health,Substance misuse
|
| 6 |
+
Behaviour at school,Behaviour at school
|
| 7 |
+
Trends over time,Trends over time
|
example_data/combined_case_notes.csv
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Date,Social Worker,Client,Case Note
|
| 2 |
+
"January 3, 2023",Jane Smith,Alex D.,"Met with Alex at school following reports of increased absences and declining grades. Alex appeared sullen and avoided eye contact. When prompted about school, Alex expressed feelings of isolation and stated, ""No one gets me."" Scheduled a follow-up meeting to further explore these feelings."
|
| 3 |
+
"January 17, 2023",Jane Smith,Alex D.,"Met with Alex at the community center. Alex displayed sudden outbursts of anger when discussing home life, particularly in relation to a new stepfather. Alex mentioned occasional substance use, but did not specify which substances. Recommended a comprehensive assessment."
|
| 4 |
+
"February 5, 2023",Jane Smith,Alex D.,Home visit conducted. Alex's mother reported frequent arguments at home. She expressed concerns about Alex's new group of friends and late-night outings. Noted potential signs of substance abuse. Suggested family counseling.
|
| 5 |
+
"February 21, 2023",Jane Smith,Alex D.,"Met with Alex alone at my office. Alex appeared more agitated than in previous meetings. There were visible signs of self-harm on Alex's arms. When questioned, Alex became defensive. Immediate referral made to a mental health professional."
|
| 6 |
+
"March 10, 2023",Jane Smith,Alex D.,Attended joint session with Alex and a therapist. Alex shared feelings of hopelessness and admitted to occasional thoughts of self-harm. Therapist recommended a comprehensive mental health evaluation and ongoing therapy.
|
| 7 |
+
"March 25, 2023",Jane Smith,Alex D.,"Received a call from Alex's school about a physical altercation with another student. Met with Alex, who displayed high levels of frustration and admitted to the use of alcohol. Discussed the importance of seeking help and finding positive coping mechanisms. Recommended enrollment in an anger management program."
|
| 8 |
+
"April 15, 2023",Jane Smith,Alex D.,Met with Alex and mother to discuss progress. Alex's mother expressed concerns about Alex's increasing aggression at home. Alex acknowledged the issues but blamed others for provoking the behavior. It was decided that a more intensive intervention may be needed.
|
| 9 |
+
"April 30, 2023",Jane Smith,Alex D.,"Met with Alex and a psychiatrist. Psychiatrist diagnosed Alex with Oppositional Defiant Disorder (ODD) and co-morbid substance use disorder. A treatment plan was discussed, including medication, therapy, and family counseling."
|
| 10 |
+
"May 20, 2023",Jane Smith,Alex D.,"Met with Alex to discuss progress. Alex has started attending group therapy and has shown slight improvements in behavior. Still, concerns remain about substance use. Discussed potential for a short-term residential treatment program."
|
| 11 |
+
"January 3, 2023",Jane Smith,Jamie L.,"Met with Jamie at school after receiving reports of consistent tardiness and decreased participation in class. Jamie appeared withdrawn and exhibited signs of sadness. When asked about feelings, Jamie expressed feeling ""empty"" and ""hopeless"" at times. Scheduled a follow-up meeting to further explore these feelings."
|
| 12 |
+
"January 17, 2023",Jane Smith,Jamie L.,"Met with Jamie at the community center. Jamie shared feelings of low self-worth, mentioning that it's hard to find motivation for daily tasks. Discussed potential triggers and learned about recent family financial struggles. Recommended counseling and possible group therapy for peer support."
|
| 13 |
+
"February 5, 2023",Jane Smith,Jamie L.,Home visit conducted. Jamie's parents shared concerns about Jamie's increasing withdrawal from family activities and lack of interest in hobbies. Parents mentioned that Jamie spends a lot of time alone in the room. Suggested family therapy to open communication channels.
|
| 14 |
+
"February 21, 2023",Jane Smith,Jamie L.,Met with Jamie in my office. Jamie opened up about feelings of isolation and mentioned difficulty sleeping. No signs of self-harm or suicidal ideation were noted. Recommended a comprehensive mental health assessment to better understand the depth of the depression.
|
| 15 |
+
"March 10, 2023",Jane Smith,Jamie L.,"Attended a joint session with Jamie and a therapist. The therapist noted signs of moderate depression. Together, we discussed coping strategies and potential interventions. Jamie showed interest in art therapy."
|
| 16 |
+
"March 25, 2023",Jane Smith,Jamie L.,"Received feedback from Jamie's school that academic performance has slightly improved. However, social interactions remain limited. Encouraged Jamie to join school clubs or groups to foster connection."
|
| 17 |
+
"April 15, 2023",Jane Smith,Jamie L.,"Met with Jamie and parents to discuss progress. Parents have observed slight improvements in mood on some days, but overall, Jamie still appears to struggle. It was decided to explore medication as a potential aid alongside therapy."
|
| 18 |
+
"April 30, 2023",Jane Smith,Jamie L.,Met with Jamie and a psychiatrist. The psychiatrist diagnosed Jamie with Major Depressive Disorder (MDD) and suggested considering antidepressant medication. Discussed the potential benefits and side effects. Jamie and parents will think it over.
|
| 19 |
+
"May 20, 2023",Jane Smith,Jamie L.,"Jamie has started on a low dose of an antidepressant. Initial feedback is positive, with some improvement in mood and energy levels. Will continue monitoring and adjusting as necessary."
|
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_structured_summaries.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:322a081b29d4fb40ccae7d47aa74fda772a002eda576ddc98d6acc86366cff11
|
| 3 |
+
size 13502
|
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3dcc1ea155169c23d913043b1ad87da2f2912be36d9fb1521c72ee05b8dcf36
|
| 3 |
+
size 25299
|
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis_grouped.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e1eaede9af75b6ab695b1cfc6c01ec875abf14521249ba7257bd4bb0afd7ee8
|
| 3 |
+
size 28673
|
example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30947f1355eacc74c92d09b766e8e3d71092b9a240e7f8acd381874b7d7ebcb3
|
| 3 |
+
size 24673
|
example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis_zero_shot.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5f0e36143d8362391e3b11d1c20e3a2a1b7536b8f0c972e3d44644eb9ae4e82
|
| 3 |
+
size 27592
|
example_data/dummy_consultation_response.csv
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Response Reference,Object to or Support application,Response text
|
| 2 |
+
R1,Object,I strongly object to the proposed five-storey apartment block on Main Street. It is completely out of keeping with the existing character of the area and will overshadow the existing buildings.
|
| 3 |
+
R2,Support,"I fully support the proposed development. The town needs more housing, and this development will provide much-needed homes."
|
| 4 |
+
R3,Object,The proposed development is too tall and will have a negative impact on the views from the surrounding area.
|
| 5 |
+
R4,Object,The loss of the well-loved cafe will be a great loss to the community.
|
| 6 |
+
R5,Support,The development will bring much-needed investment to the area and create jobs.
|
| 7 |
+
R6,Object,The increased traffic generated by the development will cause congestion on Main Street.
|
| 8 |
+
R7,Support,The development will provide much-needed affordable housing.
|
| 9 |
+
R8,Object,The development will have a negative impact on the local environment.
|
| 10 |
+
R9,Support,The development will improve the appearance of Main Street.
|
| 11 |
+
R10,Object,The development will overshadow the existing buildings and make them feel cramped.
|
| 12 |
+
R11,Support,The development will provide much-needed amenities for the local community.
|
| 13 |
+
R12,Object,The development will have a negative impact on the local wildlife.
|
| 14 |
+
R13,Support,The development will help to revitalise the town centre.
|
| 15 |
+
R14,Object,The development will increase noise pollution in the area.
|
| 16 |
+
R15,Support,The development will provide much-needed parking spaces.
|
| 17 |
+
R16,Object,The development will have a negative impact on the local businesses.
|
| 18 |
+
R17,Support,The development will provide much-needed green space.
|
| 19 |
+
R18,Object,The development will have a negative impact on the local heritage.
|
| 20 |
+
R19,Support,The development will provide much-needed facilities for young people.
|
| 21 |
+
R20,Object,The development will have a negative impact on the local schools.
|
| 22 |
+
R21,Support,The development will provide much-needed social housing.
|
| 23 |
+
R22,Object,The development will have a negative impact on the local infrastructure.
|
| 24 |
+
R23,Support,The development will provide much-needed jobs for local people.
|
| 25 |
+
R24,Object,The development will have a negative impact on the local economy.
|
| 26 |
+
R25,Support,The development will provide much-needed community facilities.
|
| 27 |
+
R26,Object,The development will have a negative impact on the local amenities.
|
| 28 |
+
R27,Support,The development will provide much-needed housing for young people.
|
| 29 |
+
R28,Object,The development will have a negative impact on the local character.
|
| 30 |
+
R29,Support,The development will provide much-needed housing for families.
|
| 31 |
+
R30,Object,The development will have a negative impact on the local quality of life.
|
example_data/dummy_consultation_response_themes.csv
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ο»Ώtopics
|
| 2 |
+
Need for family housing
|
| 3 |
+
Impact on the character of the area
|
| 4 |
+
Amenities for the local community
|
| 5 |
+
Revitalisation of the town centre
|
| 6 |
+
Impact on local wildlife
|
| 7 |
+
Parking
|
| 8 |
+
Impact on local businesses
|
| 9 |
+
Green space
|
| 10 |
+
Noise pollution
|
| 11 |
+
Impact on local heritage
|
| 12 |
+
Facilities for young people
|
| 13 |
+
Impact on local schools
|
| 14 |
+
Impact on views
|
| 15 |
+
Loss of cafe
|
| 16 |
+
Investment and job creation
|
| 17 |
+
Traffic congestion
|
| 18 |
+
Affordable housing
|
| 19 |
+
Impact on the local environment
|
| 20 |
+
Improvement of main street
|
| 21 |
+
Impact on local infrastructure
|
| 22 |
+
Investment and job creation
|
| 23 |
+
Impact on local schools
|
| 24 |
+
Provision of community facilities
|
| 25 |
+
Impact on local heritage
|
| 26 |
+
Impact on quality of life
|
intros/intro.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Create thematic summaries from your data
|
| 2 |
+
|
| 3 |
+
Extract topics and summarise open text using Large Language Models (LLMs). The model will loop through all text rows to find the most relevant general topics and subtopics, and provide a short summary of each. If you have specific topics in mind, you can enter them in 'Provide a list of specific topics' below.
|
| 4 |
+
|
| 5 |
+
NOTE: LLMs are not 100% accurate and may produce biased or incorrect responses. All files downloaded from this app **need to be checked by a human** before they are used in further outputs.
|
| 6 |
+
|
| 7 |
+
Unsure of how to use this app? Try an example by clicking on one of the example datasets below to see typical outputs the app can produce. There is also a user guide provided alongside this app - please ask your system administrator if you do not have access.
|
lambda_entrypoint.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import boto3
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Import the main function from your CLI script
|
| 8 |
+
from cli_topics import main as cli_main
|
| 9 |
+
from tools.config import (
|
| 10 |
+
AWS_REGION,
|
| 11 |
+
BATCH_SIZE_DEFAULT,
|
| 12 |
+
DEDUPLICATION_THRESHOLD,
|
| 13 |
+
DEFAULT_COST_CODE,
|
| 14 |
+
DEFAULT_SAMPLED_SUMMARIES,
|
| 15 |
+
LLM_MAX_NEW_TOKENS,
|
| 16 |
+
LLM_SEED,
|
| 17 |
+
LLM_TEMPERATURE,
|
| 18 |
+
OUTPUT_DEBUG_FILES,
|
| 19 |
+
SAVE_LOGS_TO_CSV,
|
| 20 |
+
SAVE_LOGS_TO_DYNAMODB,
|
| 21 |
+
SESSION_OUTPUT_FOLDER,
|
| 22 |
+
USAGE_LOGS_FOLDER,
|
| 23 |
+
convert_string_to_boolean,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_env_list(env_var_name: str | list[str] | None) -> list[str]:
|
| 28 |
+
"""Parses a comma-separated environment variable into a list of strings."""
|
| 29 |
+
if isinstance(env_var_name, list):
|
| 30 |
+
return env_var_name
|
| 31 |
+
if env_var_name is None:
|
| 32 |
+
return []
|
| 33 |
+
|
| 34 |
+
# Handle string input
|
| 35 |
+
value = str(env_var_name).strip()
|
| 36 |
+
if not value or value == "[]":
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
# Remove brackets if present (e.g., "[item1, item2]" -> "item1, item2")
|
| 40 |
+
if value.startswith("[") and value.endswith("]"):
|
| 41 |
+
value = value[1:-1]
|
| 42 |
+
|
| 43 |
+
# Remove quotes and split by comma
|
| 44 |
+
value = value.replace('"', "").replace("'", "")
|
| 45 |
+
if not value:
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
# Split by comma and filter out any empty strings
|
| 49 |
+
return [s.strip() for s in value.split(",") if s.strip()]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
print("Lambda entrypoint loading...")
|
| 53 |
+
|
| 54 |
+
# Initialize S3 client outside the handler for connection reuse
|
| 55 |
+
s3_client = boto3.client("s3", region_name=os.getenv("AWS_REGION", AWS_REGION))
|
| 56 |
+
print("S3 client initialised")
|
| 57 |
+
|
| 58 |
+
# Lambda's only writable directory is /tmp. Ensure that all temporary files are stored in this directory.
|
| 59 |
+
TMP_DIR = "/tmp"
|
| 60 |
+
INPUT_DIR = os.path.join(TMP_DIR, "input")
|
| 61 |
+
OUTPUT_DIR = os.path.join(TMP_DIR, "output")
|
| 62 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.join(TMP_DIR, "gradio_tmp")
|
| 63 |
+
os.environ["MPLCONFIGDIR"] = os.path.join(TMP_DIR, "matplotlib_cache")
|
| 64 |
+
os.environ["FEEDBACK_LOGS_FOLDER"] = os.path.join(TMP_DIR, "feedback")
|
| 65 |
+
os.environ["ACCESS_LOGS_FOLDER"] = os.path.join(TMP_DIR, "logs")
|
| 66 |
+
os.environ["USAGE_LOGS_FOLDER"] = os.path.join(TMP_DIR, "usage")
|
| 67 |
+
|
| 68 |
+
# Define compatible file types for processing
|
| 69 |
+
COMPATIBLE_FILE_TYPES = {
|
| 70 |
+
".csv",
|
| 71 |
+
".xlsx",
|
| 72 |
+
".xls",
|
| 73 |
+
".parquet",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def download_file_from_s3(bucket_name, key, download_path):
|
| 78 |
+
"""Download a file from S3 to the local filesystem."""
|
| 79 |
+
try:
|
| 80 |
+
s3_client.download_file(bucket_name, key, download_path)
|
| 81 |
+
print(f"Successfully downloaded s3://{bucket_name}/{key} to {download_path}")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Error downloading from S3: {e}")
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def upload_directory_to_s3(local_directory, bucket_name, s3_prefix):
|
| 88 |
+
"""Upload all files from a local directory to an S3 prefix."""
|
| 89 |
+
for root, _, files in os.walk(local_directory):
|
| 90 |
+
for file_name in files:
|
| 91 |
+
local_file_path = os.path.join(root, file_name)
|
| 92 |
+
# Create a relative path to maintain directory structure if needed
|
| 93 |
+
relative_path = os.path.relpath(local_file_path, local_directory)
|
| 94 |
+
output_key = os.path.join(s3_prefix, relative_path).replace("\\", "/")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
s3_client.upload_file(local_file_path, bucket_name, output_key)
|
| 98 |
+
print(
|
| 99 |
+
f"Successfully uploaded {local_file_path} to s3://{bucket_name}/{output_key}"
|
| 100 |
+
)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error uploading to S3: {e}")
|
| 103 |
+
raise
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def lambda_handler(event, context):
|
| 107 |
+
print(f"Received event: {json.dumps(event)}")
|
| 108 |
+
|
| 109 |
+
# 1. Setup temporary directories
|
| 110 |
+
os.makedirs(INPUT_DIR, exist_ok=True)
|
| 111 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
# 2. Extract information from the event
|
| 114 |
+
# Assumes the event is triggered by S3 and may contain an 'arguments' payload
|
| 115 |
+
try:
|
| 116 |
+
record = event["Records"][0]
|
| 117 |
+
bucket_name = record["s3"]["bucket"]["name"]
|
| 118 |
+
input_key = record["s3"]["object"]["key"]
|
| 119 |
+
|
| 120 |
+
# The user metadata can be used to pass arguments
|
| 121 |
+
# This is more robust than embedding them in the main event body
|
| 122 |
+
try:
|
| 123 |
+
response = s3_client.head_object(Bucket=bucket_name, Key=input_key)
|
| 124 |
+
metadata = response.get("Metadata", dict())
|
| 125 |
+
print(f"S3 object metadata: {metadata}")
|
| 126 |
+
|
| 127 |
+
# Arguments can be passed as a JSON string in metadata
|
| 128 |
+
arguments_str = metadata.get("arguments", "{}")
|
| 129 |
+
print(f"Arguments string from metadata: '{arguments_str}'")
|
| 130 |
+
|
| 131 |
+
if arguments_str and arguments_str != "{}":
|
| 132 |
+
arguments = json.loads(arguments_str)
|
| 133 |
+
print(f"Successfully parsed arguments from metadata: {arguments}")
|
| 134 |
+
else:
|
| 135 |
+
arguments = dict()
|
| 136 |
+
print("No arguments found in metadata, using empty dictionary")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"Warning: Could not parse metadata arguments: {e}")
|
| 139 |
+
print("Using empty arguments dictionary")
|
| 140 |
+
arguments = dict()
|
| 141 |
+
|
| 142 |
+
except (KeyError, IndexError) as e:
|
| 143 |
+
print(
|
| 144 |
+
f"Could not parse S3 event record: {e}. Checking for direct invocation payload."
|
| 145 |
+
)
|
| 146 |
+
# Fallback for direct invocation (e.g., from Step Functions or manual test)
|
| 147 |
+
bucket_name = event.get("bucket_name")
|
| 148 |
+
input_key = event.get("input_key")
|
| 149 |
+
arguments = event.get("arguments", dict())
|
| 150 |
+
if not all([bucket_name, input_key]):
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Missing 'bucket_name' or 'input_key' in direct invocation event."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Log file type information
|
| 156 |
+
file_extension = os.path.splitext(input_key)[1].lower()
|
| 157 |
+
print(f"Detected file extension: '{file_extension}'")
|
| 158 |
+
|
| 159 |
+
# 3. Download the main input file
|
| 160 |
+
input_file_path = os.path.join(INPUT_DIR, os.path.basename(input_key))
|
| 161 |
+
download_file_from_s3(bucket_name, input_key, input_file_path)
|
| 162 |
+
|
| 163 |
+
# 3.1. Validate file type compatibility
|
| 164 |
+
is_env_file = input_key.lower().endswith(".env")
|
| 165 |
+
|
| 166 |
+
if not is_env_file and file_extension not in COMPATIBLE_FILE_TYPES:
|
| 167 |
+
error_message = f"File type '{file_extension}' is not supported for processing. Compatible file types are: {', '.join(sorted(COMPATIBLE_FILE_TYPES))}"
|
| 168 |
+
print(f"ERROR: {error_message}")
|
| 169 |
+
print(f"File was not processed due to unsupported file type: {file_extension}")
|
| 170 |
+
return {
|
| 171 |
+
"statusCode": 400,
|
| 172 |
+
"body": json.dumps(
|
| 173 |
+
{
|
| 174 |
+
"error": "Unsupported file type",
|
| 175 |
+
"message": error_message,
|
| 176 |
+
"supported_types": list(COMPATIBLE_FILE_TYPES),
|
| 177 |
+
"received_type": file_extension,
|
| 178 |
+
"file_processed": False,
|
| 179 |
+
}
|
| 180 |
+
),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
print(f"File type '{file_extension}' is compatible for processing")
|
| 184 |
+
if is_env_file:
|
| 185 |
+
print("Processing .env file for configuration")
|
| 186 |
+
else:
|
| 187 |
+
print(f"Processing {file_extension} file for topic modelling")
|
| 188 |
+
|
| 189 |
+
# 3.5. Check if the downloaded file is a .env file and handle accordingly
|
| 190 |
+
actual_input_file_path = input_file_path
|
| 191 |
+
if input_key.lower().endswith(".env"):
|
| 192 |
+
print("Detected .env file, loading environment variables...")
|
| 193 |
+
|
| 194 |
+
# Load environment variables from the .env file
|
| 195 |
+
print(f"Loading .env file from: {input_file_path}")
|
| 196 |
+
|
| 197 |
+
# Check if file exists and is readable
|
| 198 |
+
if os.path.exists(input_file_path):
|
| 199 |
+
print(".env file exists and is readable")
|
| 200 |
+
with open(input_file_path, "r") as f:
|
| 201 |
+
content = f.read()
|
| 202 |
+
print(f".env file content preview: {content[:200]}...")
|
| 203 |
+
else:
|
| 204 |
+
print(f"ERROR: .env file does not exist at {input_file_path}")
|
| 205 |
+
|
| 206 |
+
load_dotenv(input_file_path, override=True)
|
| 207 |
+
print("Environment variables loaded from .env file")
|
| 208 |
+
|
| 209 |
+
# Extract the actual input file path from environment variables
|
| 210 |
+
env_input_file = os.getenv("INPUT_FILE")
|
| 211 |
+
|
| 212 |
+
if env_input_file:
|
| 213 |
+
print(f"Found input file path in environment: {env_input_file}")
|
| 214 |
+
|
| 215 |
+
# If the path is an S3 path, download it
|
| 216 |
+
if env_input_file.startswith("s3://"):
|
| 217 |
+
# Parse S3 path: s3://bucket/key
|
| 218 |
+
s3_path_parts = env_input_file[5:].split("/", 1)
|
| 219 |
+
if len(s3_path_parts) == 2:
|
| 220 |
+
env_bucket = s3_path_parts[0]
|
| 221 |
+
env_key = s3_path_parts[1]
|
| 222 |
+
actual_input_file_path = os.path.join(
|
| 223 |
+
INPUT_DIR, os.path.basename(env_key)
|
| 224 |
+
)
|
| 225 |
+
print(
|
| 226 |
+
f"Downloading actual input file from s3://{env_bucket}/{env_key}"
|
| 227 |
+
)
|
| 228 |
+
download_file_from_s3(env_bucket, env_key, actual_input_file_path)
|
| 229 |
+
else:
|
| 230 |
+
print("Warning: Invalid S3 path format in environment variable")
|
| 231 |
+
actual_input_file_path = input_file_path
|
| 232 |
+
else:
|
| 233 |
+
# Assume it's a local path or relative path
|
| 234 |
+
actual_input_file_path = env_input_file
|
| 235 |
+
print(
|
| 236 |
+
f"Using input file path from environment: {actual_input_file_path}"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
print("Warning: No input file path found in environment variables")
|
| 240 |
+
# Fall back to using the .env file itself (though this might not be what we want)
|
| 241 |
+
actual_input_file_path = input_file_path
|
| 242 |
+
else:
|
| 243 |
+
print("File is not a .env file, proceeding with normal processing")
|
| 244 |
+
|
| 245 |
+
# 4. Prepare arguments for the CLI function
|
| 246 |
+
# This dictionary should mirror the arguments that cli_topics.main() expects via direct_mode_args
|
| 247 |
+
|
| 248 |
+
cli_args = {
|
| 249 |
+
# Task Selection
|
| 250 |
+
"task": arguments.get("task", os.getenv("DIRECT_MODE_TASK", "extract")),
|
| 251 |
+
# General Arguments
|
| 252 |
+
"input_file": [actual_input_file_path] if actual_input_file_path else None,
|
| 253 |
+
"output_dir": arguments.get(
|
| 254 |
+
"output_dir", os.getenv("DIRECT_MODE_OUTPUT_DIR", OUTPUT_DIR)
|
| 255 |
+
),
|
| 256 |
+
"input_dir": arguments.get("input_dir", INPUT_DIR),
|
| 257 |
+
"text_column": arguments.get(
|
| 258 |
+
"text_column", os.getenv("DIRECT_MODE_TEXT_COLUMN", "")
|
| 259 |
+
),
|
| 260 |
+
"previous_output_files": _get_env_list(
|
| 261 |
+
arguments.get(
|
| 262 |
+
"previous_output_files",
|
| 263 |
+
os.getenv("DIRECT_MODE_PREVIOUS_OUTPUT_FILES", list()),
|
| 264 |
+
)
|
| 265 |
+
),
|
| 266 |
+
"username": arguments.get("username", os.getenv("DIRECT_MODE_USERNAME", "")),
|
| 267 |
+
"save_to_user_folders": convert_string_to_boolean(
|
| 268 |
+
arguments.get(
|
| 269 |
+
"save_to_user_folders",
|
| 270 |
+
os.getenv("SESSION_OUTPUT_FOLDER", str(SESSION_OUTPUT_FOLDER)),
|
| 271 |
+
)
|
| 272 |
+
),
|
| 273 |
+
"excel_sheets": _get_env_list(
|
| 274 |
+
arguments.get("excel_sheets", os.getenv("DIRECT_MODE_EXCEL_SHEETS", list()))
|
| 275 |
+
),
|
| 276 |
+
"group_by": arguments.get("group_by", os.getenv("DIRECT_MODE_GROUP_BY", "")),
|
| 277 |
+
# Model Configuration
|
| 278 |
+
"model_choice": arguments.get(
|
| 279 |
+
"model_choice", os.getenv("DIRECT_MODE_MODEL_CHOICE", "")
|
| 280 |
+
),
|
| 281 |
+
"temperature": float(
|
| 282 |
+
arguments.get(
|
| 283 |
+
"temperature",
|
| 284 |
+
os.getenv("DIRECT_MODE_TEMPERATURE", str(LLM_TEMPERATURE)),
|
| 285 |
+
)
|
| 286 |
+
),
|
| 287 |
+
"batch_size": int(
|
| 288 |
+
arguments.get(
|
| 289 |
+
"batch_size",
|
| 290 |
+
os.getenv("DIRECT_MODE_BATCH_SIZE", str(BATCH_SIZE_DEFAULT)),
|
| 291 |
+
)
|
| 292 |
+
),
|
| 293 |
+
"max_tokens": int(
|
| 294 |
+
arguments.get(
|
| 295 |
+
"max_tokens",
|
| 296 |
+
os.getenv("DIRECT_MODE_MAX_TOKENS", str(LLM_MAX_NEW_TOKENS)),
|
| 297 |
+
)
|
| 298 |
+
),
|
| 299 |
+
"google_api_key": arguments.get(
|
| 300 |
+
"google_api_key", os.getenv("GEMINI_API_KEY", "")
|
| 301 |
+
),
|
| 302 |
+
"aws_access_key": None, # Use IAM Role instead of keys
|
| 303 |
+
"aws_secret_key": None, # Use IAM Role instead of keys
|
| 304 |
+
"aws_region": os.getenv("AWS_REGION", AWS_REGION),
|
| 305 |
+
"hf_token": arguments.get("hf_token", os.getenv("HF_TOKEN", "")),
|
| 306 |
+
"azure_api_key": arguments.get(
|
| 307 |
+
"azure_api_key", os.getenv("AZURE_OPENAI_API_KEY", "")
|
| 308 |
+
),
|
| 309 |
+
"azure_endpoint": arguments.get(
|
| 310 |
+
"azure_endpoint", os.getenv("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
|
| 311 |
+
),
|
| 312 |
+
"api_url": arguments.get("api_url", os.getenv("API_URL", "")),
|
| 313 |
+
"inference_server_model": arguments.get(
|
| 314 |
+
"inference_server_model", os.getenv("CHOSEN_INFERENCE_SERVER_MODEL", "")
|
| 315 |
+
),
|
| 316 |
+
# Topic Extraction Arguments
|
| 317 |
+
"context": arguments.get("context", os.getenv("DIRECT_MODE_CONTEXT", "")),
|
| 318 |
+
"candidate_topics": arguments.get(
|
| 319 |
+
"candidate_topics", os.getenv("DIRECT_MODE_CANDIDATE_TOPICS", "")
|
| 320 |
+
),
|
| 321 |
+
"force_zero_shot": arguments.get(
|
| 322 |
+
"force_zero_shot", os.getenv("DIRECT_MODE_FORCE_ZERO_SHOT", "No")
|
| 323 |
+
),
|
| 324 |
+
"force_single_topic": arguments.get(
|
| 325 |
+
"force_single_topic", os.getenv("DIRECT_MODE_FORCE_SINGLE_TOPIC", "No")
|
| 326 |
+
),
|
| 327 |
+
"produce_structured_summary": arguments.get(
|
| 328 |
+
"produce_structured_summary",
|
| 329 |
+
os.getenv("DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY", "No"),
|
| 330 |
+
),
|
| 331 |
+
"sentiment": arguments.get(
|
| 332 |
+
"sentiment", os.getenv("DIRECT_MODE_SENTIMENT", "Negative or Positive")
|
| 333 |
+
),
|
| 334 |
+
"additional_summary_instructions": arguments.get(
|
| 335 |
+
"additional_summary_instructions",
|
| 336 |
+
os.getenv("DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS", ""),
|
| 337 |
+
),
|
| 338 |
+
# Validation Arguments
|
| 339 |
+
"additional_validation_issues": arguments.get(
|
| 340 |
+
"additional_validation_issues",
|
| 341 |
+
os.getenv("DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES", ""),
|
| 342 |
+
),
|
| 343 |
+
"show_previous_table": arguments.get(
|
| 344 |
+
"show_previous_table", os.getenv("DIRECT_MODE_SHOW_PREVIOUS_TABLE", "Yes")
|
| 345 |
+
),
|
| 346 |
+
"output_debug_files": arguments.get(
|
| 347 |
+
"output_debug_files", str(OUTPUT_DEBUG_FILES)
|
| 348 |
+
),
|
| 349 |
+
"max_time_for_loop": int(
|
| 350 |
+
arguments.get("max_time_for_loop", os.getenv("MAX_TIME_FOR_LOOP", "99999"))
|
| 351 |
+
),
|
| 352 |
+
# Deduplication Arguments
|
| 353 |
+
"method": arguments.get(
|
| 354 |
+
"method", os.getenv("DIRECT_MODE_DEDUPLICATION_METHOD", "fuzzy")
|
| 355 |
+
),
|
| 356 |
+
"similarity_threshold": int(
|
| 357 |
+
arguments.get(
|
| 358 |
+
"similarity_threshold",
|
| 359 |
+
os.getenv("DEDUPLICATION_THRESHOLD", DEDUPLICATION_THRESHOLD),
|
| 360 |
+
)
|
| 361 |
+
),
|
| 362 |
+
"merge_sentiment": arguments.get(
|
| 363 |
+
"merge_sentiment", os.getenv("DIRECT_MODE_MERGE_SENTIMENT", "No")
|
| 364 |
+
),
|
| 365 |
+
"merge_general_topics": arguments.get(
|
| 366 |
+
"merge_general_topics", os.getenv("DIRECT_MODE_MERGE_GENERAL_TOPICS", "Yes")
|
| 367 |
+
),
|
| 368 |
+
# Summarisation Arguments
|
| 369 |
+
"summary_format": arguments.get(
|
| 370 |
+
"summary_format", os.getenv("DIRECT_MODE_SUMMARY_FORMAT", "two_paragraph")
|
| 371 |
+
),
|
| 372 |
+
"sample_reference_table": arguments.get(
|
| 373 |
+
"sample_reference_table",
|
| 374 |
+
os.getenv("DIRECT_MODE_SAMPLE_REFERENCE_TABLE", "True"),
|
| 375 |
+
),
|
| 376 |
+
"no_of_sampled_summaries": int(
|
| 377 |
+
arguments.get(
|
| 378 |
+
"no_of_sampled_summaries",
|
| 379 |
+
os.getenv("DEFAULT_SAMPLED_SUMMARIES", DEFAULT_SAMPLED_SUMMARIES),
|
| 380 |
+
)
|
| 381 |
+
),
|
| 382 |
+
"random_seed": int(
|
| 383 |
+
arguments.get("random_seed", os.getenv("LLM_SEED", LLM_SEED))
|
| 384 |
+
),
|
| 385 |
+
# Output Format Arguments
|
| 386 |
+
"create_xlsx_output": convert_string_to_boolean(
|
| 387 |
+
arguments.get(
|
| 388 |
+
"create_xlsx_output",
|
| 389 |
+
os.getenv("DIRECT_MODE_CREATE_XLSX_OUTPUT", "True"),
|
| 390 |
+
)
|
| 391 |
+
),
|
| 392 |
+
# Logging Arguments
|
| 393 |
+
"save_logs_to_csv": convert_string_to_boolean(
|
| 394 |
+
arguments.get(
|
| 395 |
+
"save_logs_to_csv", os.getenv("SAVE_LOGS_TO_CSV", str(SAVE_LOGS_TO_CSV))
|
| 396 |
+
)
|
| 397 |
+
),
|
| 398 |
+
"save_logs_to_dynamodb": convert_string_to_boolean(
|
| 399 |
+
arguments.get(
|
| 400 |
+
"save_logs_to_dynamodb",
|
| 401 |
+
os.getenv("SAVE_LOGS_TO_DYNAMODB", str(SAVE_LOGS_TO_DYNAMODB)),
|
| 402 |
+
)
|
| 403 |
+
),
|
| 404 |
+
"usage_logs_folder": arguments.get("usage_logs_folder", USAGE_LOGS_FOLDER),
|
| 405 |
+
"cost_code": arguments.get(
|
| 406 |
+
"cost_code", os.getenv("DEFAULT_COST_CODE", DEFAULT_COST_CODE)
|
| 407 |
+
),
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# Download optional files if they are specified
|
| 411 |
+
candidate_topics_key = arguments.get("candidate_topics_s3_key")
|
| 412 |
+
if candidate_topics_key:
|
| 413 |
+
candidate_topics_path = os.path.join(INPUT_DIR, "candidate_topics.csv")
|
| 414 |
+
download_file_from_s3(bucket_name, candidate_topics_key, candidate_topics_path)
|
| 415 |
+
cli_args["candidate_topics"] = candidate_topics_path
|
| 416 |
+
|
| 417 |
+
# Download previous output files if they are S3 keys
|
| 418 |
+
if cli_args["previous_output_files"]:
|
| 419 |
+
downloaded_previous_files = []
|
| 420 |
+
for prev_file in cli_args["previous_output_files"]:
|
| 421 |
+
if prev_file.startswith("s3://"):
|
| 422 |
+
# Parse S3 path
|
| 423 |
+
s3_path_parts = prev_file[5:].split("/", 1)
|
| 424 |
+
if len(s3_path_parts) == 2:
|
| 425 |
+
prev_bucket = s3_path_parts[0]
|
| 426 |
+
prev_key = s3_path_parts[1]
|
| 427 |
+
local_prev_path = os.path.join(
|
| 428 |
+
INPUT_DIR, os.path.basename(prev_key)
|
| 429 |
+
)
|
| 430 |
+
download_file_from_s3(prev_bucket, prev_key, local_prev_path)
|
| 431 |
+
downloaded_previous_files.append(local_prev_path)
|
| 432 |
+
else:
|
| 433 |
+
downloaded_previous_files.append(prev_file)
|
| 434 |
+
else:
|
| 435 |
+
downloaded_previous_files.append(prev_file)
|
| 436 |
+
cli_args["previous_output_files"] = downloaded_previous_files
|
| 437 |
+
|
| 438 |
+
# 5. Execute the main application logic
|
| 439 |
+
try:
|
| 440 |
+
print("--- Starting CLI Topics Main Function ---")
|
| 441 |
+
print(
|
| 442 |
+
f"Arguments passed to cli_main: {json.dumps({k: v for k, v in cli_args.items() if k not in ['aws_access_key', 'aws_secret_key']}, default=str)}"
|
| 443 |
+
)
|
| 444 |
+
cli_main(direct_mode_args=cli_args)
|
| 445 |
+
print("--- CLI Topics Main Function Finished ---")
|
| 446 |
+
except Exception as e:
|
| 447 |
+
print(f"An error occurred during CLI execution: {e}")
|
| 448 |
+
import traceback
|
| 449 |
+
|
| 450 |
+
traceback.print_exc()
|
| 451 |
+
# Optionally, re-raise the exception to make the Lambda fail
|
| 452 |
+
raise
|
| 453 |
+
|
| 454 |
+
# 6. Upload results back to S3
|
| 455 |
+
output_s3_prefix = f"output/{os.path.splitext(os.path.basename(input_key))[0]}"
|
| 456 |
+
print(
|
| 457 |
+
f"Uploading contents of {OUTPUT_DIR} to s3://{bucket_name}/{output_s3_prefix}/"
|
| 458 |
+
)
|
| 459 |
+
upload_directory_to_s3(OUTPUT_DIR, bucket_name, output_s3_prefix)
|
| 460 |
+
|
| 461 |
+
return {
|
| 462 |
+
"statusCode": 200,
|
| 463 |
+
"body": json.dumps(
|
| 464 |
+
f"Processing complete for {input_key}. Output saved to s3://{bucket_name}/{output_s3_prefix}/"
|
| 465 |
+
),
|
| 466 |
+
}
|
load_dynamo_logs.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import datetime
|
| 3 |
+
from decimal import Decimal
|
| 4 |
+
|
| 5 |
+
import boto3
|
| 6 |
+
|
| 7 |
+
from tools.config import (
|
| 8 |
+
AWS_REGION,
|
| 9 |
+
OUTPUT_FOLDER,
|
| 10 |
+
USAGE_LOG_DYNAMODB_TABLE_NAME,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# Replace with your actual table name and region
|
| 14 |
+
TABLE_NAME = USAGE_LOG_DYNAMODB_TABLE_NAME # Choose as appropriate
|
| 15 |
+
REGION = AWS_REGION
|
| 16 |
+
CSV_OUTPUT = OUTPUT_FOLDER + "dynamodb_logs_export.csv"
|
| 17 |
+
|
| 18 |
+
# Create DynamoDB resource
|
| 19 |
+
dynamodb = boto3.resource("dynamodb", region_name=REGION)
|
| 20 |
+
table = dynamodb.Table(TABLE_NAME)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Helper function to convert Decimal to float or int
|
| 24 |
+
def convert_types(item):
|
| 25 |
+
new_item = {}
|
| 26 |
+
for key, value in item.items():
|
| 27 |
+
# Handle Decimals first
|
| 28 |
+
if isinstance(value, Decimal):
|
| 29 |
+
new_item[key] = int(value) if value % 1 == 0 else float(value)
|
| 30 |
+
# Handle Strings that might be dates
|
| 31 |
+
elif isinstance(value, str):
|
| 32 |
+
try:
|
| 33 |
+
# Attempt to parse a common ISO 8601 format.
|
| 34 |
+
# The .replace() handles the 'Z' for Zulu/UTC time.
|
| 35 |
+
dt_obj = datetime.datetime.fromisoformat(value.replace("Z", "+00:00"))
|
| 36 |
+
# Now that we have a datetime object, format it as desired
|
| 37 |
+
new_item[key] = dt_obj.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| 38 |
+
except (ValueError, TypeError):
|
| 39 |
+
# If it fails to parse, it's just a regular string
|
| 40 |
+
new_item[key] = value
|
| 41 |
+
# Handle all other types
|
| 42 |
+
else:
|
| 43 |
+
new_item[key] = value
|
| 44 |
+
return new_item
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Paginated scan
|
| 48 |
+
def scan_table():
|
| 49 |
+
items = []
|
| 50 |
+
response = table.scan()
|
| 51 |
+
items.extend(response["Items"])
|
| 52 |
+
|
| 53 |
+
while "LastEvaluatedKey" in response:
|
| 54 |
+
response = table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
|
| 55 |
+
items.extend(response["Items"])
|
| 56 |
+
|
| 57 |
+
return items
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Export to CSV
|
| 61 |
+
# Export to CSV
|
| 62 |
+
def export_to_csv(items, output_path, fields_to_drop: list = None):
|
| 63 |
+
if not items:
|
| 64 |
+
print("No items found.")
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
# Use a set for efficient lookup
|
| 68 |
+
drop_set = set(fields_to_drop or [])
|
| 69 |
+
|
| 70 |
+
# Get a comprehensive list of all possible headers from all items
|
| 71 |
+
all_keys = set()
|
| 72 |
+
for item in items:
|
| 73 |
+
all_keys.update(item.keys())
|
| 74 |
+
|
| 75 |
+
# Determine the final fieldnames by subtracting the ones to drop
|
| 76 |
+
fieldnames = sorted(list(all_keys - drop_set))
|
| 77 |
+
|
| 78 |
+
print("Final CSV columns will be:", fieldnames)
|
| 79 |
+
|
| 80 |
+
with open(output_path, "w", newline="", encoding="utf-8-sig") as csvfile:
|
| 81 |
+
# The key fix is here: extrasaction='ignore'
|
| 82 |
+
# restval='' is also good practice to handle rows that are missing a key
|
| 83 |
+
writer = csv.DictWriter(
|
| 84 |
+
csvfile, fieldnames=fieldnames, extrasaction="ignore", restval=""
|
| 85 |
+
)
|
| 86 |
+
writer.writeheader()
|
| 87 |
+
|
| 88 |
+
for item in items:
|
| 89 |
+
# The convert_types function can now return the full dict,
|
| 90 |
+
# and the writer will simply ignore the extra fields.
|
| 91 |
+
writer.writerow(convert_types(item))
|
| 92 |
+
|
| 93 |
+
print(f"Exported {len(items)} items to {output_path}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Run export
|
| 97 |
+
items = scan_table()
|
| 98 |
+
export_to_csv(
|
| 99 |
+
items,
|
| 100 |
+
CSV_OUTPUT,
|
| 101 |
+
fields_to_drop=["Query metadata - usage counts and other parameters"],
|
| 102 |
+
)
|
load_s3_logs.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from io import StringIO
|
| 3 |
+
|
| 4 |
+
import boto3
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from tools.config import (
|
| 8 |
+
AWS_ACCESS_KEY,
|
| 9 |
+
AWS_REGION,
|
| 10 |
+
AWS_SECRET_KEY,
|
| 11 |
+
DOCUMENT_REDACTION_BUCKET,
|
| 12 |
+
OUTPUT_FOLDER,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Combine together log files that can be then used for e.g. dashboarding and financial tracking.
|
| 16 |
+
|
| 17 |
+
# S3 setup. Try to use provided keys (needs S3 permissions), otherwise assume AWS SSO connection
|
| 18 |
+
if AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION:
|
| 19 |
+
s3 = boto3.client(
|
| 20 |
+
"s3",
|
| 21 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 22 |
+
aws_secret_access_key=AWS_SECRET_KEY,
|
| 23 |
+
region_name=AWS_REGION,
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
s3 = boto3.client("s3")
|
| 27 |
+
|
| 28 |
+
bucket_name = DOCUMENT_REDACTION_BUCKET
|
| 29 |
+
prefix = "usage/" # 'feedback/' # 'logs/' # Change as needed - top-level folder where logs are stored
|
| 30 |
+
earliest_date = "20250409" # Earliest date of logs folder retrieved
|
| 31 |
+
latest_date = "20250423" # Latest date of logs folder retrieved
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Function to list all files in a folder
|
| 35 |
+
def list_files_in_s3(bucket, prefix):
|
| 36 |
+
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
| 37 |
+
if "Contents" in response:
|
| 38 |
+
return [content["Key"] for content in response["Contents"]]
|
| 39 |
+
return []
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Function to filter date range
|
| 43 |
+
def is_within_date_range(date_str, start_date, end_date):
|
| 44 |
+
date_obj = datetime.strptime(date_str, "%Y%m%d")
|
| 45 |
+
return start_date <= date_obj <= end_date
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Define the date range
|
| 49 |
+
start_date = datetime.strptime(earliest_date, "%Y%m%d") # Replace with your start date
|
| 50 |
+
end_date = datetime.strptime(latest_date, "%Y%m%d") # Replace with your end date
|
| 51 |
+
|
| 52 |
+
# List all subfolders under 'usage/'
|
| 53 |
+
all_files = list_files_in_s3(bucket_name, prefix)
|
| 54 |
+
|
| 55 |
+
# Filter based on date range
|
| 56 |
+
log_files = []
|
| 57 |
+
for file in all_files:
|
| 58 |
+
parts = file.split("/")
|
| 59 |
+
if len(parts) >= 3:
|
| 60 |
+
date_str = parts[1]
|
| 61 |
+
if (
|
| 62 |
+
is_within_date_range(date_str, start_date, end_date)
|
| 63 |
+
and parts[-1] == "log.csv"
|
| 64 |
+
):
|
| 65 |
+
log_files.append(file)
|
| 66 |
+
|
| 67 |
+
# Download, read and concatenate CSV files into a pandas DataFrame
|
| 68 |
+
df_list = []
|
| 69 |
+
for log_file in log_files:
|
| 70 |
+
# Download the file
|
| 71 |
+
obj = s3.get_object(Bucket=bucket_name, Key=log_file)
|
| 72 |
+
try:
|
| 73 |
+
csv_content = (
|
| 74 |
+
obj["Body"].read().decode("utf-8")
|
| 75 |
+
) # Suggest trying latin-1 instead of utf-8 if this fails
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print("Could not load in log file:", log_file, "due to:", e)
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Read CSV content into pandas DataFrame
|
| 81 |
+
df = pd.read_csv(StringIO(csv_content))
|
| 82 |
+
|
| 83 |
+
df_list.append(df)
|
| 84 |
+
|
| 85 |
+
# Concatenate all DataFrames
|
| 86 |
+
if df_list:
|
| 87 |
+
concatenated_df = pd.concat(df_list, ignore_index=True)
|
| 88 |
+
|
| 89 |
+
# Save the concatenated DataFrame to a CSV file
|
| 90 |
+
concatenated_df.to_csv(OUTPUT_FOLDER + "consolidated_s3_logs.csv", index=False)
|
| 91 |
+
print("Consolidated CSV saved as 'consolidated_s3_logs.csv'")
|
| 92 |
+
else:
|
| 93 |
+
print("No log files found in the given date range.")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "llm_topic_modelling"
|
| 7 |
+
version = "0.6.0"
|
| 8 |
+
description = "Generate thematic summaries from open text in tabular data files with a large language model."
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
readme = "README.md"
|
| 11 |
+
authors = [
|
| 12 |
+
{ name = "Sean Pedrick-Case", email = "spedrickcase@lambeth.gov.uk" },
|
| 13 |
+
]
|
| 14 |
+
maintainers = [
|
| 15 |
+
{ name = "Sean Pedrick-Case", email = "spedrickcase@lambeth.gov.uk" },
|
| 16 |
+
]
|
| 17 |
+
keywords = [
|
| 18 |
+
"topic-modelling",
|
| 19 |
+
"topic-modeling",
|
| 20 |
+
"llm",
|
| 21 |
+
"large-language-models",
|
| 22 |
+
"thematic-analysis",
|
| 23 |
+
"text-analysis",
|
| 24 |
+
"nlp",
|
| 25 |
+
"natural-language-processing",
|
| 26 |
+
"text-summarization",
|
| 27 |
+
"text-summarisation",
|
| 28 |
+
"thematic-summaries",
|
| 29 |
+
"gradio",
|
| 30 |
+
"data-analysis",
|
| 31 |
+
"tabular-data",
|
| 32 |
+
"excel",
|
| 33 |
+
"csv",
|
| 34 |
+
"open-text",
|
| 35 |
+
"text-mining"
|
| 36 |
+
]
|
| 37 |
+
classifiers = [
|
| 38 |
+
"Development Status :: 4 - Beta",
|
| 39 |
+
"Intended Audience :: Developers",
|
| 40 |
+
"Intended Audience :: Science/Research",
|
| 41 |
+
"Intended Audience :: Information Technology",
|
| 42 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 43 |
+
"Topic :: Text Processing :: Linguistic",
|
| 44 |
+
"Topic :: Text Processing :: Markup",
|
| 45 |
+
"Topic :: Scientific/Engineering :: Information Analysis",
|
| 46 |
+
"Programming Language :: Python :: 3",
|
| 47 |
+
"Programming Language :: Python :: 3.10",
|
| 48 |
+
"Programming Language :: Python :: 3.11",
|
| 49 |
+
"Programming Language :: Python :: 3.12",
|
| 50 |
+
"Programming Language :: Python :: 3.13",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
dependencies = [
|
| 54 |
+
"gradio==6.0.2",
|
| 55 |
+
"transformers==4.57.2",
|
| 56 |
+
"spaces==0.42.1",
|
| 57 |
+
"boto3==1.42.1",
|
| 58 |
+
"pandas<=2.3.3",
|
| 59 |
+
"pyarrow>=21.0.0",
|
| 60 |
+
"openpyxl>=3.1.5",
|
| 61 |
+
"markdown>=3.7",
|
| 62 |
+
"tabulate>=0.9.0",
|
| 63 |
+
"lxml>=5.3.0",
|
| 64 |
+
"google-genai<=1.52.0",
|
| 65 |
+
"openai<=2.8.1",
|
| 66 |
+
"html5lib>=1.1",
|
| 67 |
+
"beautifulsoup4>=4.12.3",
|
| 68 |
+
"rapidfuzz>=3.13.0",
|
| 69 |
+
"python-dotenv>=1.1.0"
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
[project.optional-dependencies]
|
| 73 |
+
dev = ["pytest"]
|
| 74 |
+
test = ["pytest", "pytest-cov"]
|
| 75 |
+
|
| 76 |
+
# Extra dependencies for VLM models
|
| 77 |
+
# For torch you should use --index-url https://download.pytorch.org/whl/cu128. Additionally installs the unsloth package
|
| 78 |
+
torch = [
|
| 79 |
+
"torch<=2.9.1",
|
| 80 |
+
"torchvision",
|
| 81 |
+
"accelerate",
|
| 82 |
+
"bitsandbytes",
|
| 83 |
+
"unsloth==2025.11.6",
|
| 84 |
+
"unsloth_zoo==2025.11.6",
|
| 85 |
+
"timm",
|
| 86 |
+
"xformers"
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# If you want to install llama-cpp-python in GPU mode, use cmake.args="-DGGML_CUDA=on" . If that doesn't work, try specific wheels for your system, e.g. for Linux see files in https://github.com/JamePeng/llama-cpp-python/releases. More details on installation here: https://llama-cpp-python.readthedocs.io/en/latest
|
| 90 |
+
llamacpp = [
|
| 91 |
+
"llama-cpp-python>=0.3.16",
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
# Run Gradio as an mcp server
|
| 95 |
+
mcp = [
|
| 96 |
+
"gradio[mcp]==6.0.2"
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
[project.urls]
|
| 100 |
+
Homepage = "https://github.com/seanpedrick-case/llm_topic_modelling"
|
| 101 |
+
repository = "https://github.com/seanpedrick-case/llm_topic_modelling"
|
| 102 |
+
|
| 103 |
+
[tool.setuptools]
|
| 104 |
+
packages = ["tools"]
|
| 105 |
+
py-modules = ["app"]
|
| 106 |
+
|
| 107 |
+
# Configuration for Ruff linter:
|
| 108 |
+
[tool.ruff]
|
| 109 |
+
line-length = 88
|
| 110 |
+
|
| 111 |
+
[tool.ruff.lint]
|
| 112 |
+
select = ["E", "F", "I"]
|
| 113 |
+
ignore = [
|
| 114 |
+
"E501", # line-too-long (handled with Black)
|
| 115 |
+
"E402", # module-import-not-at-top-of-file (sometimes needed for conditional imports)
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
[tool.ruff.lint.per-file-ignores]
|
| 119 |
+
"__init__.py" = ["F401"] # Allow unused imports in __init__.py
|
| 120 |
+
|
| 121 |
+
# Configuration for a Black formatter:
|
| 122 |
+
[tool.black]
|
| 123 |
+
line-length = 88
|
| 124 |
+
target-version = ['py310']
|
| 125 |
+
|
| 126 |
+
# Configuration for pytest:
|
| 127 |
+
[tool.pytest.ini_options]
|
| 128 |
+
filterwarnings = [
|
| 129 |
+
"ignore::DeprecationWarning:click.parser",
|
| 130 |
+
"ignore::DeprecationWarning:weasel.util.config",
|
| 131 |
+
"ignore::DeprecationWarning:builtin type",
|
| 132 |
+
"ignore::DeprecationWarning:websockets.legacy",
|
| 133 |
+
"ignore::DeprecationWarning:websockets.server",
|
| 134 |
+
"ignore::DeprecationWarning:spacy.cli._util",
|
| 135 |
+
"ignore::DeprecationWarning:weasel.util.config",
|
| 136 |
+
"ignore::DeprecationWarning:importlib._bootstrap",
|
| 137 |
+
]
|
| 138 |
+
testpaths = ["test"]
|
| 139 |
+
python_files = ["test_*.py", "*_test.py"]
|
| 140 |
+
python_classes = ["Test*"]
|
| 141 |
+
python_functions = ["test_*"]
|
| 142 |
+
addopts = [
|
| 143 |
+
"-v",
|
| 144 |
+
"--tb=short",
|
| 145 |
+
"--strict-markers",
|
| 146 |
+
"--disable-warnings",
|
| 147 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Note that this requirements file is optimised for Hugging Face spaces / Python 3.10. Please use requirements_no_local.txt for installation without local model inference (simplest approach to get going). Please use requirements_cpu.txt for CPU instances and requirements_gpu.txt for GPU instances using Python 3.11
|
| 2 |
+
gradio==6.0.2
|
| 3 |
+
transformers==4.57.2
|
| 4 |
+
spaces==0.42.1
|
| 5 |
+
boto3>=1.42.1
|
| 6 |
+
pandas>=2.3.3
|
| 7 |
+
pyarrow>=21.0.0
|
| 8 |
+
openpyxl>=3.1.5
|
| 9 |
+
markdown>=3.7
|
| 10 |
+
tabulate>=0.9.0
|
| 11 |
+
lxml>=5.3.0
|
| 12 |
+
google-genai>=1.52.0
|
| 13 |
+
openai>=2.8.1
|
| 14 |
+
html5lib>=1.1
|
| 15 |
+
beautifulsoup4>=4.12.3
|
| 16 |
+
rapidfuzz>=3.13.0
|
| 17 |
+
python-dotenv>=1.1.0
|
| 18 |
+
# GPU (for huggingface instance)
|
| 19 |
+
# Torch/Unsloth and llama-cpp-python
|
| 20 |
+
# Latest compatible with CUDA 12.4
|
| 21 |
+
torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cu128
|
| 22 |
+
unsloth[cu128-torch280]<=2025.11.6
|
| 23 |
+
unsloth_zoo<=2025.11.6
|
| 24 |
+
timm
|
| 25 |
+
# llama-cpp-python direct wheel link for GPU compatible version 3.17 for use with Python 3.10 and Hugging Face
|
| 26 |
+
https://github.com/JamePeng/llama-cpp-python/releases/download/v0.3.17-cu128-Basic-linux-20251202/llama_cpp_python-0.3.17-cp310-cp310-linux_x86_64.whl
|
| 27 |
+
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
| 28 |
+
|
| 29 |
+
|
requirements_cpu.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==6.0.2
|
| 2 |
+
transformers==4.57.2
|
| 3 |
+
spaces==0.42.1
|
| 4 |
+
pandas>=2.3.3
|
| 5 |
+
boto3>=1.42.1
|
| 6 |
+
pyarrow>=21.0.0
|
| 7 |
+
openpyxl>=3.1.5
|
| 8 |
+
markdown>=3.7
|
| 9 |
+
tabulate>=0.9.0
|
| 10 |
+
lxml>=5.3.0
|
| 11 |
+
google-genai>=1.52.0
|
| 12 |
+
openai>=2.8.1
|
| 13 |
+
html5lib>=1.1
|
| 14 |
+
beautifulsoup4>=4.12.3
|
| 15 |
+
rapidfuzz>=3.13.0
|
| 16 |
+
python-dotenv>=1.1.0
|
| 17 |
+
torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
| 18 |
+
llama-cpp-python==0.3.16 -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
|
| 19 |
+
# Direct wheel links if above doesn't work
|
| 20 |
+
# I have created CPU Linux, Python 3.11 compatible wheels:
|
| 21 |
+
# https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
|
| 22 |
+
# Windows, Python 3.11 compatible CPU wheels available:
|
| 23 |
+
# https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-win_amd64_cpu_openblas.whl
|
| 24 |
+
# If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
|
requirements_gpu.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
gradio==6.0.2
|
| 3 |
+
transformers==4.57.2
|
| 4 |
+
spaces==0.42.1
|
| 5 |
+
boto3>=1.42.1
|
| 6 |
+
pandas>=2.3.3
|
| 7 |
+
pyarrow>=21.0.0
|
| 8 |
+
openpyxl>=3.1.5
|
| 9 |
+
markdown>=3.7
|
| 10 |
+
tabulate>=0.9.0
|
| 11 |
+
lxml>=5.3.0
|
| 12 |
+
google-genai>=1.52.0
|
| 13 |
+
openai>=2.8.1
|
| 14 |
+
html5lib>=1.1
|
| 15 |
+
beautifulsoup4>=4.12.3
|
| 16 |
+
rapidfuzz>=3.13.0
|
| 17 |
+
python-dotenv>=1.1.0
|
| 18 |
+
# Torch/Unsloth
|
| 19 |
+
# Latest compatible with CUDA 12.4
|
| 20 |
+
torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cu128
|
| 21 |
+
unsloth[cu128-torch280]<=2025.11.6 # Refer here for more details on installation: https://pypi.org/project/unsloth
|
| 22 |
+
unsloth_zoo<=2025.11.6
|
| 23 |
+
# Additional for Windows and CUDA 12.4 older GPUS (RTX 3x or similar):
|
| 24 |
+
#triton-windows<3.3
|
| 25 |
+
timm
|
| 26 |
+
# Llama CPP Python
|
| 27 |
+
llama-cpp-python>=0.3.16 -C cmake.args="-DGGML_CUDA=on"
|
| 28 |
+
# If above doesn't work, try specific wheels for your system, see files in https://github.com/JamePeng/llama-cpp-python/releases for different python versions
|
requirements_lightweight.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This requirements file is optimised for AWS ECS using Python 3.11 alongside the Dockerfile, without local torch and llama-cpp-python. For AWS ECS, torch and llama-cpp-python are optionally installed in the main Dockerfile
|
| 2 |
+
gradio==6.0.2
|
| 3 |
+
transformers==4.57.2
|
| 4 |
+
spaces==0.42.1
|
| 5 |
+
boto3>=1.42.1
|
| 6 |
+
pandas>=2.3.3
|
| 7 |
+
pyarrow>=21.0.0
|
| 8 |
+
openpyxl>=3.1.5
|
| 9 |
+
markdown>=3.7
|
| 10 |
+
tabulate>=0.9.0
|
| 11 |
+
lxml>=5.3.0
|
| 12 |
+
google-genai>=1.52.0
|
| 13 |
+
openai>=2.8.1
|
| 14 |
+
html5lib>=1.1
|
| 15 |
+
beautifulsoup4>=4.12.3
|
| 16 |
+
rapidfuzz>=3.13.0
|
| 17 |
+
python-dotenv>=1.1.0
|
| 18 |
+
awslambdaric==3.1.1
|
test/README.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test Suite for LLM Topic Modeller
|
| 2 |
+
|
| 3 |
+
This test suite provides comprehensive testing for the CLI interface (`cli_topics.py`) and GUI application (`app.py`).
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The test suite includes:
|
| 8 |
+
- **CLI Tests**: Tests based on examples from the `cli_topics.py` epilog
|
| 9 |
+
- **GUI Tests**: Tests to verify the Gradio interface loads correctly
|
| 10 |
+
- **Mock Inference Server**: A dummy inference-server endpoint that avoids API costs during testing
|
| 11 |
+
|
| 12 |
+
## Structure
|
| 13 |
+
|
| 14 |
+
- `test.py`: Main test suite with CLI tests
|
| 15 |
+
- `test_gui_only.py`: GUI-specific tests
|
| 16 |
+
- `mock_inference_server.py`: Mock HTTP server that mimics an inference-server API
|
| 17 |
+
- `run_tests.py`: Test runner script
|
| 18 |
+
- `__init__.py`: Package initialization
|
| 19 |
+
|
| 20 |
+
## Running Tests
|
| 21 |
+
|
| 22 |
+
### Run All Tests
|
| 23 |
+
|
| 24 |
+
From the project root directory:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
python test/run_tests.py
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Or from the test directory:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python run_tests.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Run Only CLI Tests
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python -m unittest test.test.TestCLITopicsExamples
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Run Only GUI Tests
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python test/test_gui_only.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Mock Inference Server
|
| 49 |
+
|
| 50 |
+
The test suite uses a mock inference server to avoid API costs during testing. The mock server:
|
| 51 |
+
|
| 52 |
+
- Listens on `localhost:8080` by default
|
| 53 |
+
- Responds to `/v1/chat/completions` endpoint
|
| 54 |
+
- Returns valid markdown table responses that satisfy validation requirements
|
| 55 |
+
- Provides token counts for usage tracking
|
| 56 |
+
|
| 57 |
+
The mock server is automatically started before tests and stopped after tests complete.
|
| 58 |
+
|
| 59 |
+
## Test Coverage
|
| 60 |
+
|
| 61 |
+
The CLI tests cover:
|
| 62 |
+
|
| 63 |
+
1. **Topic Extraction**
|
| 64 |
+
- Default settings
|
| 65 |
+
- Custom model and context
|
| 66 |
+
- Grouping by column
|
| 67 |
+
- Zero-shot extraction with candidate topics
|
| 68 |
+
|
| 69 |
+
2. **Topic Deduplication**
|
| 70 |
+
- Fuzzy matching
|
| 71 |
+
- LLM-based deduplication
|
| 72 |
+
|
| 73 |
+
3. **All-in-One Pipeline**
|
| 74 |
+
- Complete workflow (extract, deduplicate, summarise)
|
| 75 |
+
|
| 76 |
+
## Requirements
|
| 77 |
+
|
| 78 |
+
- Python 3.7+
|
| 79 |
+
- All dependencies from `requirements.txt`
|
| 80 |
+
- Example data files in `example_data/` directory
|
| 81 |
+
|
| 82 |
+
## Notes
|
| 83 |
+
|
| 84 |
+
- Tests will be skipped if required example files are not found
|
| 85 |
+
- The mock server must be running for CLI tests to work
|
| 86 |
+
- Tests use temporary output directories that are cleaned up after execution
|
| 87 |
+
|
test/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test suite for LLM Topic Modeller CLI.
|
| 3 |
+
|
| 4 |
+
This package contains tests for the CLI interface and GUI application.
|
| 5 |
+
"""
|
test/mock_inference_server.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Mock inference server for testing CLI topic extraction without API costs.
|
| 4 |
+
|
| 5 |
+
This server mimics an inference-server API endpoint and returns dummy
|
| 6 |
+
responses that satisfy the validation requirements (markdown tables with |).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import threading
|
| 11 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MockInferenceServerHandler(BaseHTTPRequestHandler):
|
| 16 |
+
"""HTTP request handler for the mock inference server."""
|
| 17 |
+
|
| 18 |
+
def _generate_mock_response(self, prompt: str, system_prompt: str) -> str:
|
| 19 |
+
"""
|
| 20 |
+
Generate a mock response that satisfies validation requirements.
|
| 21 |
+
|
| 22 |
+
The response must:
|
| 23 |
+
- Be longer than 120 characters
|
| 24 |
+
- Contain a markdown table (with | characters)
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
prompt: The user prompt
|
| 28 |
+
system_prompt: The system prompt
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A mock markdown table response
|
| 32 |
+
"""
|
| 33 |
+
# Generate a simple markdown table that satisfies the validation
|
| 34 |
+
# This mimics a topic extraction table response
|
| 35 |
+
mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
|
| 36 |
+
|-----------|---------------|-----------|-----------|
|
| 37 |
+
| 1 | Test Topic | Test Subtopic | Positive |
|
| 38 |
+
| 2 | Another Topic | Another Subtopic | Neutral |
|
| 39 |
+
| 3 | Third Topic | Third Subtopic | Negative |
|
| 40 |
+
|
| 41 |
+
This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
|
| 42 |
+
|
| 43 |
+
return mock_table
|
| 44 |
+
|
| 45 |
+
def _estimate_tokens(self, text: str) -> int:
|
| 46 |
+
"""Estimate token count (rough approximation: ~4 characters per token)."""
|
| 47 |
+
return max(1, len(text) // 4)
|
| 48 |
+
|
| 49 |
+
def do_POST(self):
|
| 50 |
+
"""Handle POST requests to /v1/chat/completions."""
|
| 51 |
+
print(f"[Mock Server] Received POST request to: {self.path}")
|
| 52 |
+
if self.path == "/v1/chat/completions":
|
| 53 |
+
try:
|
| 54 |
+
# Read request body
|
| 55 |
+
content_length = int(self.headers.get("Content-Length", 0))
|
| 56 |
+
print(f"[Mock Server] Content-Length: {content_length}")
|
| 57 |
+
body = self.rfile.read(content_length)
|
| 58 |
+
payload = json.loads(body.decode("utf-8"))
|
| 59 |
+
print("[Mock Server] Payload received, processing...")
|
| 60 |
+
|
| 61 |
+
# Extract messages
|
| 62 |
+
messages = payload.get("messages", [])
|
| 63 |
+
system_prompt = ""
|
| 64 |
+
user_prompt = ""
|
| 65 |
+
|
| 66 |
+
for msg in messages:
|
| 67 |
+
role = msg.get("role", "")
|
| 68 |
+
content = msg.get("content", "")
|
| 69 |
+
if role == "system":
|
| 70 |
+
system_prompt = content
|
| 71 |
+
elif role == "user":
|
| 72 |
+
user_prompt = content
|
| 73 |
+
|
| 74 |
+
# Generate mock response
|
| 75 |
+
response_text = self._generate_mock_response(user_prompt, system_prompt)
|
| 76 |
+
|
| 77 |
+
# Estimate tokens
|
| 78 |
+
input_tokens = self._estimate_tokens(system_prompt + "\n" + user_prompt)
|
| 79 |
+
output_tokens = self._estimate_tokens(response_text)
|
| 80 |
+
|
| 81 |
+
# Check if streaming is requested
|
| 82 |
+
stream = payload.get("stream", False)
|
| 83 |
+
|
| 84 |
+
if stream:
|
| 85 |
+
# Handle streaming response
|
| 86 |
+
self.send_response(200)
|
| 87 |
+
self.send_header("Content-Type", "text/event-stream")
|
| 88 |
+
self.send_header("Cache-Control", "no-cache")
|
| 89 |
+
self.send_header("Connection", "keep-alive")
|
| 90 |
+
self.end_headers()
|
| 91 |
+
|
| 92 |
+
# Send streaming chunks
|
| 93 |
+
chunk_size = 20 # Characters per chunk
|
| 94 |
+
for i in range(0, len(response_text), chunk_size):
|
| 95 |
+
chunk = response_text[i : i + chunk_size]
|
| 96 |
+
chunk_data = {
|
| 97 |
+
"choices": [
|
| 98 |
+
{
|
| 99 |
+
"delta": {"content": chunk},
|
| 100 |
+
"index": 0,
|
| 101 |
+
"finish_reason": None,
|
| 102 |
+
}
|
| 103 |
+
]
|
| 104 |
+
}
|
| 105 |
+
self.wfile.write(f"data: {json.dumps(chunk_data)}\n\n".encode())
|
| 106 |
+
self.wfile.flush()
|
| 107 |
+
|
| 108 |
+
# Send final done message
|
| 109 |
+
self.wfile.write(b"data: [DONE]\n\n")
|
| 110 |
+
self.wfile.flush()
|
| 111 |
+
else:
|
| 112 |
+
# Handle non-streaming response
|
| 113 |
+
response_data = {
|
| 114 |
+
"choices": [
|
| 115 |
+
{
|
| 116 |
+
"index": 0,
|
| 117 |
+
"finish_reason": "stop",
|
| 118 |
+
"message": {
|
| 119 |
+
"role": "assistant",
|
| 120 |
+
"content": response_text,
|
| 121 |
+
},
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"usage": {
|
| 125 |
+
"prompt_tokens": input_tokens,
|
| 126 |
+
"completion_tokens": output_tokens,
|
| 127 |
+
"total_tokens": input_tokens + output_tokens,
|
| 128 |
+
},
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
self.send_response(200)
|
| 132 |
+
self.send_header("Content-Type", "application/json")
|
| 133 |
+
self.end_headers()
|
| 134 |
+
self.wfile.write(json.dumps(response_data).encode())
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
self.send_response(500)
|
| 138 |
+
self.send_header("Content-Type", "application/json")
|
| 139 |
+
self.end_headers()
|
| 140 |
+
error_response = {"error": {"message": str(e), "type": "server_error"}}
|
| 141 |
+
self.wfile.write(json.dumps(error_response).encode())
|
| 142 |
+
else:
|
| 143 |
+
self.send_response(404)
|
| 144 |
+
self.end_headers()
|
| 145 |
+
|
| 146 |
+
def log_message(self, format, *args):
|
| 147 |
+
"""Log messages for debugging."""
|
| 148 |
+
# Enable logging for debugging
|
| 149 |
+
print(f"[Mock Server] {format % args}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MockInferenceServer:
|
| 153 |
+
"""Mock inference server that can be started and stopped for testing."""
|
| 154 |
+
|
| 155 |
+
def __init__(self, host: str = "localhost", port: int = 8080):
|
| 156 |
+
"""
|
| 157 |
+
Initialize the mock server.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
host: Host to bind to (default: localhost)
|
| 161 |
+
port: Port to bind to (default: 8080)
|
| 162 |
+
"""
|
| 163 |
+
self.host = host
|
| 164 |
+
self.port = port
|
| 165 |
+
self.server: Optional[HTTPServer] = None
|
| 166 |
+
self.server_thread: Optional[threading.Thread] = None
|
| 167 |
+
self.running = False
|
| 168 |
+
|
| 169 |
+
def start(self):
|
| 170 |
+
"""Start the mock server in a separate thread."""
|
| 171 |
+
if self.running:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
def run_server():
|
| 175 |
+
self.server = HTTPServer((self.host, self.port), MockInferenceServerHandler)
|
| 176 |
+
self.running = True
|
| 177 |
+
self.server.serve_forever()
|
| 178 |
+
|
| 179 |
+
self.server_thread = threading.Thread(target=run_server, daemon=True)
|
| 180 |
+
self.server_thread.start()
|
| 181 |
+
|
| 182 |
+
# Wait a moment for server to start
|
| 183 |
+
import time
|
| 184 |
+
|
| 185 |
+
time.sleep(0.5)
|
| 186 |
+
|
| 187 |
+
def stop(self):
|
| 188 |
+
"""Stop the mock server."""
|
| 189 |
+
if self.server and self.running:
|
| 190 |
+
self.server.shutdown()
|
| 191 |
+
self.server.server_close()
|
| 192 |
+
self.running = False
|
| 193 |
+
|
| 194 |
+
def get_url(self) -> str:
|
| 195 |
+
"""Get the server URL."""
|
| 196 |
+
return f"http://{self.host}:{self.port}"
|
| 197 |
+
|
| 198 |
+
def __enter__(self):
|
| 199 |
+
"""Context manager entry."""
|
| 200 |
+
self.start()
|
| 201 |
+
return self
|
| 202 |
+
|
| 203 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 204 |
+
"""Context manager exit."""
|
| 205 |
+
self.stop()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
# Test the server
|
| 210 |
+
print("Starting mock inference server on http://localhost:8080")
|
| 211 |
+
print("Press Ctrl+C to stop")
|
| 212 |
+
|
| 213 |
+
server = MockInferenceServer()
|
| 214 |
+
try:
|
| 215 |
+
server.start()
|
| 216 |
+
print(f"Server running at {server.get_url()}")
|
| 217 |
+
# Keep running
|
| 218 |
+
while True:
|
| 219 |
+
import time
|
| 220 |
+
|
| 221 |
+
time.sleep(1)
|
| 222 |
+
except KeyboardInterrupt:
|
| 223 |
+
print("\nStopping server...")
|
| 224 |
+
server.stop()
|
| 225 |
+
print("Server stopped")
|
test/mock_llm_calls.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Mock LLM function calls for testing CLI topic extraction without API costs.
|
| 4 |
+
|
| 5 |
+
This module patches requests.post to intercept HTTP calls to inference servers
|
| 6 |
+
and return mock responses instead.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Store original requests if it exists
|
| 13 |
+
_original_requests = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _generate_mock_response(prompt: str, system_prompt: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Generate a mock response that satisfies validation requirements.
|
| 19 |
+
|
| 20 |
+
The response must:
|
| 21 |
+
- Be longer than 120 characters
|
| 22 |
+
- Contain a markdown table (with | characters)
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
prompt: The user prompt
|
| 26 |
+
system_prompt: The system prompt
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
A mock markdown table response
|
| 30 |
+
"""
|
| 31 |
+
# Generate a simple markdown table that satisfies the validation
|
| 32 |
+
# This mimics a topic extraction table response
|
| 33 |
+
mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
|
| 34 |
+
|-----------|---------------|-----------|-----------|
|
| 35 |
+
| 1 | Test Topic | Test Subtopic | Positive |
|
| 36 |
+
| 2 | Another Topic | Another Subtopic | Neutral |
|
| 37 |
+
| 3 | Third Topic | Third Subtopic | Negative |
|
| 38 |
+
|
| 39 |
+
This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
|
| 40 |
+
|
| 41 |
+
return mock_table
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _estimate_tokens(text: str) -> int:
|
| 45 |
+
"""Estimate token count (rough approximation: ~4 characters per token)."""
|
| 46 |
+
return max(1, len(text) // 4)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def mock_requests_post(url, **kwargs):
|
| 50 |
+
"""
|
| 51 |
+
Mock version of requests.post that intercepts inference-server calls.
|
| 52 |
+
|
| 53 |
+
Returns a mock response object that mimics the real requests.Response.
|
| 54 |
+
"""
|
| 55 |
+
# Only mock inference-server URLs
|
| 56 |
+
if "/v1/chat/completions" not in url:
|
| 57 |
+
# For non-inference-server URLs, use real requests
|
| 58 |
+
import requests
|
| 59 |
+
|
| 60 |
+
return requests.post(url, **kwargs)
|
| 61 |
+
|
| 62 |
+
# Extract payload
|
| 63 |
+
payload = kwargs.get("json", {})
|
| 64 |
+
messages = payload.get("messages", [])
|
| 65 |
+
|
| 66 |
+
# Extract prompts
|
| 67 |
+
system_prompt = ""
|
| 68 |
+
user_prompt = ""
|
| 69 |
+
for msg in messages:
|
| 70 |
+
role = msg.get("role", "")
|
| 71 |
+
content = msg.get("content", "")
|
| 72 |
+
if role == "system":
|
| 73 |
+
system_prompt = content
|
| 74 |
+
elif role == "user":
|
| 75 |
+
user_prompt = content
|
| 76 |
+
|
| 77 |
+
# Generate mock response
|
| 78 |
+
response_text = _generate_mock_response(user_prompt, system_prompt)
|
| 79 |
+
|
| 80 |
+
# Estimate tokens
|
| 81 |
+
input_tokens = _estimate_tokens(system_prompt + "\n" + user_prompt)
|
| 82 |
+
output_tokens = _estimate_tokens(response_text)
|
| 83 |
+
|
| 84 |
+
# Check if streaming is requested
|
| 85 |
+
stream = payload.get("stream", False)
|
| 86 |
+
|
| 87 |
+
if stream:
|
| 88 |
+
# For streaming, create a mock response with iter_lines
|
| 89 |
+
class MockStreamResponse:
|
| 90 |
+
def __init__(self, text):
|
| 91 |
+
self.text = text
|
| 92 |
+
self.status_code = 200
|
| 93 |
+
self.lines = []
|
| 94 |
+
# Simulate streaming chunks
|
| 95 |
+
chunk_size = 20
|
| 96 |
+
for i in range(0, len(text), chunk_size):
|
| 97 |
+
chunk = text[i : i + chunk_size]
|
| 98 |
+
chunk_data = {
|
| 99 |
+
"choices": [
|
| 100 |
+
{
|
| 101 |
+
"delta": {"content": chunk},
|
| 102 |
+
"index": 0,
|
| 103 |
+
"finish_reason": None,
|
| 104 |
+
}
|
| 105 |
+
]
|
| 106 |
+
}
|
| 107 |
+
self.lines.append(f"data: {json.dumps(chunk_data)}\n\n".encode())
|
| 108 |
+
self.lines.append(b"data: [DONE]\n\n")
|
| 109 |
+
self._line_index = 0
|
| 110 |
+
|
| 111 |
+
def raise_for_status(self):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def iter_lines(self):
|
| 115 |
+
for line in self.lines:
|
| 116 |
+
yield line
|
| 117 |
+
|
| 118 |
+
return MockStreamResponse(response_text)
|
| 119 |
+
else:
|
| 120 |
+
# For non-streaming, create a simple mock response
|
| 121 |
+
class MockResponse:
|
| 122 |
+
def __init__(self, text, input_tokens, output_tokens):
|
| 123 |
+
self._json_data = {
|
| 124 |
+
"choices": [
|
| 125 |
+
{
|
| 126 |
+
"index": 0,
|
| 127 |
+
"finish_reason": "stop",
|
| 128 |
+
"message": {
|
| 129 |
+
"role": "assistant",
|
| 130 |
+
"content": text,
|
| 131 |
+
},
|
| 132 |
+
}
|
| 133 |
+
],
|
| 134 |
+
"usage": {
|
| 135 |
+
"prompt_tokens": input_tokens,
|
| 136 |
+
"completion_tokens": output_tokens,
|
| 137 |
+
"total_tokens": input_tokens + output_tokens,
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
self.status_code = 200
|
| 141 |
+
|
| 142 |
+
def raise_for_status(self):
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def json(self):
|
| 146 |
+
return self._json_data
|
| 147 |
+
|
| 148 |
+
return MockResponse(response_text, input_tokens, output_tokens)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def apply_mock_patches():
|
| 152 |
+
"""
|
| 153 |
+
Apply patches to mock HTTP requests.
|
| 154 |
+
This should be called before importing modules that use requests.
|
| 155 |
+
"""
|
| 156 |
+
global _original_requests
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
import requests
|
| 160 |
+
|
| 161 |
+
_original_requests = requests.post
|
| 162 |
+
requests.post = mock_requests_post
|
| 163 |
+
print("[Mock] Patched requests.post for inference-server calls")
|
| 164 |
+
except ImportError:
|
| 165 |
+
# requests not imported yet, will be patched when imported
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def restore_original():
|
| 170 |
+
"""Restore original requests.post if it was patched."""
|
| 171 |
+
global _original_requests
|
| 172 |
+
if _original_requests:
|
| 173 |
+
try:
|
| 174 |
+
import requests
|
| 175 |
+
|
| 176 |
+
requests.post = _original_requests
|
| 177 |
+
_original_requests = None
|
| 178 |
+
print("[Mock] Restored original requests.post")
|
| 179 |
+
except ImportError:
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Auto-apply patches if TEST_MODE environment variable is set
|
| 184 |
+
if os.environ.get("TEST_MODE") == "1" or os.environ.get("USE_MOCK_LLM") == "1":
|
| 185 |
+
apply_mock_patches()
|
test/run_tests.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple script to run the CLI topics test suite.
|
| 4 |
+
|
| 5 |
+
This script demonstrates how to run the comprehensive test suite
|
| 6 |
+
that covers all the examples from the CLI epilog.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Add the parent directory to the path so we can import the test module
|
| 13 |
+
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
sys.path.insert(0, parent_dir)
|
| 15 |
+
|
| 16 |
+
# Import test functions
|
| 17 |
+
from test.test import run_all_tests
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
print("Starting LLM Topic Modeller Test Suite...")
|
| 21 |
+
print("This will test:")
|
| 22 |
+
print("- CLI examples from the epilog")
|
| 23 |
+
print("- GUI application functionality")
|
| 24 |
+
print("Using a mock inference-server to avoid API costs.")
|
| 25 |
+
print("=" * 60)
|
| 26 |
+
|
| 27 |
+
success = run_all_tests()
|
| 28 |
+
|
| 29 |
+
if success:
|
| 30 |
+
print("\nπ All tests passed successfully!")
|
| 31 |
+
sys.exit(0)
|
| 32 |
+
else:
|
| 33 |
+
print("\nβ Some tests failed. Check the output above for details.")
|
| 34 |
+
sys.exit(1)
|
test/test.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
import time
|
| 7 |
+
import unittest
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
# Mock LLM calls are automatically applied via environment variables
|
| 11 |
+
# No need to import - the mock patches are applied when USE_MOCK_LLM=1 is set
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_cli_topics(
|
| 15 |
+
script_path: str,
|
| 16 |
+
task: str,
|
| 17 |
+
output_dir: str,
|
| 18 |
+
input_file: Optional[str] = None,
|
| 19 |
+
text_column: Optional[str] = None,
|
| 20 |
+
previous_output_files: Optional[List[str]] = None,
|
| 21 |
+
timeout: int = 600, # 10-minute timeout
|
| 22 |
+
# General Arguments
|
| 23 |
+
username: Optional[str] = None,
|
| 24 |
+
save_to_user_folders: Optional[bool] = None,
|
| 25 |
+
excel_sheets: Optional[List[str]] = None,
|
| 26 |
+
group_by: Optional[str] = None,
|
| 27 |
+
# Model Configuration
|
| 28 |
+
model_choice: Optional[str] = None,
|
| 29 |
+
temperature: Optional[float] = None,
|
| 30 |
+
batch_size: Optional[int] = None,
|
| 31 |
+
max_tokens: Optional[int] = None,
|
| 32 |
+
api_url: Optional[str] = None,
|
| 33 |
+
inference_server_model: Optional[str] = None,
|
| 34 |
+
# Topic Extraction Arguments
|
| 35 |
+
context: Optional[str] = None,
|
| 36 |
+
candidate_topics: Optional[str] = None,
|
| 37 |
+
force_zero_shot: Optional[str] = None,
|
| 38 |
+
force_single_topic: Optional[str] = None,
|
| 39 |
+
produce_structured_summary: Optional[str] = None,
|
| 40 |
+
sentiment: Optional[str] = None,
|
| 41 |
+
additional_summary_instructions: Optional[str] = None,
|
| 42 |
+
# Validation Arguments
|
| 43 |
+
additional_validation_issues: Optional[str] = None,
|
| 44 |
+
show_previous_table: Optional[str] = None,
|
| 45 |
+
output_debug_files: Optional[str] = None,
|
| 46 |
+
max_time_for_loop: Optional[int] = None,
|
| 47 |
+
# Deduplication Arguments
|
| 48 |
+
method: Optional[str] = None,
|
| 49 |
+
similarity_threshold: Optional[int] = None,
|
| 50 |
+
merge_sentiment: Optional[str] = None,
|
| 51 |
+
merge_general_topics: Optional[str] = None,
|
| 52 |
+
# Summarisation Arguments
|
| 53 |
+
summary_format: Optional[str] = None,
|
| 54 |
+
sample_reference_table: Optional[str] = None,
|
| 55 |
+
no_of_sampled_summaries: Optional[int] = None,
|
| 56 |
+
random_seed: Optional[int] = None,
|
| 57 |
+
# Output Format Arguments
|
| 58 |
+
create_xlsx_output: Optional[bool] = None,
|
| 59 |
+
# Logging Arguments
|
| 60 |
+
save_logs_to_csv: Optional[bool] = None,
|
| 61 |
+
save_logs_to_dynamodb: Optional[bool] = None,
|
| 62 |
+
cost_code: Optional[str] = None,
|
| 63 |
+
) -> bool:
|
| 64 |
+
"""
|
| 65 |
+
Executes the cli_topics.py script with specified arguments using a subprocess.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
script_path (str): The path to the cli_topics.py script.
|
| 69 |
+
task (str): The main task to perform ('extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one').
|
| 70 |
+
output_dir (str): The path to the directory for output files.
|
| 71 |
+
input_file (str, optional): Path to the input file to process.
|
| 72 |
+
text_column (str, optional): Name of the text column to process.
|
| 73 |
+
previous_output_files (List[str], optional): Path(s) to previous output files.
|
| 74 |
+
timeout (int): Timeout in seconds for the subprocess.
|
| 75 |
+
|
| 76 |
+
All other arguments match the CLI arguments from cli_topics.py.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
bool: True if the script executed successfully, False otherwise.
|
| 80 |
+
"""
|
| 81 |
+
# 1. Get absolute paths and perform pre-checks
|
| 82 |
+
script_abs_path = os.path.abspath(script_path)
|
| 83 |
+
output_abs_dir = os.path.abspath(output_dir)
|
| 84 |
+
|
| 85 |
+
# Handle input file based on task
|
| 86 |
+
if task in ["extract", "validate", "all_in_one"] and input_file is None:
|
| 87 |
+
raise ValueError(f"Input file is required for '{task}' task")
|
| 88 |
+
|
| 89 |
+
if input_file:
|
| 90 |
+
input_abs_path = os.path.abspath(input_file)
|
| 91 |
+
if not os.path.isfile(input_abs_path):
|
| 92 |
+
raise FileNotFoundError(f"Input file not found: {input_abs_path}")
|
| 93 |
+
|
| 94 |
+
if not os.path.isfile(script_abs_path):
|
| 95 |
+
raise FileNotFoundError(f"Script not found: {script_abs_path}")
|
| 96 |
+
|
| 97 |
+
if not os.path.isdir(output_abs_dir):
|
| 98 |
+
# Create the output directory if it doesn't exist
|
| 99 |
+
print(f"Output directory not found. Creating: {output_abs_dir}")
|
| 100 |
+
os.makedirs(output_abs_dir)
|
| 101 |
+
|
| 102 |
+
script_folder = os.path.dirname(script_abs_path)
|
| 103 |
+
|
| 104 |
+
# 2. Dynamically build the command list
|
| 105 |
+
command = [
|
| 106 |
+
"python",
|
| 107 |
+
script_abs_path,
|
| 108 |
+
"--output_dir",
|
| 109 |
+
output_abs_dir,
|
| 110 |
+
"--task",
|
| 111 |
+
task,
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
# Add input_file only if it's not None
|
| 115 |
+
if input_file:
|
| 116 |
+
command.extend(["--input_file", input_abs_path])
|
| 117 |
+
|
| 118 |
+
# Add general arguments
|
| 119 |
+
if text_column:
|
| 120 |
+
command.extend(["--text_column", text_column])
|
| 121 |
+
if previous_output_files:
|
| 122 |
+
command.extend(["--previous_output_files"] + previous_output_files)
|
| 123 |
+
if username:
|
| 124 |
+
command.extend(["--username", username])
|
| 125 |
+
if save_to_user_folders is not None:
|
| 126 |
+
command.extend(["--save_to_user_folders", str(save_to_user_folders)])
|
| 127 |
+
if excel_sheets:
|
| 128 |
+
command.append("--excel_sheets")
|
| 129 |
+
command.extend(excel_sheets)
|
| 130 |
+
if group_by:
|
| 131 |
+
command.extend(["--group_by", group_by])
|
| 132 |
+
|
| 133 |
+
# Add model configuration arguments
|
| 134 |
+
if model_choice:
|
| 135 |
+
command.extend(["--model_choice", model_choice])
|
| 136 |
+
if temperature is not None:
|
| 137 |
+
command.extend(["--temperature", str(temperature)])
|
| 138 |
+
if batch_size is not None:
|
| 139 |
+
command.extend(["--batch_size", str(batch_size)])
|
| 140 |
+
if max_tokens is not None:
|
| 141 |
+
command.extend(["--max_tokens", str(max_tokens)])
|
| 142 |
+
if api_url:
|
| 143 |
+
command.extend(["--api_url", api_url])
|
| 144 |
+
if inference_server_model:
|
| 145 |
+
command.extend(["--inference_server_model", inference_server_model])
|
| 146 |
+
|
| 147 |
+
# Add topic extraction arguments
|
| 148 |
+
if context:
|
| 149 |
+
command.extend(["--context", context])
|
| 150 |
+
if candidate_topics:
|
| 151 |
+
command.extend(["--candidate_topics", candidate_topics])
|
| 152 |
+
if force_zero_shot:
|
| 153 |
+
command.extend(["--force_zero_shot", force_zero_shot])
|
| 154 |
+
if force_single_topic:
|
| 155 |
+
command.extend(["--force_single_topic", force_single_topic])
|
| 156 |
+
if produce_structured_summary:
|
| 157 |
+
command.extend(["--produce_structured_summary", produce_structured_summary])
|
| 158 |
+
if sentiment:
|
| 159 |
+
command.extend(["--sentiment", sentiment])
|
| 160 |
+
if additional_summary_instructions:
|
| 161 |
+
command.extend(
|
| 162 |
+
["--additional_summary_instructions", additional_summary_instructions]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Add validation arguments
|
| 166 |
+
if additional_validation_issues:
|
| 167 |
+
command.extend(["--additional_validation_issues", additional_validation_issues])
|
| 168 |
+
if show_previous_table:
|
| 169 |
+
command.extend(["--show_previous_table", show_previous_table])
|
| 170 |
+
if output_debug_files:
|
| 171 |
+
command.extend(["--output_debug_files", output_debug_files])
|
| 172 |
+
if max_time_for_loop is not None:
|
| 173 |
+
command.extend(["--max_time_for_loop", str(max_time_for_loop)])
|
| 174 |
+
|
| 175 |
+
# Add deduplication arguments
|
| 176 |
+
if method:
|
| 177 |
+
command.extend(["--method", method])
|
| 178 |
+
if similarity_threshold is not None:
|
| 179 |
+
command.extend(["--similarity_threshold", str(similarity_threshold)])
|
| 180 |
+
if merge_sentiment:
|
| 181 |
+
command.extend(["--merge_sentiment", merge_sentiment])
|
| 182 |
+
if merge_general_topics:
|
| 183 |
+
command.extend(["--merge_general_topics", merge_general_topics])
|
| 184 |
+
|
| 185 |
+
# Add summarisation arguments
|
| 186 |
+
if summary_format:
|
| 187 |
+
command.extend(["--summary_format", summary_format])
|
| 188 |
+
if sample_reference_table:
|
| 189 |
+
command.extend(["--sample_reference_table", sample_reference_table])
|
| 190 |
+
if no_of_sampled_summaries is not None:
|
| 191 |
+
command.extend(["--no_of_sampled_summaries", str(no_of_sampled_summaries)])
|
| 192 |
+
if random_seed is not None:
|
| 193 |
+
command.extend(["--random_seed", str(random_seed)])
|
| 194 |
+
|
| 195 |
+
# Add output format arguments
|
| 196 |
+
if create_xlsx_output is False:
|
| 197 |
+
command.append("--no_xlsx_output")
|
| 198 |
+
|
| 199 |
+
# Add logging arguments
|
| 200 |
+
if save_logs_to_csv is not None:
|
| 201 |
+
command.extend(["--save_logs_to_csv", str(save_logs_to_csv)])
|
| 202 |
+
if save_logs_to_dynamodb is not None:
|
| 203 |
+
command.extend(["--save_logs_to_dynamodb", str(save_logs_to_dynamodb)])
|
| 204 |
+
if cost_code:
|
| 205 |
+
command.extend(["--cost_code", cost_code])
|
| 206 |
+
|
| 207 |
+
# Filter out None values before joining
|
| 208 |
+
command_str = " ".join(str(arg) for arg in command if arg is not None)
|
| 209 |
+
print(f"Executing command: {command_str}")
|
| 210 |
+
|
| 211 |
+
# 3. Execute the command using subprocess
|
| 212 |
+
try:
|
| 213 |
+
# Use unbuffered output to avoid hanging
|
| 214 |
+
env = os.environ.copy()
|
| 215 |
+
env["PYTHONUNBUFFERED"] = "1"
|
| 216 |
+
# Ensure inference server is enabled for testing
|
| 217 |
+
env["RUN_INFERENCE_SERVER"] = "1"
|
| 218 |
+
# Enable mock mode
|
| 219 |
+
env["USE_MOCK_LLM"] = "1"
|
| 220 |
+
env["TEST_MODE"] = "1"
|
| 221 |
+
|
| 222 |
+
result = subprocess.Popen(
|
| 223 |
+
command,
|
| 224 |
+
stdout=subprocess.PIPE,
|
| 225 |
+
stderr=subprocess.STDOUT, # Combine stderr with stdout to avoid deadlocks
|
| 226 |
+
text=True,
|
| 227 |
+
cwd=script_folder, # Important for relative paths within the script
|
| 228 |
+
env=env,
|
| 229 |
+
bufsize=0, # Unbuffered
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Read output in real-time to avoid deadlocks
|
| 233 |
+
start_time = time.time()
|
| 234 |
+
|
| 235 |
+
# For Windows, we need a different approach
|
| 236 |
+
if sys.platform == "win32":
|
| 237 |
+
# On Windows, use communicate with timeout
|
| 238 |
+
try:
|
| 239 |
+
stdout, stderr = result.communicate(timeout=timeout)
|
| 240 |
+
except subprocess.TimeoutExpired:
|
| 241 |
+
result.kill()
|
| 242 |
+
stdout, stderr = result.communicate()
|
| 243 |
+
raise subprocess.TimeoutExpired(result.args, timeout)
|
| 244 |
+
else:
|
| 245 |
+
# On Unix, we can use select for real-time reading
|
| 246 |
+
import select
|
| 247 |
+
|
| 248 |
+
stdout_lines = []
|
| 249 |
+
while result.poll() is None:
|
| 250 |
+
ready, _, _ = select.select([result.stdout], [], [], 0.1)
|
| 251 |
+
if ready:
|
| 252 |
+
line = result.stdout.readline()
|
| 253 |
+
if line:
|
| 254 |
+
print(line.rstrip(), flush=True)
|
| 255 |
+
stdout_lines.append(line)
|
| 256 |
+
# Check timeout
|
| 257 |
+
if time.time() - start_time > timeout:
|
| 258 |
+
result.kill()
|
| 259 |
+
raise subprocess.TimeoutExpired(result.args, timeout)
|
| 260 |
+
|
| 261 |
+
# Read remaining output
|
| 262 |
+
remaining = result.stdout.read()
|
| 263 |
+
if remaining:
|
| 264 |
+
print(remaining, end="", flush=True)
|
| 265 |
+
stdout_lines.append(remaining)
|
| 266 |
+
|
| 267 |
+
stdout = "".join(stdout_lines)
|
| 268 |
+
stderr = "" # Combined with stdout
|
| 269 |
+
|
| 270 |
+
print("--- SCRIPT STDOUT ---")
|
| 271 |
+
if stdout:
|
| 272 |
+
print(stdout)
|
| 273 |
+
print("--- SCRIPT STDERR ---")
|
| 274 |
+
if stderr:
|
| 275 |
+
print(stderr)
|
| 276 |
+
print("---------------------")
|
| 277 |
+
|
| 278 |
+
# Analyze the output for errors and success indicators
|
| 279 |
+
analysis = analyze_test_output(stdout, stderr)
|
| 280 |
+
|
| 281 |
+
if analysis["has_errors"]:
|
| 282 |
+
print("β Errors detected in output:")
|
| 283 |
+
for i, error_type in enumerate(analysis["error_types"]):
|
| 284 |
+
print(f" {i+1}. {error_type}")
|
| 285 |
+
if analysis["error_messages"]:
|
| 286 |
+
print(" Error messages:")
|
| 287 |
+
for msg in analysis["error_messages"][
|
| 288 |
+
:3
|
| 289 |
+
]: # Show first 3 error messages
|
| 290 |
+
print(f" - {msg}")
|
| 291 |
+
return False
|
| 292 |
+
elif result.returncode == 0:
|
| 293 |
+
success_msg = "β
Script executed successfully."
|
| 294 |
+
if analysis["success_indicators"]:
|
| 295 |
+
success_msg += f" (Success indicators: {', '.join(analysis['success_indicators'][:3])})"
|
| 296 |
+
print(success_msg)
|
| 297 |
+
return True
|
| 298 |
+
else:
|
| 299 |
+
print(f"β Command failed with return code {result.returncode}")
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
except subprocess.TimeoutExpired:
|
| 303 |
+
result.kill()
|
| 304 |
+
print(f"β Subprocess timed out after {timeout} seconds.")
|
| 305 |
+
return False
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"β An unexpected error occurred: {e}")
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def analyze_test_output(stdout: str, stderr: str) -> dict:
|
| 312 |
+
"""
|
| 313 |
+
Analyze test output to provide detailed error information.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
stdout (str): Standard output from the test
|
| 317 |
+
stderr (str): Standard error from the test
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
dict: Analysis results with error details
|
| 321 |
+
"""
|
| 322 |
+
combined_output = (stdout or "") + (stderr or "")
|
| 323 |
+
|
| 324 |
+
analysis = {
|
| 325 |
+
"has_errors": False,
|
| 326 |
+
"error_types": [],
|
| 327 |
+
"error_messages": [],
|
| 328 |
+
"success_indicators": [],
|
| 329 |
+
"warning_indicators": [],
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
# Error patterns
|
| 333 |
+
error_patterns = {
|
| 334 |
+
"An error occurred": "General error message",
|
| 335 |
+
"Error:": "Error prefix",
|
| 336 |
+
"Exception:": "Exception occurred",
|
| 337 |
+
"Traceback": "Python traceback",
|
| 338 |
+
"Failed to": "Operation failure",
|
| 339 |
+
"Cannot": "Operation not possible",
|
| 340 |
+
"Unable to": "Operation not possible",
|
| 341 |
+
"KeyError:": "Missing key/dictionary error",
|
| 342 |
+
"AttributeError:": "Missing attribute error",
|
| 343 |
+
"TypeError:": "Type mismatch error",
|
| 344 |
+
"ValueError:": "Invalid value error",
|
| 345 |
+
"FileNotFoundError:": "File not found",
|
| 346 |
+
"ImportError:": "Import failure",
|
| 347 |
+
"ModuleNotFoundError:": "Module not found",
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
# Success indicators
|
| 351 |
+
success_patterns = [
|
| 352 |
+
"Successfully",
|
| 353 |
+
"Completed",
|
| 354 |
+
"Finished",
|
| 355 |
+
"Processed",
|
| 356 |
+
"Complete",
|
| 357 |
+
"Output files saved",
|
| 358 |
+
]
|
| 359 |
+
|
| 360 |
+
# Warning indicators
|
| 361 |
+
warning_patterns = ["Warning:", "WARNING:", "Deprecated", "DeprecationWarning"]
|
| 362 |
+
|
| 363 |
+
# Check for errors
|
| 364 |
+
for pattern, description in error_patterns.items():
|
| 365 |
+
if pattern.lower() in combined_output.lower():
|
| 366 |
+
analysis["has_errors"] = True
|
| 367 |
+
analysis["error_types"].append(description)
|
| 368 |
+
|
| 369 |
+
# Extract the actual error message
|
| 370 |
+
lines = combined_output.split("\n")
|
| 371 |
+
for line in lines:
|
| 372 |
+
if pattern.lower() in line.lower():
|
| 373 |
+
analysis["error_messages"].append(line.strip())
|
| 374 |
+
|
| 375 |
+
# Check for success indicators
|
| 376 |
+
for pattern in success_patterns:
|
| 377 |
+
if pattern.lower() in combined_output.lower():
|
| 378 |
+
analysis["success_indicators"].append(pattern)
|
| 379 |
+
|
| 380 |
+
# Check for warnings
|
| 381 |
+
for pattern in warning_patterns:
|
| 382 |
+
if pattern.lower() in combined_output.lower():
|
| 383 |
+
analysis["warning_indicators"].append(pattern)
|
| 384 |
+
|
| 385 |
+
return analysis
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def run_app_direct_mode(
|
| 389 |
+
app_path: str,
|
| 390 |
+
task: str,
|
| 391 |
+
output_dir: str,
|
| 392 |
+
input_file: Optional[str] = None,
|
| 393 |
+
text_column: Optional[str] = None,
|
| 394 |
+
previous_output_files: Optional[List[str]] = None,
|
| 395 |
+
timeout: int = 600,
|
| 396 |
+
# General Arguments
|
| 397 |
+
username: Optional[str] = None,
|
| 398 |
+
save_to_user_folders: Optional[bool] = None,
|
| 399 |
+
excel_sheets: Optional[List[str]] = None,
|
| 400 |
+
group_by: Optional[str] = None,
|
| 401 |
+
# Model Configuration
|
| 402 |
+
model_choice: Optional[str] = None,
|
| 403 |
+
temperature: Optional[float] = None,
|
| 404 |
+
batch_size: Optional[int] = None,
|
| 405 |
+
max_tokens: Optional[int] = None,
|
| 406 |
+
api_url: Optional[str] = None,
|
| 407 |
+
inference_server_model: Optional[str] = None,
|
| 408 |
+
# Topic Extraction Arguments
|
| 409 |
+
context: Optional[str] = None,
|
| 410 |
+
candidate_topics: Optional[str] = None,
|
| 411 |
+
force_zero_shot: Optional[str] = None,
|
| 412 |
+
force_single_topic: Optional[str] = None,
|
| 413 |
+
produce_structured_summary: Optional[str] = None,
|
| 414 |
+
sentiment: Optional[str] = None,
|
| 415 |
+
additional_summary_instructions: Optional[str] = None,
|
| 416 |
+
# Validation Arguments
|
| 417 |
+
additional_validation_issues: Optional[str] = None,
|
| 418 |
+
show_previous_table: Optional[str] = None,
|
| 419 |
+
output_debug_files: Optional[str] = None,
|
| 420 |
+
max_time_for_loop: Optional[int] = None,
|
| 421 |
+
# Deduplication Arguments
|
| 422 |
+
method: Optional[str] = None,
|
| 423 |
+
similarity_threshold: Optional[int] = None,
|
| 424 |
+
merge_sentiment: Optional[str] = None,
|
| 425 |
+
merge_general_topics: Optional[str] = None,
|
| 426 |
+
# Summarisation Arguments
|
| 427 |
+
summary_format: Optional[str] = None,
|
| 428 |
+
sample_reference_table: Optional[str] = None,
|
| 429 |
+
no_of_sampled_summaries: Optional[int] = None,
|
| 430 |
+
random_seed: Optional[int] = None,
|
| 431 |
+
# Output Format Arguments
|
| 432 |
+
create_xlsx_output: Optional[bool] = None,
|
| 433 |
+
# Logging Arguments
|
| 434 |
+
save_logs_to_csv: Optional[bool] = None,
|
| 435 |
+
save_logs_to_dynamodb: Optional[bool] = None,
|
| 436 |
+
cost_code: Optional[str] = None,
|
| 437 |
+
) -> bool:
|
| 438 |
+
"""
|
| 439 |
+
Executes the app.py script in direct mode with specified environment variables.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
app_path (str): The path to the app.py script.
|
| 443 |
+
task (str): The main task to perform ('extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one').
|
| 444 |
+
output_dir (str): The path to the directory for output files.
|
| 445 |
+
input_file (str, optional): Path to the input file to process.
|
| 446 |
+
text_column (str, optional): Name of the text column to process.
|
| 447 |
+
previous_output_files (List[str], optional): Path(s) to previous output files.
|
| 448 |
+
timeout (int): Timeout in seconds for the subprocess.
|
| 449 |
+
|
| 450 |
+
All other arguments match the CLI arguments from cli_topics.py, but are set as environment variables.
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
bool: True if the script executed successfully, False otherwise.
|
| 454 |
+
"""
|
| 455 |
+
# 1. Get absolute paths and perform pre-checks
|
| 456 |
+
app_abs_path = os.path.abspath(app_path)
|
| 457 |
+
output_abs_dir = os.path.abspath(output_dir)
|
| 458 |
+
|
| 459 |
+
# Handle input file based on task
|
| 460 |
+
if task in ["extract", "validate", "all_in_one"] and input_file is None:
|
| 461 |
+
raise ValueError(f"Input file is required for '{task}' task")
|
| 462 |
+
|
| 463 |
+
if input_file:
|
| 464 |
+
input_abs_path = os.path.abspath(input_file)
|
| 465 |
+
if not os.path.isfile(input_abs_path):
|
| 466 |
+
raise FileNotFoundError(f"Input file not found: {input_abs_path}")
|
| 467 |
+
|
| 468 |
+
if not os.path.isfile(app_abs_path):
|
| 469 |
+
raise FileNotFoundError(f"App script not found: {app_abs_path}")
|
| 470 |
+
|
| 471 |
+
if not os.path.isdir(output_abs_dir):
|
| 472 |
+
# Create the output directory if it doesn't exist
|
| 473 |
+
print(f"Output directory not found. Creating: {output_abs_dir}")
|
| 474 |
+
os.makedirs(output_abs_dir)
|
| 475 |
+
|
| 476 |
+
script_folder = os.path.dirname(app_abs_path)
|
| 477 |
+
|
| 478 |
+
# 2. Build environment variables for direct mode
|
| 479 |
+
env = os.environ.copy()
|
| 480 |
+
env["PYTHONUNBUFFERED"] = "1"
|
| 481 |
+
env["RUN_INFERENCE_SERVER"] = "1"
|
| 482 |
+
env["USE_MOCK_LLM"] = "1"
|
| 483 |
+
env["TEST_MODE"] = "1"
|
| 484 |
+
|
| 485 |
+
# Enable direct mode
|
| 486 |
+
env["RUN_DIRECT_MODE"] = "1"
|
| 487 |
+
|
| 488 |
+
# Task selection
|
| 489 |
+
env["DIRECT_MODE_TASK"] = task
|
| 490 |
+
|
| 491 |
+
# General arguments
|
| 492 |
+
if input_file:
|
| 493 |
+
# Use pipe separator to handle file paths with spaces
|
| 494 |
+
env["DIRECT_MODE_INPUT_FILE"] = input_abs_path
|
| 495 |
+
env["DIRECT_MODE_OUTPUT_DIR"] = output_abs_dir
|
| 496 |
+
if text_column:
|
| 497 |
+
env["DIRECT_MODE_TEXT_COLUMN"] = text_column
|
| 498 |
+
if previous_output_files:
|
| 499 |
+
# Use pipe separator to handle file paths with spaces
|
| 500 |
+
env["DIRECT_MODE_PREVIOUS_OUTPUT_FILES"] = "|".join(previous_output_files)
|
| 501 |
+
if username:
|
| 502 |
+
env["DIRECT_MODE_USERNAME"] = username
|
| 503 |
+
if save_to_user_folders is not None:
|
| 504 |
+
env["SESSION_OUTPUT_FOLDER"] = str(save_to_user_folders)
|
| 505 |
+
if excel_sheets:
|
| 506 |
+
env["DIRECT_MODE_EXCEL_SHEETS"] = ",".join(excel_sheets)
|
| 507 |
+
if group_by:
|
| 508 |
+
env["DIRECT_MODE_GROUP_BY"] = group_by
|
| 509 |
+
|
| 510 |
+
# Model configuration
|
| 511 |
+
if model_choice:
|
| 512 |
+
env["DIRECT_MODE_MODEL_CHOICE"] = model_choice
|
| 513 |
+
if temperature is not None:
|
| 514 |
+
env["DIRECT_MODE_TEMPERATURE"] = str(temperature)
|
| 515 |
+
if batch_size is not None:
|
| 516 |
+
env["DIRECT_MODE_BATCH_SIZE"] = str(batch_size)
|
| 517 |
+
if max_tokens is not None:
|
| 518 |
+
env["DIRECT_MODE_MAX_TOKENS"] = str(max_tokens)
|
| 519 |
+
if api_url:
|
| 520 |
+
env["API_URL"] = api_url
|
| 521 |
+
if inference_server_model:
|
| 522 |
+
env["DIRECT_MODE_INFERENCE_SERVER_MODEL"] = inference_server_model
|
| 523 |
+
|
| 524 |
+
# Topic extraction arguments
|
| 525 |
+
if context:
|
| 526 |
+
env["DIRECT_MODE_CONTEXT"] = context
|
| 527 |
+
if candidate_topics:
|
| 528 |
+
env["DIRECT_MODE_CANDIDATE_TOPICS"] = candidate_topics
|
| 529 |
+
if force_zero_shot:
|
| 530 |
+
env["DIRECT_MODE_FORCE_ZERO_SHOT"] = force_zero_shot
|
| 531 |
+
if force_single_topic:
|
| 532 |
+
env["DIRECT_MODE_FORCE_SINGLE_TOPIC"] = force_single_topic
|
| 533 |
+
if produce_structured_summary:
|
| 534 |
+
env["DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY"] = produce_structured_summary
|
| 535 |
+
if sentiment:
|
| 536 |
+
env["DIRECT_MODE_SENTIMENT"] = sentiment
|
| 537 |
+
if additional_summary_instructions:
|
| 538 |
+
env["DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS"] = (
|
| 539 |
+
additional_summary_instructions
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Validation arguments
|
| 543 |
+
if additional_validation_issues:
|
| 544 |
+
env["DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES"] = additional_validation_issues
|
| 545 |
+
if show_previous_table:
|
| 546 |
+
env["DIRECT_MODE_SHOW_PREVIOUS_TABLE"] = show_previous_table
|
| 547 |
+
if output_debug_files:
|
| 548 |
+
env["OUTPUT_DEBUG_FILES"] = output_debug_files
|
| 549 |
+
if max_time_for_loop is not None:
|
| 550 |
+
env["DIRECT_MODE_MAX_TIME_FOR_LOOP"] = str(max_time_for_loop)
|
| 551 |
+
|
| 552 |
+
# Deduplication arguments
|
| 553 |
+
if method:
|
| 554 |
+
env["DIRECT_MODE_DEDUP_METHOD"] = method
|
| 555 |
+
if similarity_threshold is not None:
|
| 556 |
+
env["DIRECT_MODE_SIMILARITY_THRESHOLD"] = str(similarity_threshold)
|
| 557 |
+
if merge_sentiment:
|
| 558 |
+
env["DIRECT_MODE_MERGE_SENTIMENT"] = merge_sentiment
|
| 559 |
+
if merge_general_topics:
|
| 560 |
+
env["DIRECT_MODE_MERGE_GENERAL_TOPICS"] = merge_general_topics
|
| 561 |
+
|
| 562 |
+
# Summarisation arguments
|
| 563 |
+
if summary_format:
|
| 564 |
+
env["DIRECT_MODE_SUMMARY_FORMAT"] = summary_format
|
| 565 |
+
if sample_reference_table:
|
| 566 |
+
env["DIRECT_MODE_SAMPLE_REFERENCE_TABLE"] = sample_reference_table
|
| 567 |
+
if no_of_sampled_summaries is not None:
|
| 568 |
+
env["DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES"] = str(no_of_sampled_summaries)
|
| 569 |
+
if random_seed is not None:
|
| 570 |
+
env["DIRECT_MODE_RANDOM_SEED"] = str(random_seed)
|
| 571 |
+
|
| 572 |
+
# Output format arguments
|
| 573 |
+
if create_xlsx_output is not None:
|
| 574 |
+
env["DIRECT_MODE_CREATE_XLSX_OUTPUT"] = str(create_xlsx_output)
|
| 575 |
+
|
| 576 |
+
# Logging arguments
|
| 577 |
+
if save_logs_to_csv is not None:
|
| 578 |
+
env["SAVE_LOGS_TO_CSV"] = str(save_logs_to_csv)
|
| 579 |
+
if save_logs_to_dynamodb is not None:
|
| 580 |
+
env["SAVE_LOGS_TO_DYNAMODB"] = str(save_logs_to_dynamodb)
|
| 581 |
+
if cost_code:
|
| 582 |
+
env["DEFAULT_COST_CODE"] = cost_code
|
| 583 |
+
|
| 584 |
+
# 3. Build command (just run app.py, no arguments needed in direct mode)
|
| 585 |
+
command = ["python", app_abs_path]
|
| 586 |
+
command_str = " ".join(str(arg) for arg in command)
|
| 587 |
+
print(f"Executing direct mode command: {command_str}")
|
| 588 |
+
print(f"Direct mode task: {task}")
|
| 589 |
+
if input_file:
|
| 590 |
+
print(f"Input file: {input_abs_path}")
|
| 591 |
+
if text_column:
|
| 592 |
+
print(f"Text column: {text_column}")
|
| 593 |
+
|
| 594 |
+
# 4. Execute the command using subprocess
|
| 595 |
+
try:
|
| 596 |
+
result = subprocess.Popen(
|
| 597 |
+
command,
|
| 598 |
+
stdout=subprocess.PIPE,
|
| 599 |
+
stderr=subprocess.STDOUT, # Combine stderr with stdout to avoid deadlocks
|
| 600 |
+
text=True,
|
| 601 |
+
cwd=script_folder, # Important for relative paths within the script
|
| 602 |
+
env=env,
|
| 603 |
+
bufsize=0, # Unbuffered
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Read output in real-time to avoid deadlocks
|
| 607 |
+
start_time = time.time()
|
| 608 |
+
|
| 609 |
+
# For Windows, we need a different approach
|
| 610 |
+
if sys.platform == "win32":
|
| 611 |
+
# On Windows, use communicate with timeout
|
| 612 |
+
try:
|
| 613 |
+
stdout, stderr = result.communicate(timeout=timeout)
|
| 614 |
+
except subprocess.TimeoutExpired:
|
| 615 |
+
result.kill()
|
| 616 |
+
stdout, stderr = result.communicate()
|
| 617 |
+
raise subprocess.TimeoutExpired(result.args, timeout)
|
| 618 |
+
else:
|
| 619 |
+
# On Unix, we can use select for real-time reading
|
| 620 |
+
import select
|
| 621 |
+
|
| 622 |
+
stdout_lines = []
|
| 623 |
+
while result.poll() is None:
|
| 624 |
+
ready, _, _ = select.select([result.stdout], [], [], 0.1)
|
| 625 |
+
if ready:
|
| 626 |
+
line = result.stdout.readline()
|
| 627 |
+
if line:
|
| 628 |
+
print(line.rstrip(), flush=True)
|
| 629 |
+
stdout_lines.append(line)
|
| 630 |
+
# Check timeout
|
| 631 |
+
if time.time() - start_time > timeout:
|
| 632 |
+
result.kill()
|
| 633 |
+
raise subprocess.TimeoutExpired(result.args, timeout)
|
| 634 |
+
|
| 635 |
+
# Read remaining output
|
| 636 |
+
remaining = result.stdout.read()
|
| 637 |
+
if remaining:
|
| 638 |
+
print(remaining, end="", flush=True)
|
| 639 |
+
stdout_lines.append(remaining)
|
| 640 |
+
|
| 641 |
+
stdout = "".join(stdout_lines)
|
| 642 |
+
stderr = "" # Combined with stdout
|
| 643 |
+
|
| 644 |
+
print("--- SCRIPT STDOUT ---")
|
| 645 |
+
if stdout:
|
| 646 |
+
print(stdout)
|
| 647 |
+
print("--- SCRIPT STDERR ---")
|
| 648 |
+
if stderr:
|
| 649 |
+
print(stderr)
|
| 650 |
+
print("---------------------")
|
| 651 |
+
|
| 652 |
+
# Analyze the output for errors and success indicators
|
| 653 |
+
analysis = analyze_test_output(stdout, stderr)
|
| 654 |
+
|
| 655 |
+
if analysis["has_errors"]:
|
| 656 |
+
print("β Errors detected in output:")
|
| 657 |
+
for i, error_type in enumerate(analysis["error_types"]):
|
| 658 |
+
print(f" {i+1}. {error_type}")
|
| 659 |
+
if analysis["error_messages"]:
|
| 660 |
+
print(" Error messages:")
|
| 661 |
+
for msg in analysis["error_messages"][
|
| 662 |
+
:3
|
| 663 |
+
]: # Show first 3 error messages
|
| 664 |
+
print(f" - {msg}")
|
| 665 |
+
return False
|
| 666 |
+
elif result.returncode == 0:
|
| 667 |
+
success_msg = "β
Script executed successfully."
|
| 668 |
+
if analysis["success_indicators"]:
|
| 669 |
+
success_msg += f" (Success indicators: {', '.join(analysis['success_indicators'][:3])})"
|
| 670 |
+
print(success_msg)
|
| 671 |
+
return True
|
| 672 |
+
else:
|
| 673 |
+
print(f"β Command failed with return code {result.returncode}")
|
| 674 |
+
return False
|
| 675 |
+
|
| 676 |
+
except subprocess.TimeoutExpired:
|
| 677 |
+
result.kill()
|
| 678 |
+
print(f"β Subprocess timed out after {timeout} seconds.")
|
| 679 |
+
return False
|
| 680 |
+
except Exception as e:
|
| 681 |
+
print(f"β An unexpected error occurred: {e}")
|
| 682 |
+
return False
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class TestCLITopicsExamples(unittest.TestCase):
|
| 686 |
+
"""Test suite for CLI topic extraction examples from the epilog."""
|
| 687 |
+
|
| 688 |
+
@classmethod
|
| 689 |
+
def setUpClass(cls):
|
| 690 |
+
"""Set up test environment before running tests."""
|
| 691 |
+
cls.script_path = os.path.join(
|
| 692 |
+
os.path.dirname(os.path.dirname(__file__)), "cli_topics.py"
|
| 693 |
+
)
|
| 694 |
+
cls.example_data_dir = os.path.join(
|
| 695 |
+
os.path.dirname(os.path.dirname(__file__)), "example_data"
|
| 696 |
+
)
|
| 697 |
+
cls.temp_output_dir = tempfile.mkdtemp(prefix="test_output_")
|
| 698 |
+
|
| 699 |
+
# Verify script exists
|
| 700 |
+
if not os.path.isfile(cls.script_path):
|
| 701 |
+
raise FileNotFoundError(f"CLI script not found: {cls.script_path}")
|
| 702 |
+
|
| 703 |
+
print(f"Test setup complete. Script: {cls.script_path}")
|
| 704 |
+
print(f"Example data directory: {cls.example_data_dir}")
|
| 705 |
+
print(f"Temp output directory: {cls.temp_output_dir}")
|
| 706 |
+
print("Using function mocking instead of HTTP server")
|
| 707 |
+
|
| 708 |
+
# Debug: Check if example data directory exists and list contents
|
| 709 |
+
if os.path.exists(cls.example_data_dir):
|
| 710 |
+
print("Example data directory exists. Contents:")
|
| 711 |
+
for item in os.listdir(cls.example_data_dir):
|
| 712 |
+
item_path = os.path.join(cls.example_data_dir, item)
|
| 713 |
+
if os.path.isfile(item_path):
|
| 714 |
+
print(f" File: {item} ({os.path.getsize(item_path)} bytes)")
|
| 715 |
+
else:
|
| 716 |
+
print(f" Directory: {item}")
|
| 717 |
+
else:
|
| 718 |
+
print(f"Example data directory does not exist: {cls.example_data_dir}")
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def tearDownClass(cls):
|
| 722 |
+
"""Clean up test environment after running tests."""
|
| 723 |
+
if os.path.exists(cls.temp_output_dir):
|
| 724 |
+
shutil.rmtree(cls.temp_output_dir)
|
| 725 |
+
print(f"Cleaned up temp directory: {cls.temp_output_dir}")
|
| 726 |
+
|
| 727 |
+
def test_extract_topics_default_settings(self):
|
| 728 |
+
"""Test: Extract topics from a CSV file with default settings"""
|
| 729 |
+
print("\n=== Testing topic extraction with default settings ===")
|
| 730 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 731 |
+
|
| 732 |
+
if not os.path.isfile(input_file):
|
| 733 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 734 |
+
|
| 735 |
+
result = run_cli_topics(
|
| 736 |
+
script_path=self.script_path,
|
| 737 |
+
task="extract",
|
| 738 |
+
input_file=input_file,
|
| 739 |
+
text_column="Case Note",
|
| 740 |
+
output_dir=self.temp_output_dir,
|
| 741 |
+
model_choice="test-model",
|
| 742 |
+
inference_server_model="test-model",
|
| 743 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 744 |
+
create_xlsx_output=False,
|
| 745 |
+
save_logs_to_csv=False,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
self.assertTrue(result, "Topic extraction with default settings should succeed")
|
| 749 |
+
print("β
Topic extraction with default settings passed")
|
| 750 |
+
|
| 751 |
+
def test_extract_topics_custom_model_and_context(self):
|
| 752 |
+
"""Test: Extract topics with custom model and context"""
|
| 753 |
+
print("\n=== Testing topic extraction with custom model and context ===")
|
| 754 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 755 |
+
|
| 756 |
+
if not os.path.isfile(input_file):
|
| 757 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 758 |
+
|
| 759 |
+
result = run_cli_topics(
|
| 760 |
+
script_path=self.script_path,
|
| 761 |
+
task="extract",
|
| 762 |
+
input_file=input_file,
|
| 763 |
+
text_column="Case Note",
|
| 764 |
+
output_dir=self.temp_output_dir,
|
| 765 |
+
model_choice="test-model",
|
| 766 |
+
inference_server_model="test-model",
|
| 767 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 768 |
+
context="Social Care case notes for young people",
|
| 769 |
+
create_xlsx_output=False,
|
| 770 |
+
save_logs_to_csv=False,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
self.assertTrue(
|
| 774 |
+
result, "Topic extraction with custom model and context should succeed"
|
| 775 |
+
)
|
| 776 |
+
print("β
Topic extraction with custom model and context passed")
|
| 777 |
+
|
| 778 |
+
def test_extract_topics_with_grouping(self):
|
| 779 |
+
"""Test: Extract topics with grouping"""
|
| 780 |
+
print("\n=== Testing topic extraction with grouping ===")
|
| 781 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 782 |
+
|
| 783 |
+
if not os.path.isfile(input_file):
|
| 784 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 785 |
+
|
| 786 |
+
result = run_cli_topics(
|
| 787 |
+
script_path=self.script_path,
|
| 788 |
+
task="extract",
|
| 789 |
+
input_file=input_file,
|
| 790 |
+
text_column="Case Note",
|
| 791 |
+
output_dir=self.temp_output_dir,
|
| 792 |
+
group_by="Client",
|
| 793 |
+
model_choice="test-model",
|
| 794 |
+
inference_server_model="test-model",
|
| 795 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 796 |
+
create_xlsx_output=False,
|
| 797 |
+
save_logs_to_csv=False,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
self.assertTrue(result, "Topic extraction with grouping should succeed")
|
| 801 |
+
print("β
Topic extraction with grouping passed")
|
| 802 |
+
|
| 803 |
+
def test_extract_topics_with_candidate_topics(self):
|
| 804 |
+
"""Test: Extract topics with candidate topics (zero-shot)"""
|
| 805 |
+
print("\n=== Testing topic extraction with candidate topics ===")
|
| 806 |
+
input_file = os.path.join(
|
| 807 |
+
self.example_data_dir, "dummy_consultation_response.csv"
|
| 808 |
+
)
|
| 809 |
+
candidate_topics_file = os.path.join(
|
| 810 |
+
self.example_data_dir, "dummy_consultation_response_themes.csv"
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
if not os.path.isfile(input_file):
|
| 814 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 815 |
+
if not os.path.isfile(candidate_topics_file):
|
| 816 |
+
self.skipTest(f"Candidate topics file not found: {candidate_topics_file}")
|
| 817 |
+
|
| 818 |
+
result = run_cli_topics(
|
| 819 |
+
script_path=self.script_path,
|
| 820 |
+
task="extract",
|
| 821 |
+
input_file=input_file,
|
| 822 |
+
text_column="Response text",
|
| 823 |
+
output_dir=self.temp_output_dir,
|
| 824 |
+
candidate_topics=candidate_topics_file,
|
| 825 |
+
model_choice="test-model",
|
| 826 |
+
inference_server_model="test-model",
|
| 827 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 828 |
+
create_xlsx_output=False,
|
| 829 |
+
save_logs_to_csv=False,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.assertTrue(result, "Topic extraction with candidate topics should succeed")
|
| 833 |
+
print("β
Topic extraction with candidate topics passed")
|
| 834 |
+
|
| 835 |
+
def test_deduplicate_topics_fuzzy(self):
|
| 836 |
+
"""Test: Deduplicate topics using fuzzy matching"""
|
| 837 |
+
print("\n=== Testing topic deduplication with fuzzy matching ===")
|
| 838 |
+
|
| 839 |
+
# First, we need to create some output files by running extraction
|
| 840 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 841 |
+
|
| 842 |
+
if not os.path.isfile(input_file):
|
| 843 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 844 |
+
|
| 845 |
+
# Run extraction first to create output files
|
| 846 |
+
extract_result = run_cli_topics(
|
| 847 |
+
script_path=self.script_path,
|
| 848 |
+
task="extract",
|
| 849 |
+
input_file=input_file,
|
| 850 |
+
text_column="Case Note",
|
| 851 |
+
output_dir=self.temp_output_dir,
|
| 852 |
+
model_choice="test-model",
|
| 853 |
+
inference_server_model="test-model",
|
| 854 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 855 |
+
create_xlsx_output=False,
|
| 856 |
+
save_logs_to_csv=False,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
if not extract_result:
|
| 860 |
+
self.skipTest("Extraction failed, cannot test deduplication")
|
| 861 |
+
|
| 862 |
+
# Find the output files (they should be in temp_output_dir)
|
| 863 |
+
# The file names follow a pattern like: {input_file_name}_col_{text_column}_reference_table.csv
|
| 864 |
+
import glob
|
| 865 |
+
|
| 866 |
+
reference_files = glob.glob(
|
| 867 |
+
os.path.join(self.temp_output_dir, "*reference_table.csv")
|
| 868 |
+
)
|
| 869 |
+
unique_files = glob.glob(
|
| 870 |
+
os.path.join(self.temp_output_dir, "*unique_topics.csv")
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if not reference_files or not unique_files:
|
| 874 |
+
self.skipTest("Could not find output files from extraction")
|
| 875 |
+
|
| 876 |
+
result = run_cli_topics(
|
| 877 |
+
script_path=self.script_path,
|
| 878 |
+
task="deduplicate",
|
| 879 |
+
previous_output_files=[reference_files[0], unique_files[0]],
|
| 880 |
+
output_dir=self.temp_output_dir,
|
| 881 |
+
method="fuzzy",
|
| 882 |
+
similarity_threshold=90,
|
| 883 |
+
create_xlsx_output=False,
|
| 884 |
+
save_logs_to_csv=False,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
self.assertTrue(
|
| 888 |
+
result, "Topic deduplication with fuzzy matching should succeed"
|
| 889 |
+
)
|
| 890 |
+
print("β
Topic deduplication with fuzzy matching passed")
|
| 891 |
+
|
| 892 |
+
def test_deduplicate_topics_llm(self):
|
| 893 |
+
"""Test: Deduplicate topics using LLM"""
|
| 894 |
+
print("\n=== Testing topic deduplication with LLM ===")
|
| 895 |
+
|
| 896 |
+
# First, we need to create some output files by running extraction
|
| 897 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 898 |
+
|
| 899 |
+
if not os.path.isfile(input_file):
|
| 900 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 901 |
+
|
| 902 |
+
# Run extraction first to create output files
|
| 903 |
+
extract_result = run_cli_topics(
|
| 904 |
+
script_path=self.script_path,
|
| 905 |
+
task="extract",
|
| 906 |
+
input_file=input_file,
|
| 907 |
+
text_column="Case Note",
|
| 908 |
+
output_dir=self.temp_output_dir,
|
| 909 |
+
model_choice="test-model",
|
| 910 |
+
inference_server_model="test-model",
|
| 911 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 912 |
+
create_xlsx_output=False,
|
| 913 |
+
save_logs_to_csv=False,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if not extract_result:
|
| 917 |
+
self.skipTest("Extraction failed, cannot test deduplication")
|
| 918 |
+
|
| 919 |
+
# Find the output files
|
| 920 |
+
import glob
|
| 921 |
+
|
| 922 |
+
reference_files = glob.glob(
|
| 923 |
+
os.path.join(self.temp_output_dir, "*reference_table.csv")
|
| 924 |
+
)
|
| 925 |
+
unique_files = glob.glob(
|
| 926 |
+
os.path.join(self.temp_output_dir, "*unique_topics.csv")
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
if not reference_files or not unique_files:
|
| 930 |
+
self.skipTest("Could not find output files from extraction")
|
| 931 |
+
|
| 932 |
+
result = run_cli_topics(
|
| 933 |
+
script_path=self.script_path,
|
| 934 |
+
task="deduplicate",
|
| 935 |
+
previous_output_files=[reference_files[0], unique_files[0]],
|
| 936 |
+
output_dir=self.temp_output_dir,
|
| 937 |
+
method="llm",
|
| 938 |
+
model_choice="test-model",
|
| 939 |
+
inference_server_model="test-model",
|
| 940 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 941 |
+
create_xlsx_output=False,
|
| 942 |
+
save_logs_to_csv=False,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
self.assertTrue(result, "Topic deduplication with LLM should succeed")
|
| 946 |
+
print("β
Topic deduplication with LLM passed")
|
| 947 |
+
|
| 948 |
+
def test_all_in_one_pipeline(self):
|
| 949 |
+
"""Test: Run complete pipeline (extract, deduplicate, summarise)"""
|
| 950 |
+
print("\n=== Testing all-in-one pipeline ===")
|
| 951 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 952 |
+
|
| 953 |
+
if not os.path.isfile(input_file):
|
| 954 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 955 |
+
|
| 956 |
+
result = run_cli_topics(
|
| 957 |
+
script_path=self.script_path,
|
| 958 |
+
task="all_in_one",
|
| 959 |
+
input_file=input_file,
|
| 960 |
+
text_column="Case Note",
|
| 961 |
+
output_dir=self.temp_output_dir,
|
| 962 |
+
model_choice="test-model",
|
| 963 |
+
inference_server_model="test-model",
|
| 964 |
+
api_url="http://localhost:8080", # URL doesn't matter with function mocking
|
| 965 |
+
create_xlsx_output=False,
|
| 966 |
+
save_logs_to_csv=False,
|
| 967 |
+
timeout=120, # Shorter timeout for debugging
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
self.assertTrue(result, "All-in-one pipeline should succeed")
|
| 971 |
+
print("β
All-in-one pipeline passed")
|
| 972 |
+
|
| 973 |
+
def test_direct_mode_extract(self):
|
| 974 |
+
"""Test: Run app in direct mode for topic extraction"""
|
| 975 |
+
print("\n=== Testing direct mode - topic extraction ===")
|
| 976 |
+
input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
|
| 977 |
+
|
| 978 |
+
if not os.path.isfile(input_file):
|
| 979 |
+
self.skipTest(f"Example file not found: {input_file}")
|
| 980 |
+
|
| 981 |
+
app_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "app.py")
|
| 982 |
+
|
| 983 |
+
if not os.path.isfile(app_path):
|
| 984 |
+
self.skipTest(f"App script not found: {app_path}")
|
| 985 |
+
|
| 986 |
+
result = run_app_direct_mode(
|
| 987 |
+
app_path=app_path,
|
| 988 |
+
task="extract",
|
| 989 |
+
input_file=input_file,
|
| 990 |
+
text_column="Case Note",
|
| 991 |
+
output_dir=self.temp_output_dir,
|
| 992 |
+
model_choice="test-model",
|
| 993 |
+
inference_server_model="test-model",
|
| 994 |
+
api_url="http://localhost:8080",
|
| 995 |
+
create_xlsx_output=False,
|
| 996 |
+
save_logs_to_csv=False,
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
self.assertTrue(result, "Direct mode topic extraction should succeed")
|
| 1000 |
+
print("β
Direct mode topic extraction passed")
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def run_all_tests():
|
| 1004 |
+
"""Run all test examples and report results."""
|
| 1005 |
+
print("=" * 80)
|
| 1006 |
+
print("LLM TOPIC MODELLER TEST SUITE")
|
| 1007 |
+
print("=" * 80)
|
| 1008 |
+
print("This test suite includes:")
|
| 1009 |
+
print("- CLI examples from the epilog")
|
| 1010 |
+
print("- GUI application tests")
|
| 1011 |
+
print("- Tests use a mock inference-server to avoid API costs")
|
| 1012 |
+
print("Tests will be skipped if required example files are not found.")
|
| 1013 |
+
print("=" * 80)
|
| 1014 |
+
|
| 1015 |
+
# Create test suite
|
| 1016 |
+
loader = unittest.TestLoader()
|
| 1017 |
+
suite = unittest.TestSuite()
|
| 1018 |
+
|
| 1019 |
+
# Add CLI tests
|
| 1020 |
+
cli_suite = loader.loadTestsFromTestCase(TestCLITopicsExamples)
|
| 1021 |
+
suite.addTests(cli_suite)
|
| 1022 |
+
|
| 1023 |
+
# Add GUI tests
|
| 1024 |
+
try:
|
| 1025 |
+
from test.test_gui_only import TestGUIAppOnly
|
| 1026 |
+
|
| 1027 |
+
gui_suite = loader.loadTestsFromTestCase(TestGUIAppOnly)
|
| 1028 |
+
suite.addTests(gui_suite)
|
| 1029 |
+
print("GUI tests included in test suite.")
|
| 1030 |
+
except ImportError as e:
|
| 1031 |
+
print(f"Warning: Could not import GUI tests: {e}")
|
| 1032 |
+
print("Skipping GUI tests.")
|
| 1033 |
+
|
| 1034 |
+
# Run tests with detailed output
|
| 1035 |
+
runner = unittest.TextTestRunner(verbosity=2, stream=None)
|
| 1036 |
+
result = runner.run(suite)
|
| 1037 |
+
|
| 1038 |
+
# Print summary
|
| 1039 |
+
print("\n" + "=" * 80)
|
| 1040 |
+
print("TEST SUMMARY")
|
| 1041 |
+
print("=" * 80)
|
| 1042 |
+
print(f"Tests run: {result.testsRun}")
|
| 1043 |
+
print(f"Failures: {len(result.failures)}")
|
| 1044 |
+
print(f"Errors: {len(result.errors)}")
|
| 1045 |
+
print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
|
| 1046 |
+
|
| 1047 |
+
if result.failures:
|
| 1048 |
+
print("\nFAILURES:")
|
| 1049 |
+
for test, traceback in result.failures:
|
| 1050 |
+
print(f"- {test}: {traceback}")
|
| 1051 |
+
|
| 1052 |
+
if result.errors:
|
| 1053 |
+
print("\nERRORS:")
|
| 1054 |
+
for test, traceback in result.errors:
|
| 1055 |
+
print(f"- {test}: {traceback}")
|
| 1056 |
+
|
| 1057 |
+
success = len(result.failures) == 0 and len(result.errors) == 0
|
| 1058 |
+
print(f"\nOverall result: {'β
PASSED' if success else 'β FAILED'}")
|
| 1059 |
+
print("=" * 80)
|
| 1060 |
+
|
| 1061 |
+
return success
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
if __name__ == "__main__":
|
| 1065 |
+
# Run the test suite
|
| 1066 |
+
success = run_all_tests()
|
| 1067 |
+
exit(0 if success else 1)
|
test/test_gui_only.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone GUI test script for the LLM topic modeller application.
|
| 4 |
+
|
| 5 |
+
This script tests only the GUI functionality of app.py to ensure it loads correctly.
|
| 6 |
+
Run this script to verify that the Gradio interface can be imported and initialized.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import threading
|
| 12 |
+
import unittest
|
| 13 |
+
|
| 14 |
+
# Add the parent directory to the path so we can import the app
|
| 15 |
+
parent_dir = os.path.dirname(os.path.dirname(__file__))
|
| 16 |
+
if parent_dir not in sys.path:
|
| 17 |
+
sys.path.insert(0, parent_dir)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestGUIAppOnly(unittest.TestCase):
|
| 21 |
+
"""Test suite for GUI application loading and basic functionality."""
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def setUpClass(cls):
|
| 25 |
+
"""Set up test environment for GUI tests."""
|
| 26 |
+
cls.app_path = os.path.join(parent_dir, "app.py")
|
| 27 |
+
|
| 28 |
+
# Verify app.py exists
|
| 29 |
+
if not os.path.isfile(cls.app_path):
|
| 30 |
+
raise FileNotFoundError(f"App file not found: {cls.app_path}")
|
| 31 |
+
|
| 32 |
+
print(f"GUI test setup complete. App: {cls.app_path}")
|
| 33 |
+
|
| 34 |
+
def test_app_import_and_initialization(self):
|
| 35 |
+
"""Test: Import app.py and check if the Gradio app object is created successfully."""
|
| 36 |
+
print("\n=== Testing GUI app import and initialization ===")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Import the app module
|
| 40 |
+
import app
|
| 41 |
+
|
| 42 |
+
# Check if the app object exists and is a Gradio Blocks object
|
| 43 |
+
self.assertTrue(
|
| 44 |
+
hasattr(app, "app"), "App object should exist in the module"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Check if it's a Gradio Blocks instance
|
| 48 |
+
import gradio as gr
|
| 49 |
+
|
| 50 |
+
self.assertIsInstance(
|
| 51 |
+
app.app, gr.Blocks, "App should be a Gradio Blocks instance"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print("β
GUI app import and initialization passed")
|
| 55 |
+
|
| 56 |
+
except ImportError as e:
|
| 57 |
+
error_msg = f"Failed to import app module: {e}"
|
| 58 |
+
self.fail(error_msg)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
self.fail(f"Unexpected error during app initialization: {e}")
|
| 61 |
+
|
| 62 |
+
def test_app_launch_headless(self):
|
| 63 |
+
"""Test: Launch the app in headless mode to verify it starts without errors."""
|
| 64 |
+
print("\n=== Testing GUI app launch in headless mode ===")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# Import the app module
|
| 68 |
+
import app
|
| 69 |
+
|
| 70 |
+
# Set up a flag to track if the app launched successfully
|
| 71 |
+
app_launched = threading.Event()
|
| 72 |
+
launch_error = None
|
| 73 |
+
|
| 74 |
+
def launch_app():
|
| 75 |
+
try:
|
| 76 |
+
# Launch the app in headless mode with a short timeout
|
| 77 |
+
app.app.launch(
|
| 78 |
+
show_error=True,
|
| 79 |
+
inbrowser=False, # Don't open browser
|
| 80 |
+
server_port=0, # Use any available port
|
| 81 |
+
quiet=True, # Suppress output
|
| 82 |
+
prevent_thread_lock=True, # Don't block the main thread
|
| 83 |
+
)
|
| 84 |
+
app_launched.set()
|
| 85 |
+
except Exception:
|
| 86 |
+
app_launched.set()
|
| 87 |
+
|
| 88 |
+
# Start the app in a separate thread
|
| 89 |
+
launch_thread = threading.Thread(target=launch_app)
|
| 90 |
+
launch_thread.daemon = True
|
| 91 |
+
launch_thread.start()
|
| 92 |
+
|
| 93 |
+
# Wait for the app to launch (with timeout)
|
| 94 |
+
if app_launched.wait(timeout=10): # 10 second timeout
|
| 95 |
+
if launch_error:
|
| 96 |
+
self.fail(f"App launch failed: {launch_error}")
|
| 97 |
+
else:
|
| 98 |
+
print("β
GUI app launch in headless mode passed")
|
| 99 |
+
else:
|
| 100 |
+
self.fail("App launch timed out after 10 seconds")
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
error_msg = f"Unexpected error during app launch test: {e}"
|
| 104 |
+
self.fail(error_msg)
|
| 105 |
+
|
| 106 |
+
def test_app_configuration_loading(self):
|
| 107 |
+
"""Test: Verify that the app can load its configuration without errors."""
|
| 108 |
+
print("\n=== Testing GUI app configuration loading ===")
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Check if key configuration variables are accessible
|
| 112 |
+
# These should be imported from tools.config
|
| 113 |
+
from tools.config import (
|
| 114 |
+
DEFAULT_COST_CODE,
|
| 115 |
+
GRADIO_SERVER_PORT,
|
| 116 |
+
MAX_FILE_SIZE,
|
| 117 |
+
default_model_choice,
|
| 118 |
+
model_name_map,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Verify these are not None/empty
|
| 122 |
+
self.assertIsNotNone(
|
| 123 |
+
GRADIO_SERVER_PORT, "GRADIO_SERVER_PORT should be configured"
|
| 124 |
+
)
|
| 125 |
+
self.assertIsNotNone(MAX_FILE_SIZE, "MAX_FILE_SIZE should be configured")
|
| 126 |
+
self.assertIsNotNone(
|
| 127 |
+
DEFAULT_COST_CODE, "DEFAULT_COST_CODE should be configured"
|
| 128 |
+
)
|
| 129 |
+
self.assertIsNotNone(
|
| 130 |
+
default_model_choice, "default_model_choice should be configured"
|
| 131 |
+
)
|
| 132 |
+
self.assertIsNotNone(model_name_map, "model_name_map should be configured")
|
| 133 |
+
|
| 134 |
+
print("β
GUI app configuration loading passed")
|
| 135 |
+
|
| 136 |
+
except ImportError as e:
|
| 137 |
+
error_msg = f"Failed to import configuration: {e}"
|
| 138 |
+
self.fail(error_msg)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
error_msg = f"Unexpected error during configuration test: {e}"
|
| 141 |
+
self.fail(error_msg)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def run_gui_tests():
|
| 145 |
+
"""Run GUI tests and report results."""
|
| 146 |
+
print("=" * 80)
|
| 147 |
+
print("LLM TOPIC MODELLER GUI TEST SUITE")
|
| 148 |
+
print("=" * 80)
|
| 149 |
+
print("This test suite verifies that the GUI application loads correctly.")
|
| 150 |
+
print("=" * 80)
|
| 151 |
+
|
| 152 |
+
# Create test suite
|
| 153 |
+
loader = unittest.TestLoader()
|
| 154 |
+
suite = loader.loadTestsFromTestCase(TestGUIAppOnly)
|
| 155 |
+
|
| 156 |
+
# Run tests with detailed output
|
| 157 |
+
runner = unittest.TextTestRunner(verbosity=2, stream=None)
|
| 158 |
+
result = runner.run(suite)
|
| 159 |
+
|
| 160 |
+
# Print summary
|
| 161 |
+
print("\n" + "=" * 80)
|
| 162 |
+
print("GUI TEST SUMMARY")
|
| 163 |
+
print("=" * 80)
|
| 164 |
+
print(f"Tests run: {result.testsRun}")
|
| 165 |
+
print(f"Failures: {len(result.failures)}")
|
| 166 |
+
print(f"Errors: {len(result.errors)}")
|
| 167 |
+
print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
|
| 168 |
+
|
| 169 |
+
if result.failures:
|
| 170 |
+
print("\nFAILURES:")
|
| 171 |
+
for test, traceback in result.failures:
|
| 172 |
+
print(f"- {test}: {traceback}")
|
| 173 |
+
|
| 174 |
+
if result.errors:
|
| 175 |
+
print("\nERRORS:")
|
| 176 |
+
for test, traceback in result.errors:
|
| 177 |
+
print(f"- {test}: {traceback}")
|
| 178 |
+
|
| 179 |
+
success = len(result.failures) == 0 and len(result.errors) == 0
|
| 180 |
+
print(f"\nOverall result: {'β
PASSED' if success else 'β FAILED'}")
|
| 181 |
+
print("=" * 80)
|
| 182 |
+
|
| 183 |
+
return success
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
# Run the GUI test suite
|
| 188 |
+
success = run_gui_tests()
|
| 189 |
+
exit(0 if success else 1)
|
tools/__init__.py
ADDED
|
File without changes
|
tools/auth.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import hashlib
|
| 3 |
+
import hmac
|
| 4 |
+
|
| 5 |
+
import boto3
|
| 6 |
+
|
| 7 |
+
from tools.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_REGION, AWS_USER_POOL_ID
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def calculate_secret_hash(client_id: str, client_secret: str, username: str):
|
| 11 |
+
message = username + client_id
|
| 12 |
+
dig = hmac.new(
|
| 13 |
+
str(client_secret).encode("utf-8"),
|
| 14 |
+
msg=str(message).encode("utf-8"),
|
| 15 |
+
digestmod=hashlib.sha256,
|
| 16 |
+
).digest()
|
| 17 |
+
secret_hash = base64.b64encode(dig).decode()
|
| 18 |
+
return secret_hash
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def authenticate_user(
|
| 22 |
+
username: str,
|
| 23 |
+
password: str,
|
| 24 |
+
user_pool_id: str = AWS_USER_POOL_ID,
|
| 25 |
+
client_id: str = AWS_CLIENT_ID,
|
| 26 |
+
client_secret: str = AWS_CLIENT_SECRET,
|
| 27 |
+
):
|
| 28 |
+
"""Authenticates a user against an AWS Cognito user pool.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
user_pool_id (str): The ID of the Cognito user pool.
|
| 32 |
+
client_id (str): The ID of the Cognito user pool client.
|
| 33 |
+
username (str): The username of the user.
|
| 34 |
+
password (str): The password of the user.
|
| 35 |
+
client_secret (str): The client secret of the app client
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
bool: True if the user is authenticated, False otherwise.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
client = boto3.client(
|
| 42 |
+
"cognito-idp", region_name=AWS_REGION
|
| 43 |
+
) # Cognito Identity Provider client
|
| 44 |
+
|
| 45 |
+
# Compute the secret hash
|
| 46 |
+
secret_hash = calculate_secret_hash(client_id, client_secret, username)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
|
| 50 |
+
if client_secret == "":
|
| 51 |
+
response = client.initiate_auth(
|
| 52 |
+
AuthFlow="USER_PASSWORD_AUTH",
|
| 53 |
+
AuthParameters={
|
| 54 |
+
"USERNAME": username,
|
| 55 |
+
"PASSWORD": password,
|
| 56 |
+
},
|
| 57 |
+
ClientId=client_id,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
response = client.initiate_auth(
|
| 62 |
+
AuthFlow="USER_PASSWORD_AUTH",
|
| 63 |
+
AuthParameters={
|
| 64 |
+
"USERNAME": username,
|
| 65 |
+
"PASSWORD": password,
|
| 66 |
+
"SECRET_HASH": secret_hash,
|
| 67 |
+
},
|
| 68 |
+
ClientId=client_id,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# If successful, you'll receive an AuthenticationResult in the response
|
| 72 |
+
if response.get("AuthenticationResult"):
|
| 73 |
+
return True
|
| 74 |
+
else:
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
except client.exceptions.NotAuthorizedException:
|
| 78 |
+
return False
|
| 79 |
+
except client.exceptions.UserNotFoundException:
|
| 80 |
+
return False
|
| 81 |
+
except Exception as e:
|
| 82 |
+
out_message = f"An error occurred: {e}"
|
| 83 |
+
print(out_message)
|
| 84 |
+
raise Exception(out_message)
|
| 85 |
+
return False
|
tools/aws_functions.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import boto3
|
| 5 |
+
|
| 6 |
+
from tools.config import (
|
| 7 |
+
AWS_ACCESS_KEY,
|
| 8 |
+
AWS_REGION,
|
| 9 |
+
AWS_SECRET_KEY,
|
| 10 |
+
PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
|
| 11 |
+
RUN_AWS_FUNCTIONS,
|
| 12 |
+
S3_LOG_BUCKET,
|
| 13 |
+
S3_OUTPUTS_BUCKET,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Empty bucket name in case authentication fails
|
| 17 |
+
bucket_name = S3_LOG_BUCKET
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def connect_to_bedrock_runtime(
|
| 21 |
+
model_name_map: dict,
|
| 22 |
+
model_choice: str,
|
| 23 |
+
aws_access_key_textbox: str = "",
|
| 24 |
+
aws_secret_key_textbox: str = "",
|
| 25 |
+
aws_region_textbox: str = "",
|
| 26 |
+
):
|
| 27 |
+
# If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
|
| 28 |
+
model_source = model_name_map[model_choice]["source"]
|
| 29 |
+
|
| 30 |
+
# Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
|
| 31 |
+
region = aws_region_textbox if aws_region_textbox else AWS_REGION
|
| 32 |
+
|
| 33 |
+
if "AWS" in model_source:
|
| 34 |
+
if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
|
| 35 |
+
print("Connecting to Bedrock via existing SSO connection")
|
| 36 |
+
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
|
| 37 |
+
elif aws_access_key_textbox and aws_secret_key_textbox:
|
| 38 |
+
print(
|
| 39 |
+
"Connecting to Bedrock using AWS access key and secret keys from user input."
|
| 40 |
+
)
|
| 41 |
+
bedrock_runtime = boto3.client(
|
| 42 |
+
"bedrock-runtime",
|
| 43 |
+
aws_access_key_id=aws_access_key_textbox,
|
| 44 |
+
aws_secret_access_key=aws_secret_key_textbox,
|
| 45 |
+
region_name=region,
|
| 46 |
+
)
|
| 47 |
+
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 48 |
+
print("Getting Bedrock credentials from environment variables")
|
| 49 |
+
bedrock_runtime = boto3.client(
|
| 50 |
+
"bedrock-runtime",
|
| 51 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 52 |
+
aws_secret_access_key=AWS_SECRET_KEY,
|
| 53 |
+
region_name=region,
|
| 54 |
+
)
|
| 55 |
+
elif RUN_AWS_FUNCTIONS == "1":
|
| 56 |
+
print("Connecting to Bedrock via existing SSO connection")
|
| 57 |
+
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
|
| 58 |
+
else:
|
| 59 |
+
bedrock_runtime = ""
|
| 60 |
+
out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
|
| 61 |
+
print(out_message)
|
| 62 |
+
raise Exception(out_message)
|
| 63 |
+
else:
|
| 64 |
+
bedrock_runtime = None
|
| 65 |
+
|
| 66 |
+
return bedrock_runtime
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def connect_to_s3_client(
|
| 70 |
+
aws_access_key_textbox: str = "",
|
| 71 |
+
aws_secret_key_textbox: str = "",
|
| 72 |
+
aws_region_textbox: str = "",
|
| 73 |
+
):
|
| 74 |
+
# If running an anthropic model, assume that running an AWS s3 model, load in s3
|
| 75 |
+
s3_client = None
|
| 76 |
+
|
| 77 |
+
# Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
|
| 78 |
+
region = aws_region_textbox if aws_region_textbox else AWS_REGION
|
| 79 |
+
|
| 80 |
+
if aws_access_key_textbox and aws_secret_key_textbox:
|
| 81 |
+
print("Connecting to s3 using AWS access key and secret keys from user input.")
|
| 82 |
+
s3_client = boto3.client(
|
| 83 |
+
"s3",
|
| 84 |
+
aws_access_key_id=aws_access_key_textbox,
|
| 85 |
+
aws_secret_access_key=aws_secret_key_textbox,
|
| 86 |
+
region_name=region,
|
| 87 |
+
)
|
| 88 |
+
elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
|
| 89 |
+
print("Connecting to s3 via existing SSO connection")
|
| 90 |
+
s3_client = boto3.client("s3", region_name=region)
|
| 91 |
+
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 92 |
+
print("Getting s3 credentials from environment variables")
|
| 93 |
+
s3_client = boto3.client(
|
| 94 |
+
"s3",
|
| 95 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 96 |
+
aws_secret_access_key=AWS_SECRET_KEY,
|
| 97 |
+
region_name=region,
|
| 98 |
+
)
|
| 99 |
+
elif RUN_AWS_FUNCTIONS == "1":
|
| 100 |
+
print("Connecting to s3 via existing SSO connection")
|
| 101 |
+
s3_client = boto3.client("s3", region_name=region)
|
| 102 |
+
else:
|
| 103 |
+
s3_client = ""
|
| 104 |
+
out_message = "Cannot connect to S3 service. Please provide access keys under LLM settings, or choose another model type."
|
| 105 |
+
print(out_message)
|
| 106 |
+
raise Exception(out_message)
|
| 107 |
+
|
| 108 |
+
return s3_client
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Download direct from S3 - requires login credentials
|
| 112 |
+
def download_file_from_s3(
|
| 113 |
+
bucket_name: str,
|
| 114 |
+
key: str,
|
| 115 |
+
local_file_path: str,
|
| 116 |
+
aws_access_key_textbox: str = "",
|
| 117 |
+
aws_secret_key_textbox: str = "",
|
| 118 |
+
aws_region_textbox: str = "",
|
| 119 |
+
RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
|
| 120 |
+
):
|
| 121 |
+
|
| 122 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 123 |
+
|
| 124 |
+
s3 = connect_to_s3_client(
|
| 125 |
+
aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
|
| 126 |
+
)
|
| 127 |
+
# boto3.client('s3')
|
| 128 |
+
s3.download_file(bucket_name, key, local_file_path)
|
| 129 |
+
print(f"File downloaded from S3: s3://{bucket_name}/{key} to {local_file_path}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def download_folder_from_s3(
|
| 133 |
+
bucket_name: str,
|
| 134 |
+
s3_folder: str,
|
| 135 |
+
local_folder: str,
|
| 136 |
+
aws_access_key_textbox: str = "",
|
| 137 |
+
aws_secret_key_textbox: str = "",
|
| 138 |
+
aws_region_textbox: str = "",
|
| 139 |
+
RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Download all files from an S3 folder to a local folder.
|
| 143 |
+
"""
|
| 144 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 145 |
+
s3 = connect_to_s3_client(
|
| 146 |
+
aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
|
| 147 |
+
)
|
| 148 |
+
# boto3.client('s3')
|
| 149 |
+
|
| 150 |
+
# List objects in the specified S3 folder
|
| 151 |
+
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
|
| 152 |
+
|
| 153 |
+
# Download each object
|
| 154 |
+
for obj in response.get("Contents", []):
|
| 155 |
+
# Extract object key and construct local file path
|
| 156 |
+
object_key = obj["Key"]
|
| 157 |
+
local_file_path = os.path.join(
|
| 158 |
+
local_folder, os.path.relpath(object_key, s3_folder)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Create directories if necessary
|
| 162 |
+
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
| 163 |
+
|
| 164 |
+
# Download the object
|
| 165 |
+
try:
|
| 166 |
+
s3.download_file(bucket_name, object_key, local_file_path)
|
| 167 |
+
print(
|
| 168 |
+
f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
|
| 169 |
+
)
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def download_files_from_s3(
|
| 175 |
+
bucket_name: str,
|
| 176 |
+
s3_folder: str,
|
| 177 |
+
local_folder: str,
|
| 178 |
+
filenames: list[str],
|
| 179 |
+
aws_access_key_textbox: str = "",
|
| 180 |
+
aws_secret_key_textbox: str = "",
|
| 181 |
+
aws_region_textbox: str = "",
|
| 182 |
+
RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Download specific files from an S3 folder to a local folder.
|
| 186 |
+
"""
|
| 187 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 188 |
+
s3 = connect_to_s3_client(
|
| 189 |
+
aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
|
| 190 |
+
)
|
| 191 |
+
# boto3.client('s3')
|
| 192 |
+
|
| 193 |
+
print("Trying to download file: ", filenames)
|
| 194 |
+
|
| 195 |
+
if filenames == "*":
|
| 196 |
+
# List all objects in the S3 folder
|
| 197 |
+
print("Trying to download all files in AWS folder: ", s3_folder)
|
| 198 |
+
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
|
| 199 |
+
|
| 200 |
+
print("Found files in AWS folder: ", response.get("Contents", []))
|
| 201 |
+
|
| 202 |
+
filenames = [
|
| 203 |
+
obj["Key"].split("/")[-1] for obj in response.get("Contents", [])
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
print("Found filenames in AWS folder: ", filenames)
|
| 207 |
+
|
| 208 |
+
for filename in filenames:
|
| 209 |
+
object_key = os.path.join(s3_folder, filename)
|
| 210 |
+
local_file_path = os.path.join(local_folder, filename)
|
| 211 |
+
|
| 212 |
+
# Create directories if necessary
|
| 213 |
+
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
| 214 |
+
|
| 215 |
+
# Download the object
|
| 216 |
+
try:
|
| 217 |
+
s3.download_file(bucket_name, object_key, local_file_path)
|
| 218 |
+
print(
|
| 219 |
+
f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
|
| 220 |
+
)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def upload_file_to_s3(
|
| 226 |
+
local_file_paths: List[str],
|
| 227 |
+
s3_key: str,
|
| 228 |
+
s3_bucket: str = bucket_name,
|
| 229 |
+
aws_access_key_textbox: str = "",
|
| 230 |
+
aws_secret_key_textbox: str = "",
|
| 231 |
+
aws_region_textbox: str = "",
|
| 232 |
+
RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Uploads a file from local machine to Amazon S3.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
- local_file_path: Local file path(s) of the file(s) to upload.
|
| 239 |
+
- s3_key: Key (path) to the file in the S3 bucket.
|
| 240 |
+
- s3_bucket: Name of the S3 bucket.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
- Message as variable/printed to console
|
| 244 |
+
"""
|
| 245 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 246 |
+
|
| 247 |
+
final_out_message = list()
|
| 248 |
+
|
| 249 |
+
s3_client = connect_to_s3_client(
|
| 250 |
+
aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
|
| 251 |
+
)
|
| 252 |
+
# boto3.client('s3')
|
| 253 |
+
|
| 254 |
+
if isinstance(local_file_paths, str):
|
| 255 |
+
local_file_paths = [local_file_paths]
|
| 256 |
+
|
| 257 |
+
for file in local_file_paths:
|
| 258 |
+
try:
|
| 259 |
+
# Get file name off file path
|
| 260 |
+
file_name = os.path.basename(file)
|
| 261 |
+
|
| 262 |
+
s3_key_full = s3_key + file_name
|
| 263 |
+
print("S3 key: ", s3_key_full)
|
| 264 |
+
|
| 265 |
+
s3_client.upload_file(file, s3_bucket, s3_key_full)
|
| 266 |
+
out_message = "File " + file_name + " uploaded successfully!"
|
| 267 |
+
print(out_message)
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
out_message = f"Error uploading file(s): {e}"
|
| 271 |
+
print(out_message)
|
| 272 |
+
|
| 273 |
+
final_out_message.append(out_message)
|
| 274 |
+
final_out_message_str = "\n".join(final_out_message)
|
| 275 |
+
|
| 276 |
+
else:
|
| 277 |
+
final_out_message_str = "Not connected to AWS, no files uploaded."
|
| 278 |
+
|
| 279 |
+
return final_out_message_str
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Helper to upload outputs to S3 when enabled in config.
|
| 283 |
+
def export_outputs_to_s3(
|
| 284 |
+
file_list_state,
|
| 285 |
+
s3_output_folder_state_value: str,
|
| 286 |
+
save_outputs_to_s3_flag: bool,
|
| 287 |
+
base_file_state=None,
|
| 288 |
+
s3_bucket: str = S3_OUTPUTS_BUCKET,
|
| 289 |
+
):
|
| 290 |
+
"""
|
| 291 |
+
Upload a list of local output files to the configured S3 outputs folder.
|
| 292 |
+
|
| 293 |
+
- file_list_state: Gradio dropdown state that holds a list of file paths or a
|
| 294 |
+
single path/string. If blank/empty, no action is taken.
|
| 295 |
+
- s3_output_folder_state_value: Final S3 key prefix (including any session hash)
|
| 296 |
+
to use as the destination folder for uploads.
|
| 297 |
+
- s3_bucket: Name of the S3 bucket.
|
| 298 |
+
"""
|
| 299 |
+
try:
|
| 300 |
+
|
| 301 |
+
# Respect the runtime toggle as well as environment configuration
|
| 302 |
+
if not save_outputs_to_s3_flag:
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
if not s3_output_folder_state_value:
|
| 306 |
+
# No configured S3 outputs folder β nothing to do
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
# Normalise input to a Python list of strings
|
| 310 |
+
file_paths = file_list_state
|
| 311 |
+
if not file_paths:
|
| 312 |
+
return
|
| 313 |
+
|
| 314 |
+
# Gradio dropdown may return a single string or a list
|
| 315 |
+
if isinstance(file_paths, str):
|
| 316 |
+
file_paths = [file_paths]
|
| 317 |
+
|
| 318 |
+
# Filter out any non-truthy values
|
| 319 |
+
file_paths = [p for p in file_paths if p]
|
| 320 |
+
if not file_paths:
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
# Derive a base file stem (name without extension) from the original
|
| 324 |
+
# file(s) being analysed, if provided. This is used to create an
|
| 325 |
+
# additional subfolder layer so that outputs are grouped under the
|
| 326 |
+
# analysed file name rather than under each output file name.
|
| 327 |
+
base_stem = None
|
| 328 |
+
if base_file_state:
|
| 329 |
+
base_path = None
|
| 330 |
+
|
| 331 |
+
# Gradio File components typically provide a list of objects with a `.name` attribute
|
| 332 |
+
if isinstance(base_file_state, str):
|
| 333 |
+
base_path = base_file_state
|
| 334 |
+
elif isinstance(base_file_state, list) and base_file_state:
|
| 335 |
+
first_item = base_file_state[0]
|
| 336 |
+
base_path = getattr(first_item, "name", None) or str(first_item)
|
| 337 |
+
else:
|
| 338 |
+
base_path = getattr(base_file_state, "name", None) or str(
|
| 339 |
+
base_file_state
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if base_path:
|
| 343 |
+
base_name = os.path.basename(base_path)
|
| 344 |
+
base_stem, _ = os.path.splitext(base_name)
|
| 345 |
+
|
| 346 |
+
# Ensure base S3 prefix (session/date) ends with a trailing slash
|
| 347 |
+
base_prefix = s3_output_folder_state_value
|
| 348 |
+
if not base_prefix.endswith("/"):
|
| 349 |
+
base_prefix = base_prefix + "/"
|
| 350 |
+
|
| 351 |
+
# For each file, append a subfolder. If we have a derived base_stem
|
| 352 |
+
# from the input being analysed, use that; otherwise, fall back to
|
| 353 |
+
# the individual output file name stem. Final pattern:
|
| 354 |
+
# <session_output_folder>/<date>/<base_file_stem>/<file_name>
|
| 355 |
+
# or, if base_file_stem is not available:
|
| 356 |
+
# <session_output_folder>/<date>/<output_file_stem>/<file_name>
|
| 357 |
+
for file in file_paths:
|
| 358 |
+
file_name = os.path.basename(file)
|
| 359 |
+
|
| 360 |
+
if base_stem:
|
| 361 |
+
folder_stem = base_stem
|
| 362 |
+
else:
|
| 363 |
+
folder_stem, _ = os.path.splitext(file_name)
|
| 364 |
+
|
| 365 |
+
per_file_prefix = base_prefix + folder_stem + "/"
|
| 366 |
+
|
| 367 |
+
out_message = upload_file_to_s3(
|
| 368 |
+
local_file_paths=[file],
|
| 369 |
+
s3_key=per_file_prefix,
|
| 370 |
+
s3_bucket=s3_bucket,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# Log any issues to console so failures are visible in logs/stdout
|
| 374 |
+
if (
|
| 375 |
+
"Error uploading file" in out_message
|
| 376 |
+
or "could not upload" in out_message.lower()
|
| 377 |
+
):
|
| 378 |
+
print("export_outputs_to_s3 encountered issues:", out_message)
|
| 379 |
+
|
| 380 |
+
print("Successfully uploaded outputs to S3")
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
# Do not break the app flow if S3 upload fails β just report to console
|
| 384 |
+
print(f"export_outputs_to_s3 failed with error: {e}")
|
| 385 |
+
|
| 386 |
+
# No GUI outputs to update
|
| 387 |
+
return
|
tools/combine_sheets_into_xlsx.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import date, datetime
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from openpyxl import Workbook
|
| 7 |
+
from openpyxl.styles import Alignment, Font
|
| 8 |
+
from openpyxl.utils import get_column_letter
|
| 9 |
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
| 10 |
+
|
| 11 |
+
from tools.config import OUTPUT_FOLDER
|
| 12 |
+
from tools.config import model_name_map as global_model_name_map
|
| 13 |
+
from tools.helper_functions import (
|
| 14 |
+
clean_column_name,
|
| 15 |
+
convert_reference_table_to_pivot_table,
|
| 16 |
+
ensure_model_in_map,
|
| 17 |
+
get_basic_response_data,
|
| 18 |
+
load_in_data_file,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def add_cover_sheet(
|
| 23 |
+
wb: Workbook,
|
| 24 |
+
intro_paragraphs: list[str],
|
| 25 |
+
model_name: str,
|
| 26 |
+
analysis_date: str,
|
| 27 |
+
analysis_cost: str,
|
| 28 |
+
number_of_responses: int,
|
| 29 |
+
number_of_responses_with_text: int,
|
| 30 |
+
number_of_responses_with_text_five_plus_words: int,
|
| 31 |
+
llm_call_number: int,
|
| 32 |
+
input_tokens: int,
|
| 33 |
+
output_tokens: int,
|
| 34 |
+
time_taken: float,
|
| 35 |
+
file_name: str,
|
| 36 |
+
column_name: str,
|
| 37 |
+
number_of_responses_with_topic_assignment: int,
|
| 38 |
+
custom_title: str = "Cover Sheet",
|
| 39 |
+
):
|
| 40 |
+
ws = wb.create_sheet(title=custom_title, index=0)
|
| 41 |
+
|
| 42 |
+
# Freeze top row
|
| 43 |
+
ws.freeze_panes = "A2"
|
| 44 |
+
|
| 45 |
+
# Write title
|
| 46 |
+
ws["A1"] = "Large Language Model Topic analysis"
|
| 47 |
+
ws["A1"].font = Font(size=14, bold=True)
|
| 48 |
+
ws["A1"].alignment = Alignment(wrap_text=True, vertical="top")
|
| 49 |
+
|
| 50 |
+
# Add intro paragraphs
|
| 51 |
+
row = 3
|
| 52 |
+
for paragraph in intro_paragraphs:
|
| 53 |
+
ws.merge_cells(start_row=row, start_column=1, end_row=row, end_column=2)
|
| 54 |
+
cell = ws.cell(row=row, column=1, value=paragraph)
|
| 55 |
+
cell.alignment = Alignment(wrap_text=True, vertical="top")
|
| 56 |
+
ws.row_dimensions[row].height = 60 # Adjust height as needed
|
| 57 |
+
row += 2
|
| 58 |
+
|
| 59 |
+
# Add metadata
|
| 60 |
+
meta_start = row + 1
|
| 61 |
+
metadata = {
|
| 62 |
+
"Date Excel file created": date.today().strftime("%Y-%m-%d"),
|
| 63 |
+
"File name": file_name,
|
| 64 |
+
"Column name": column_name,
|
| 65 |
+
"Model name": model_name,
|
| 66 |
+
"Analysis date": analysis_date,
|
| 67 |
+
# "Analysis cost": analysis_cost,
|
| 68 |
+
"Number of responses": number_of_responses,
|
| 69 |
+
"Number of responses with text": number_of_responses_with_text,
|
| 70 |
+
"Number of responses with text five plus words": number_of_responses_with_text_five_plus_words,
|
| 71 |
+
"Number of responses with at least one assigned topic": number_of_responses_with_topic_assignment,
|
| 72 |
+
"Number of LLM calls": llm_call_number,
|
| 73 |
+
"Total number of input tokens from LLM calls": input_tokens,
|
| 74 |
+
"Total number of output tokens from LLM calls": output_tokens,
|
| 75 |
+
"Total time taken for all LLM calls (seconds)": round(float(time_taken), 1),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
for i, (label, value) in enumerate(metadata.items()):
|
| 79 |
+
row_num = meta_start + i
|
| 80 |
+
ws[f"A{row_num}"] = label
|
| 81 |
+
ws[f"A{row_num}"].font = Font(bold=True)
|
| 82 |
+
|
| 83 |
+
cell = ws[f"B{row_num}"]
|
| 84 |
+
cell.value = value
|
| 85 |
+
cell.alignment = Alignment(wrap_text=True)
|
| 86 |
+
# Optional: Adjust column widths
|
| 87 |
+
ws.column_dimensions["A"].width = 25
|
| 88 |
+
ws.column_dimensions["B"].width = 75
|
| 89 |
+
|
| 90 |
+
# Ensure first row cells are wrapped on the cover sheet
|
| 91 |
+
for col_idx in range(1, ws.max_column + 1):
|
| 92 |
+
header_cell = ws.cell(row=1, column=col_idx)
|
| 93 |
+
header_cell.alignment = Alignment(wrap_text=True, vertical="center")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def csvs_to_excel(
|
| 97 |
+
csv_files: list[str],
|
| 98 |
+
output_filename: str,
|
| 99 |
+
sheet_names: list[str] = None,
|
| 100 |
+
column_widths: dict = None, # Dict of {sheet_name: {col_letter: width}}
|
| 101 |
+
wrap_text_columns: dict = None, # Dict of {sheet_name: [col_letters]}
|
| 102 |
+
intro_text: list[str] = None,
|
| 103 |
+
model_name: str = "",
|
| 104 |
+
analysis_date: str = "",
|
| 105 |
+
analysis_cost: str = "",
|
| 106 |
+
llm_call_number: int = 0,
|
| 107 |
+
input_tokens: int = 0,
|
| 108 |
+
output_tokens: int = 0,
|
| 109 |
+
time_taken: float = 0,
|
| 110 |
+
number_of_responses: int = 0,
|
| 111 |
+
number_of_responses_with_text: int = 0,
|
| 112 |
+
number_of_responses_with_text_five_plus_words: int = 0,
|
| 113 |
+
column_name: str = "",
|
| 114 |
+
number_of_responses_with_topic_assignment: int = 0,
|
| 115 |
+
file_name: str = "",
|
| 116 |
+
unique_reference_numbers: list = [],
|
| 117 |
+
):
|
| 118 |
+
if intro_text is None:
|
| 119 |
+
intro_text = list()
|
| 120 |
+
|
| 121 |
+
wb = Workbook()
|
| 122 |
+
# Remove default sheet
|
| 123 |
+
wb.remove(wb.active)
|
| 124 |
+
|
| 125 |
+
for idx, csv_path in enumerate(csv_files):
|
| 126 |
+
# Use provided sheet name or derive from file name
|
| 127 |
+
sheet_name = (
|
| 128 |
+
sheet_names[idx]
|
| 129 |
+
if sheet_names and idx < len(sheet_names)
|
| 130 |
+
else os.path.splitext(os.path.basename(csv_path))[0]
|
| 131 |
+
)
|
| 132 |
+
df = pd.read_csv(csv_path)
|
| 133 |
+
|
| 134 |
+
if sheet_name == "Original data":
|
| 135 |
+
try:
|
| 136 |
+
# Create a copy to avoid modifying the original
|
| 137 |
+
df_copy = df.copy()
|
| 138 |
+
# Insert the Reference column at position 0 (first column)
|
| 139 |
+
df_copy.insert(0, "Reference", unique_reference_numbers)
|
| 140 |
+
df = df_copy
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print("Could not add reference number to original data due to:", e)
|
| 143 |
+
|
| 144 |
+
ws = wb.create_sheet(title=sheet_name)
|
| 145 |
+
|
| 146 |
+
for r_idx, row in enumerate(
|
| 147 |
+
dataframe_to_rows(df, index=False, header=True), start=1
|
| 148 |
+
):
|
| 149 |
+
ws.append(row)
|
| 150 |
+
|
| 151 |
+
for col_idx, value in enumerate(row, start=1):
|
| 152 |
+
cell = ws.cell(row=r_idx, column=col_idx)
|
| 153 |
+
|
| 154 |
+
# Bold header row
|
| 155 |
+
if r_idx == 1:
|
| 156 |
+
cell.font = Font(bold=True)
|
| 157 |
+
|
| 158 |
+
# Set vertical alignment to middle by default
|
| 159 |
+
cell.alignment = Alignment(vertical="center")
|
| 160 |
+
|
| 161 |
+
# Apply wrap text if needed
|
| 162 |
+
if wrap_text_columns and sheet_name in wrap_text_columns:
|
| 163 |
+
for col_letter in wrap_text_columns[sheet_name]:
|
| 164 |
+
cell = ws[f"{col_letter}{r_idx}"]
|
| 165 |
+
cell.alignment = Alignment(vertical="center", wrap_text=True)
|
| 166 |
+
|
| 167 |
+
# Freeze top row for all data sheets
|
| 168 |
+
ws.freeze_panes = "A2"
|
| 169 |
+
|
| 170 |
+
# Ensure all header cells (first row) are wrapped
|
| 171 |
+
for col_idx in range(1, ws.max_column + 1):
|
| 172 |
+
header_cell = ws.cell(row=1, column=col_idx)
|
| 173 |
+
header_cell.alignment = Alignment(vertical="center", wrap_text=True)
|
| 174 |
+
|
| 175 |
+
# Set column widths
|
| 176 |
+
if column_widths and sheet_name in column_widths:
|
| 177 |
+
for col_letter, width in column_widths[sheet_name].items():
|
| 178 |
+
ws.column_dimensions[col_letter].width = width
|
| 179 |
+
|
| 180 |
+
add_cover_sheet(
|
| 181 |
+
wb,
|
| 182 |
+
intro_paragraphs=intro_text,
|
| 183 |
+
model_name=model_name,
|
| 184 |
+
analysis_date=analysis_date,
|
| 185 |
+
analysis_cost=analysis_cost,
|
| 186 |
+
number_of_responses=number_of_responses,
|
| 187 |
+
number_of_responses_with_text=number_of_responses_with_text,
|
| 188 |
+
number_of_responses_with_text_five_plus_words=number_of_responses_with_text_five_plus_words,
|
| 189 |
+
llm_call_number=llm_call_number,
|
| 190 |
+
input_tokens=input_tokens,
|
| 191 |
+
output_tokens=output_tokens,
|
| 192 |
+
time_taken=time_taken,
|
| 193 |
+
file_name=file_name,
|
| 194 |
+
column_name=column_name,
|
| 195 |
+
number_of_responses_with_topic_assignment=number_of_responses_with_topic_assignment,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
wb.save(output_filename)
|
| 199 |
+
|
| 200 |
+
print(f"Output xlsx summary saved as '{output_filename}'")
|
| 201 |
+
|
| 202 |
+
return output_filename
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
###
|
| 206 |
+
# Run the functions
|
| 207 |
+
###
|
| 208 |
+
def collect_output_csvs_and_create_excel_output(
|
| 209 |
+
in_data_files: List,
|
| 210 |
+
chosen_cols: list[str],
|
| 211 |
+
reference_data_file_name_textbox: str,
|
| 212 |
+
in_group_col: str,
|
| 213 |
+
model_choice: str,
|
| 214 |
+
master_reference_df_state: pd.DataFrame,
|
| 215 |
+
master_unique_topics_df_state: pd.DataFrame,
|
| 216 |
+
summarised_output_df: pd.DataFrame,
|
| 217 |
+
missing_df_state: pd.DataFrame,
|
| 218 |
+
excel_sheets: str = "",
|
| 219 |
+
usage_logs_location: str = "",
|
| 220 |
+
model_name_map: dict = dict(),
|
| 221 |
+
output_folder: str = OUTPUT_FOLDER,
|
| 222 |
+
structured_summaries: str = "No",
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
Collect together output CSVs from various output boxes and combine them into a single output Excel file.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
in_data_files (List): A list of paths to the input data files.
|
| 229 |
+
chosen_cols (list[str]): A list of column names selected for analysis.
|
| 230 |
+
reference_data_file_name_textbox (str): The name of the reference data file.
|
| 231 |
+
in_group_col (str): The column used for grouping the data.
|
| 232 |
+
model_choice (str): The LLM model chosen for the analysis.
|
| 233 |
+
master_reference_df_state (pd.DataFrame): The master DataFrame containing reference data.
|
| 234 |
+
master_unique_topics_df_state (pd.DataFrame): The master DataFrame containing unique topics data.
|
| 235 |
+
summarised_output_df (pd.DataFrame): DataFrame containing the summarised output.
|
| 236 |
+
missing_df_state (pd.DataFrame): DataFrame containing information about missing data.
|
| 237 |
+
excel_sheets (str): Information regarding Excel sheets, typically sheet names or structure.
|
| 238 |
+
usage_logs_location (str, optional): Path to the usage logs CSV file. Defaults to "".
|
| 239 |
+
model_name_map (dict, optional): A dictionary mapping model choices to their display names. Defaults to {}.
|
| 240 |
+
output_folder (str, optional): The directory where the output Excel file will be saved. Defaults to OUTPUT_FOLDER.
|
| 241 |
+
structured_summaries (str, optional): Indicates whether structured summaries are being produced ("Yes" or "No"). Defaults to "No".
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
tuple: A tuple containing:
|
| 245 |
+
- list: A list of paths to the generated Excel output files.
|
| 246 |
+
- list: A duplicate of the list of paths to the generated Excel output files (for UI compatibility).
|
| 247 |
+
"""
|
| 248 |
+
# Use passed model_name_map if provided and not empty, otherwise use global one
|
| 249 |
+
if not model_name_map:
|
| 250 |
+
model_name_map = global_model_name_map
|
| 251 |
+
|
| 252 |
+
# Ensure custom model_choice is registered in model_name_map
|
| 253 |
+
ensure_model_in_map(model_choice, model_name_map)
|
| 254 |
+
|
| 255 |
+
if structured_summaries == "Yes":
|
| 256 |
+
structured_summaries = True
|
| 257 |
+
else:
|
| 258 |
+
structured_summaries = False
|
| 259 |
+
|
| 260 |
+
if not chosen_cols:
|
| 261 |
+
raise Exception("Could not find chosen column")
|
| 262 |
+
|
| 263 |
+
today_date = datetime.today().strftime("%Y-%m-%d")
|
| 264 |
+
original_data_file_path = os.path.abspath(in_data_files[0])
|
| 265 |
+
|
| 266 |
+
csv_files = list()
|
| 267 |
+
sheet_names = list()
|
| 268 |
+
column_widths = dict()
|
| 269 |
+
wrap_text_columns = dict()
|
| 270 |
+
short_file_name = os.path.basename(reference_data_file_name_textbox)
|
| 271 |
+
reference_pivot_table = pd.DataFrame()
|
| 272 |
+
reference_table_csv_path = ""
|
| 273 |
+
reference_pivot_table_csv_path = ""
|
| 274 |
+
unique_topic_table_csv_path = ""
|
| 275 |
+
missing_df_state_csv_path = ""
|
| 276 |
+
overall_summary_csv_path = ""
|
| 277 |
+
number_of_responses_with_topic_assignment = 0
|
| 278 |
+
|
| 279 |
+
if in_group_col:
|
| 280 |
+
group = in_group_col
|
| 281 |
+
else:
|
| 282 |
+
group = "All"
|
| 283 |
+
|
| 284 |
+
overall_summary_csv_path = output_folder + "overall_summary_for_xlsx.csv"
|
| 285 |
+
|
| 286 |
+
if structured_summaries is True and not master_unique_topics_df_state.empty:
|
| 287 |
+
print("Producing overall summary based on structured summaries.")
|
| 288 |
+
# Create structured summary from master_unique_topics_df_state
|
| 289 |
+
structured_summary_data = list()
|
| 290 |
+
|
| 291 |
+
# Group by 'Group' column
|
| 292 |
+
for group_name, group_df in master_unique_topics_df_state.groupby("Group"):
|
| 293 |
+
group_summary = f"## {group_name}\n\n"
|
| 294 |
+
|
| 295 |
+
# Group by 'General topic' within each group
|
| 296 |
+
for general_topic, topic_df in group_df.groupby("General topic"):
|
| 297 |
+
group_summary += f"### {general_topic}\n\n"
|
| 298 |
+
|
| 299 |
+
# Add subtopics under each general topic
|
| 300 |
+
for _, row in topic_df.iterrows():
|
| 301 |
+
subtopic = row["Subtopic"]
|
| 302 |
+
summary = row["Summary"]
|
| 303 |
+
# sentiment = row.get('Sentiment', '')
|
| 304 |
+
# num_responses = row.get('Number of responses', '')
|
| 305 |
+
|
| 306 |
+
# Create subtopic entry
|
| 307 |
+
subtopic_entry = f"**{subtopic}**"
|
| 308 |
+
# if sentiment:
|
| 309 |
+
# subtopic_entry += f" ({sentiment})"
|
| 310 |
+
# if num_responses:
|
| 311 |
+
# subtopic_entry += f" - {num_responses} responses"
|
| 312 |
+
subtopic_entry += "\n\n"
|
| 313 |
+
|
| 314 |
+
if summary and pd.notna(summary):
|
| 315 |
+
subtopic_entry += f"{summary}\n\n"
|
| 316 |
+
|
| 317 |
+
group_summary += subtopic_entry
|
| 318 |
+
|
| 319 |
+
# Add to structured summary data
|
| 320 |
+
structured_summary_data.append(
|
| 321 |
+
{"Group": group_name, "Summary": group_summary.strip()}
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Create DataFrame for structured summary
|
| 325 |
+
structured_summary_df = pd.DataFrame(structured_summary_data)
|
| 326 |
+
structured_summary_df.to_csv(overall_summary_csv_path, index=False)
|
| 327 |
+
else:
|
| 328 |
+
# Use original summarised_output_df
|
| 329 |
+
structured_summary_df = summarised_output_df
|
| 330 |
+
structured_summary_df.to_csv(overall_summary_csv_path, index=None)
|
| 331 |
+
|
| 332 |
+
if not structured_summary_df.empty:
|
| 333 |
+
csv_files.append(overall_summary_csv_path)
|
| 334 |
+
sheet_names.append("Overall summary")
|
| 335 |
+
column_widths["Overall summary"] = {"A": 20, "B": 100}
|
| 336 |
+
wrap_text_columns["Overall summary"] = ["B"]
|
| 337 |
+
|
| 338 |
+
if not master_reference_df_state.empty:
|
| 339 |
+
# Simplify table to just responses column and the Response reference number
|
| 340 |
+
file_data, file_name, num_batches = load_in_data_file(
|
| 341 |
+
in_data_files, chosen_cols, 1, in_excel_sheets=excel_sheets
|
| 342 |
+
)
|
| 343 |
+
basic_response_data = get_basic_response_data(
|
| 344 |
+
file_data, chosen_cols, verify_titles="No"
|
| 345 |
+
)
|
| 346 |
+
reference_pivot_table = convert_reference_table_to_pivot_table(
|
| 347 |
+
master_reference_df_state, basic_response_data
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
unique_reference_numbers = basic_response_data["Reference"].tolist()
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
master_reference_df_state.rename(
|
| 354 |
+
columns={"Topic_number": "Topic number"}, inplace=True, errors="ignore"
|
| 355 |
+
)
|
| 356 |
+
master_reference_df_state.drop(
|
| 357 |
+
columns=["1", "2", "3"], inplace=True, errors="ignore"
|
| 358 |
+
)
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print("Could not rename Topic_number due to", e)
|
| 361 |
+
|
| 362 |
+
number_of_responses_with_topic_assignment = len(
|
| 363 |
+
master_reference_df_state["Response References"].unique()
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
reference_table_csv_path = output_folder + "reference_df_for_xlsx.csv"
|
| 367 |
+
master_reference_df_state.to_csv(reference_table_csv_path, index=None)
|
| 368 |
+
|
| 369 |
+
reference_pivot_table_csv_path = (
|
| 370 |
+
output_folder + "reference_pivot_df_for_xlsx.csv"
|
| 371 |
+
)
|
| 372 |
+
reference_pivot_table.to_csv(reference_pivot_table_csv_path, index=None)
|
| 373 |
+
|
| 374 |
+
short_file_name = os.path.basename(file_name)
|
| 375 |
+
|
| 376 |
+
if not master_unique_topics_df_state.empty:
|
| 377 |
+
|
| 378 |
+
master_unique_topics_df_state.drop(
|
| 379 |
+
columns=["1", "2", "3"], inplace=True, errors="ignore"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
unique_topic_table_csv_path = (
|
| 383 |
+
output_folder + "unique_topic_table_df_for_xlsx.csv"
|
| 384 |
+
)
|
| 385 |
+
master_unique_topics_df_state.to_csv(unique_topic_table_csv_path, index=None)
|
| 386 |
+
|
| 387 |
+
if unique_topic_table_csv_path:
|
| 388 |
+
csv_files.append(unique_topic_table_csv_path)
|
| 389 |
+
sheet_names.append("Topic summary")
|
| 390 |
+
column_widths["Topic summary"] = {"A": 25, "B": 25, "C": 15, "D": 15, "F": 100}
|
| 391 |
+
wrap_text_columns["Topic summary"] = ["B", "F"]
|
| 392 |
+
else:
|
| 393 |
+
print("Relevant unique topic files not found, excluding from xlsx output.")
|
| 394 |
+
|
| 395 |
+
if reference_table_csv_path:
|
| 396 |
+
if structured_summaries:
|
| 397 |
+
print(
|
| 398 |
+
"Structured summaries are being produced, excluding response level data from xlsx output."
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
csv_files.append(reference_table_csv_path)
|
| 402 |
+
sheet_names.append("Response level data")
|
| 403 |
+
column_widths["Response level data"] = {"A": 15, "B": 30, "C": 40, "H": 100}
|
| 404 |
+
wrap_text_columns["Response level data"] = ["C", "G"]
|
| 405 |
+
else:
|
| 406 |
+
print("Relevant reference files not found, excluding from xlsx output.")
|
| 407 |
+
|
| 408 |
+
if reference_pivot_table_csv_path:
|
| 409 |
+
if structured_summaries:
|
| 410 |
+
print(
|
| 411 |
+
"Structured summaries are being produced, excluding topic response pivot table from xlsx output."
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
csv_files.append(reference_pivot_table_csv_path)
|
| 415 |
+
sheet_names.append("Topic response pivot table")
|
| 416 |
+
|
| 417 |
+
if reference_pivot_table.empty:
|
| 418 |
+
reference_pivot_table = pd.read_csv(reference_pivot_table_csv_path)
|
| 419 |
+
|
| 420 |
+
# Base widths and wrap
|
| 421 |
+
column_widths["Topic response pivot table"] = {"A": 25, "B": 100}
|
| 422 |
+
wrap_text_columns["Topic response pivot table"] = ["B"]
|
| 423 |
+
|
| 424 |
+
num_cols = len(reference_pivot_table.columns)
|
| 425 |
+
col_letters = [get_column_letter(i) for i in range(3, num_cols + 1)]
|
| 426 |
+
|
| 427 |
+
for col_letter in col_letters:
|
| 428 |
+
column_widths["Topic response pivot table"][col_letter] = 25
|
| 429 |
+
|
| 430 |
+
wrap_text_columns["Topic response pivot table"].extend(col_letters)
|
| 431 |
+
else:
|
| 432 |
+
print(
|
| 433 |
+
"Relevant reference pivot table files not found, excluding from xlsx output."
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if not missing_df_state.empty:
|
| 437 |
+
missing_df_state_csv_path = output_folder + "missing_df_state_df_for_xlsx.csv"
|
| 438 |
+
missing_df_state.to_csv(missing_df_state_csv_path, index=None)
|
| 439 |
+
|
| 440 |
+
if missing_df_state_csv_path:
|
| 441 |
+
if structured_summaries:
|
| 442 |
+
print(
|
| 443 |
+
"Structured summaries are being produced, excluding missing responses from xlsx output."
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
csv_files.append(missing_df_state_csv_path)
|
| 447 |
+
sheet_names.append("Missing responses")
|
| 448 |
+
column_widths["Missing responses"] = {"A": 25, "B": 30, "C": 50}
|
| 449 |
+
wrap_text_columns["Missing responses"] = ["C"]
|
| 450 |
+
else:
|
| 451 |
+
print("Relevant missing responses files not found, excluding from xlsx output.")
|
| 452 |
+
|
| 453 |
+
new_csv_files = csv_files.copy()
|
| 454 |
+
|
| 455 |
+
# Original data file
|
| 456 |
+
original_ext = os.path.splitext(original_data_file_path)[1].lower()
|
| 457 |
+
if original_ext == ".csv":
|
| 458 |
+
csv_files.append(original_data_file_path)
|
| 459 |
+
else:
|
| 460 |
+
# Read and convert to CSV
|
| 461 |
+
if original_ext == ".xlsx":
|
| 462 |
+
if excel_sheets:
|
| 463 |
+
df = pd.read_excel(original_data_file_path, sheet_name=excel_sheets)
|
| 464 |
+
else:
|
| 465 |
+
df = pd.read_excel(original_data_file_path)
|
| 466 |
+
elif original_ext == ".parquet":
|
| 467 |
+
df = pd.read_parquet(original_data_file_path)
|
| 468 |
+
else:
|
| 469 |
+
raise Exception(f"Unsupported file type for original data: {original_ext}")
|
| 470 |
+
|
| 471 |
+
# Save as CSV in output folder
|
| 472 |
+
original_data_csv_path = os.path.join(
|
| 473 |
+
output_folder,
|
| 474 |
+
os.path.splitext(os.path.basename(original_data_file_path))[0]
|
| 475 |
+
+ "_for_xlsx.csv",
|
| 476 |
+
)
|
| 477 |
+
df.to_csv(original_data_csv_path, index=False)
|
| 478 |
+
csv_files.append(original_data_csv_path)
|
| 479 |
+
|
| 480 |
+
sheet_names.append("Original data")
|
| 481 |
+
column_widths["Original data"] = {"A": 20, "B": 20, "C": 20}
|
| 482 |
+
wrap_text_columns["Original data"] = ["C"]
|
| 483 |
+
if isinstance(chosen_cols, list) and chosen_cols:
|
| 484 |
+
chosen_cols = chosen_cols[0]
|
| 485 |
+
else:
|
| 486 |
+
chosen_cols = str(chosen_cols) if chosen_cols else ""
|
| 487 |
+
|
| 488 |
+
# Intro page text
|
| 489 |
+
intro_text = [
|
| 490 |
+
"This workbook contains outputs from the large language model topic analysis of open text data. Each sheet corresponds to a different CSV report included in the analysis.",
|
| 491 |
+
f"The file analysed was {short_file_name}, the column analysed was '{chosen_cols}' and the data was grouped by column '{group}'."
|
| 492 |
+
" Please contact the LLM Topic Modelling app administrator if you need any explanation on how to use the results."
|
| 493 |
+
"Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this analysis **need to be checked by a human** to check for harmful outputs, false information, and bias.",
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
# Get values for number of rows, number of responses, and number of responses longer than five words
|
| 497 |
+
number_of_responses = basic_response_data.shape[0]
|
| 498 |
+
# number_of_responses_with_text = basic_response_data["Response"].str.strip().notnull().sum()
|
| 499 |
+
number_of_responses_with_text = (
|
| 500 |
+
basic_response_data["Response"].str.strip().notnull()
|
| 501 |
+
& (basic_response_data["Response"].str.split().str.len() >= 1)
|
| 502 |
+
).sum()
|
| 503 |
+
number_of_responses_with_text_five_plus_words = (
|
| 504 |
+
basic_response_data["Response"].str.strip().notnull()
|
| 505 |
+
& (basic_response_data["Response"].str.split().str.len() >= 5)
|
| 506 |
+
).sum()
|
| 507 |
+
|
| 508 |
+
# Get number of LLM calls, input and output tokens
|
| 509 |
+
if usage_logs_location:
|
| 510 |
+
try:
|
| 511 |
+
usage_logs = pd.read_csv(usage_logs_location)
|
| 512 |
+
|
| 513 |
+
relevant_logs = usage_logs.loc[
|
| 514 |
+
(
|
| 515 |
+
usage_logs["Reference data file name"]
|
| 516 |
+
== reference_data_file_name_textbox
|
| 517 |
+
)
|
| 518 |
+
& (
|
| 519 |
+
usage_logs[
|
| 520 |
+
"Large language model for topic extraction and summarisation"
|
| 521 |
+
]
|
| 522 |
+
== model_choice
|
| 523 |
+
)
|
| 524 |
+
& (
|
| 525 |
+
usage_logs[
|
| 526 |
+
"Select the open text column of interest. In an Excel file, this shows columns across all sheets."
|
| 527 |
+
]
|
| 528 |
+
== (
|
| 529 |
+
chosen_cols[0]
|
| 530 |
+
if isinstance(chosen_cols, list) and chosen_cols
|
| 531 |
+
else chosen_cols
|
| 532 |
+
)
|
| 533 |
+
),
|
| 534 |
+
:,
|
| 535 |
+
]
|
| 536 |
+
|
| 537 |
+
llm_call_number = sum(relevant_logs["Total LLM calls"].astype(int))
|
| 538 |
+
input_tokens = sum(relevant_logs["Total input tokens"].astype(int))
|
| 539 |
+
output_tokens = sum(relevant_logs["Total output tokens"].astype(int))
|
| 540 |
+
time_taken = sum(
|
| 541 |
+
relevant_logs["Estimated time taken (seconds)"].astype(float)
|
| 542 |
+
)
|
| 543 |
+
except Exception as e:
|
| 544 |
+
print("Could not obtain usage logs due to:", e)
|
| 545 |
+
usage_logs = pd.DataFrame()
|
| 546 |
+
llm_call_number = 0
|
| 547 |
+
input_tokens = 0
|
| 548 |
+
output_tokens = 0
|
| 549 |
+
time_taken = 0
|
| 550 |
+
else:
|
| 551 |
+
print("LLM call logs location not provided")
|
| 552 |
+
usage_logs = pd.DataFrame()
|
| 553 |
+
llm_call_number = 0
|
| 554 |
+
input_tokens = 0
|
| 555 |
+
output_tokens = 0
|
| 556 |
+
time_taken = 0
|
| 557 |
+
|
| 558 |
+
# Create short filename:
|
| 559 |
+
model_choice_clean_short = clean_column_name(
|
| 560 |
+
model_name_map[model_choice]["short_name"],
|
| 561 |
+
max_length=20,
|
| 562 |
+
front_characters=False,
|
| 563 |
+
)
|
| 564 |
+
# Extract first column name as string for cleaning and Excel output
|
| 565 |
+
chosen_col_str = (
|
| 566 |
+
chosen_cols[0]
|
| 567 |
+
if isinstance(chosen_cols, list) and chosen_cols
|
| 568 |
+
else str(chosen_cols) if chosen_cols else ""
|
| 569 |
+
)
|
| 570 |
+
in_column_cleaned = clean_column_name(chosen_col_str, max_length=20)
|
| 571 |
+
file_name_cleaned = clean_column_name(
|
| 572 |
+
file_name, max_length=20, front_characters=True
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# Save outputs for each batch. If master file created, label file as master
|
| 576 |
+
file_path_details = (
|
| 577 |
+
f"{file_name_cleaned}_col_{in_column_cleaned}_{model_choice_clean_short}"
|
| 578 |
+
)
|
| 579 |
+
output_xlsx_filename = (
|
| 580 |
+
output_folder
|
| 581 |
+
+ file_path_details
|
| 582 |
+
+ ("_structured_summaries" if structured_summaries else "_topic_analysis")
|
| 583 |
+
+ ".xlsx"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
xlsx_output_filename = csvs_to_excel(
|
| 587 |
+
csv_files=csv_files,
|
| 588 |
+
output_filename=output_xlsx_filename,
|
| 589 |
+
sheet_names=sheet_names,
|
| 590 |
+
column_widths=column_widths,
|
| 591 |
+
wrap_text_columns=wrap_text_columns,
|
| 592 |
+
intro_text=intro_text,
|
| 593 |
+
model_name=model_choice,
|
| 594 |
+
analysis_date=today_date,
|
| 595 |
+
analysis_cost="",
|
| 596 |
+
llm_call_number=llm_call_number,
|
| 597 |
+
input_tokens=input_tokens,
|
| 598 |
+
output_tokens=output_tokens,
|
| 599 |
+
time_taken=time_taken,
|
| 600 |
+
number_of_responses=number_of_responses,
|
| 601 |
+
number_of_responses_with_text=number_of_responses_with_text,
|
| 602 |
+
number_of_responses_with_text_five_plus_words=number_of_responses_with_text_five_plus_words,
|
| 603 |
+
column_name=chosen_col_str,
|
| 604 |
+
number_of_responses_with_topic_assignment=number_of_responses_with_topic_assignment,
|
| 605 |
+
file_name=short_file_name,
|
| 606 |
+
unique_reference_numbers=unique_reference_numbers,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
xlsx_output_filenames = [xlsx_output_filename]
|
| 610 |
+
|
| 611 |
+
# Delete intermediate csv files
|
| 612 |
+
for csv_file in new_csv_files:
|
| 613 |
+
os.remove(csv_file)
|
| 614 |
+
|
| 615 |
+
return xlsx_output_filenames, xlsx_output_filenames
|
tools/config.py
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import codecs
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import socket
|
| 5 |
+
import tempfile
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
| 12 |
+
HOST_NAME = socket.gethostname()
|
| 13 |
+
|
| 14 |
+
# Set or retrieve configuration variables for the redaction app
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_or_create_env_var(var_name: str, default_value: str, print_val: bool = False):
|
| 18 |
+
"""
|
| 19 |
+
Get an environmental variable, and set it to a default value if it doesn't exist
|
| 20 |
+
"""
|
| 21 |
+
# Get the environment variable if it exists
|
| 22 |
+
value = os.environ.get(var_name)
|
| 23 |
+
|
| 24 |
+
# If it doesn't exist, set the environment variable to the default value
|
| 25 |
+
if value is None:
|
| 26 |
+
os.environ[var_name] = default_value
|
| 27 |
+
value = default_value
|
| 28 |
+
|
| 29 |
+
if print_val is True:
|
| 30 |
+
print(f"The value of {var_name} is {value}")
|
| 31 |
+
|
| 32 |
+
return value
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def add_folder_to_path(folder_path: str):
|
| 36 |
+
"""
|
| 37 |
+
Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
| 41 |
+
print(folder_path, "folder exists.")
|
| 42 |
+
|
| 43 |
+
# Resolve relative path to absolute path
|
| 44 |
+
absolute_path = os.path.abspath(folder_path)
|
| 45 |
+
|
| 46 |
+
current_path = os.environ["PATH"]
|
| 47 |
+
if absolute_path not in current_path.split(os.pathsep):
|
| 48 |
+
full_path_extension = absolute_path + os.pathsep + current_path
|
| 49 |
+
os.environ["PATH"] = full_path_extension
|
| 50 |
+
# print(f"Updated PATH with: ", full_path_extension)
|
| 51 |
+
else:
|
| 52 |
+
print(f"Directory {folder_path} already exists in PATH.")
|
| 53 |
+
else:
|
| 54 |
+
print(f"Folder not found at {folder_path} - not added to PATH")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def convert_string_to_boolean(value: str) -> bool:
|
| 58 |
+
"""Convert string to boolean, handling various formats."""
|
| 59 |
+
if isinstance(value, bool):
|
| 60 |
+
return value
|
| 61 |
+
elif value in ["True", "1", "true", "TRUE"]:
|
| 62 |
+
return True
|
| 63 |
+
elif value in ["False", "0", "false", "FALSE"]:
|
| 64 |
+
return False
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Invalid boolean value: {value}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
###
|
| 70 |
+
# LOAD CONFIG FROM ENV FILE
|
| 71 |
+
###
|
| 72 |
+
|
| 73 |
+
CONFIG_FOLDER = get_or_create_env_var("CONFIG_FOLDER", "config/")
|
| 74 |
+
|
| 75 |
+
# If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
|
| 76 |
+
APP_CONFIG_PATH = get_or_create_env_var(
|
| 77 |
+
"APP_CONFIG_PATH", CONFIG_FOLDER + "app_config.env"
|
| 78 |
+
) # e.g. config/app_config.env
|
| 79 |
+
|
| 80 |
+
if APP_CONFIG_PATH:
|
| 81 |
+
if os.path.exists(APP_CONFIG_PATH):
|
| 82 |
+
print(f"Loading app variables from config file {APP_CONFIG_PATH}")
|
| 83 |
+
load_dotenv(APP_CONFIG_PATH)
|
| 84 |
+
else:
|
| 85 |
+
print("App config file not found at location:", APP_CONFIG_PATH)
|
| 86 |
+
|
| 87 |
+
###
|
| 88 |
+
# AWS OPTIONS
|
| 89 |
+
###
|
| 90 |
+
|
| 91 |
+
# If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
|
| 92 |
+
AWS_CONFIG_PATH = get_or_create_env_var(
|
| 93 |
+
"AWS_CONFIG_PATH", ""
|
| 94 |
+
) # e.g. config/aws_config.env
|
| 95 |
+
|
| 96 |
+
if AWS_CONFIG_PATH:
|
| 97 |
+
if os.path.exists(AWS_CONFIG_PATH):
|
| 98 |
+
print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
|
| 99 |
+
load_dotenv(AWS_CONFIG_PATH)
|
| 100 |
+
else:
|
| 101 |
+
print("AWS config file not found at location:", AWS_CONFIG_PATH)
|
| 102 |
+
|
| 103 |
+
RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
|
| 104 |
+
|
| 105 |
+
AWS_REGION = get_or_create_env_var("AWS_REGION", "")
|
| 106 |
+
|
| 107 |
+
AWS_CLIENT_ID = get_or_create_env_var("AWS_CLIENT_ID", "")
|
| 108 |
+
|
| 109 |
+
AWS_CLIENT_SECRET = get_or_create_env_var("AWS_CLIENT_SECRET", "")
|
| 110 |
+
|
| 111 |
+
AWS_USER_POOL_ID = get_or_create_env_var("AWS_USER_POOL_ID", "")
|
| 112 |
+
|
| 113 |
+
AWS_ACCESS_KEY = get_or_create_env_var("AWS_ACCESS_KEY", "")
|
| 114 |
+
# if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
|
| 115 |
+
|
| 116 |
+
AWS_SECRET_KEY = get_or_create_env_var("AWS_SECRET_KEY", "")
|
| 117 |
+
# if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
|
| 118 |
+
|
| 119 |
+
# Should the app prioritise using AWS SSO over using API keys stored in environment variables/secrets (defaults to yes)
|
| 120 |
+
PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS = get_or_create_env_var(
|
| 121 |
+
"PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS", "1"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
S3_LOG_BUCKET = get_or_create_env_var("S3_LOG_BUCKET", "")
|
| 125 |
+
|
| 126 |
+
# Custom headers e.g. if routing traffic through Cloudfront
|
| 127 |
+
# Retrieving or setting CUSTOM_HEADER
|
| 128 |
+
CUSTOM_HEADER = get_or_create_env_var("CUSTOM_HEADER", "")
|
| 129 |
+
|
| 130 |
+
# Retrieving or setting CUSTOM_HEADER_VALUE
|
| 131 |
+
CUSTOM_HEADER_VALUE = get_or_create_env_var("CUSTOM_HEADER_VALUE", "")
|
| 132 |
+
|
| 133 |
+
###
|
| 134 |
+
# File I/O
|
| 135 |
+
###
|
| 136 |
+
SESSION_OUTPUT_FOLDER = get_or_create_env_var(
|
| 137 |
+
"SESSION_OUTPUT_FOLDER", "False"
|
| 138 |
+
) # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
|
| 139 |
+
|
| 140 |
+
OUTPUT_FOLDER = get_or_create_env_var("GRADIO_OUTPUT_FOLDER", "output/") # 'output/'
|
| 141 |
+
INPUT_FOLDER = get_or_create_env_var("GRADIO_INPUT_FOLDER", "input/") # 'input/'
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# Allow for files to be saved in a temporary folder for increased security in some instances
|
| 145 |
+
if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
|
| 146 |
+
# Create a temporary directory
|
| 147 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 148 |
+
print(f"Temporary directory created at: {temp_dir}")
|
| 149 |
+
|
| 150 |
+
if OUTPUT_FOLDER == "TEMP":
|
| 151 |
+
OUTPUT_FOLDER = temp_dir + "/"
|
| 152 |
+
if INPUT_FOLDER == "TEMP":
|
| 153 |
+
INPUT_FOLDER = temp_dir + "/"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
GRADIO_TEMP_DIR = get_or_create_env_var(
|
| 157 |
+
"GRADIO_TEMP_DIR", "tmp/gradio_tmp/"
|
| 158 |
+
) # Default Gradio temp folder
|
| 159 |
+
MPLCONFIGDIR = get_or_create_env_var(
|
| 160 |
+
"MPLCONFIGDIR", "tmp/matplotlib_cache/"
|
| 161 |
+
) # Matplotlib cache folder
|
| 162 |
+
|
| 163 |
+
S3_OUTPUTS_BUCKET = get_or_create_env_var("S3_OUTPUTS_BUCKET", "")
|
| 164 |
+
S3_OUTPUTS_FOLDER = get_or_create_env_var("S3_OUTPUTS_FOLDER", "")
|
| 165 |
+
SAVE_OUTPUTS_TO_S3 = get_or_create_env_var("SAVE_OUTPUTS_TO_S3", "False")
|
| 166 |
+
|
| 167 |
+
###
|
| 168 |
+
# LOGGING OPTIONS
|
| 169 |
+
###
|
| 170 |
+
|
| 171 |
+
# By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
|
| 172 |
+
# Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
|
| 173 |
+
|
| 174 |
+
SAVE_LOGS_TO_CSV = get_or_create_env_var("SAVE_LOGS_TO_CSV", "True")
|
| 175 |
+
|
| 176 |
+
USE_LOG_SUBFOLDERS = get_or_create_env_var("USE_LOG_SUBFOLDERS", "True")
|
| 177 |
+
|
| 178 |
+
FEEDBACK_LOGS_FOLDER = get_or_create_env_var("FEEDBACK_LOGS_FOLDER", "feedback/")
|
| 179 |
+
ACCESS_LOGS_FOLDER = get_or_create_env_var("ACCESS_LOGS_FOLDER", "logs/")
|
| 180 |
+
USAGE_LOGS_FOLDER = get_or_create_env_var("USAGE_LOGS_FOLDER", "usage/")
|
| 181 |
+
|
| 182 |
+
# Initialize full_log_subfolder based on USE_LOG_SUBFOLDERS setting
|
| 183 |
+
if USE_LOG_SUBFOLDERS == "True":
|
| 184 |
+
day_log_subfolder = today_rev + "/"
|
| 185 |
+
host_name_subfolder = HOST_NAME + "/"
|
| 186 |
+
full_log_subfolder = day_log_subfolder + host_name_subfolder
|
| 187 |
+
|
| 188 |
+
FEEDBACK_LOGS_FOLDER = FEEDBACK_LOGS_FOLDER + full_log_subfolder
|
| 189 |
+
ACCESS_LOGS_FOLDER = ACCESS_LOGS_FOLDER + full_log_subfolder
|
| 190 |
+
USAGE_LOGS_FOLDER = USAGE_LOGS_FOLDER + full_log_subfolder
|
| 191 |
+
else:
|
| 192 |
+
full_log_subfolder = "" # Empty string when subfolders are not used
|
| 193 |
+
|
| 194 |
+
S3_FEEDBACK_LOGS_FOLDER = get_or_create_env_var(
|
| 195 |
+
"S3_FEEDBACK_LOGS_FOLDER", "feedback/" + full_log_subfolder
|
| 196 |
+
)
|
| 197 |
+
S3_ACCESS_LOGS_FOLDER = get_or_create_env_var(
|
| 198 |
+
"S3_ACCESS_LOGS_FOLDER", "logs/" + full_log_subfolder
|
| 199 |
+
)
|
| 200 |
+
S3_USAGE_LOGS_FOLDER = get_or_create_env_var(
|
| 201 |
+
"S3_USAGE_LOGS_FOLDER", "usage/" + full_log_subfolder
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
LOG_FILE_NAME = get_or_create_env_var("LOG_FILE_NAME", "log.csv")
|
| 205 |
+
USAGE_LOG_FILE_NAME = get_or_create_env_var("USAGE_LOG_FILE_NAME", LOG_FILE_NAME)
|
| 206 |
+
FEEDBACK_LOG_FILE_NAME = get_or_create_env_var("FEEDBACK_LOG_FILE_NAME", LOG_FILE_NAME)
|
| 207 |
+
|
| 208 |
+
# Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
|
| 209 |
+
DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var(
|
| 210 |
+
"DISPLAY_FILE_NAMES_IN_LOGS", "False"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Further customisation options for CSV logs
|
| 214 |
+
|
| 215 |
+
CSV_ACCESS_LOG_HEADERS = get_or_create_env_var(
|
| 216 |
+
"CSV_ACCESS_LOG_HEADERS", ""
|
| 217 |
+
) # If blank, uses component labels
|
| 218 |
+
CSV_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
|
| 219 |
+
"CSV_FEEDBACK_LOG_HEADERS", ""
|
| 220 |
+
) # If blank, uses component labels
|
| 221 |
+
CSV_USAGE_LOG_HEADERS = get_or_create_env_var(
|
| 222 |
+
"CSV_USAGE_LOG_HEADERS", ""
|
| 223 |
+
) # If blank, uses component labels
|
| 224 |
+
|
| 225 |
+
### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
|
| 226 |
+
SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var("SAVE_LOGS_TO_DYNAMODB", "False")
|
| 227 |
+
|
| 228 |
+
ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
|
| 229 |
+
"ACCESS_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_access_log"
|
| 230 |
+
)
|
| 231 |
+
DYNAMODB_ACCESS_LOG_HEADERS = get_or_create_env_var("DYNAMODB_ACCESS_LOG_HEADERS", "")
|
| 232 |
+
|
| 233 |
+
FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
|
| 234 |
+
"FEEDBACK_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_feedback"
|
| 235 |
+
)
|
| 236 |
+
DYNAMODB_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
|
| 237 |
+
"DYNAMODB_FEEDBACK_LOG_HEADERS", ""
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
|
| 241 |
+
"USAGE_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_usage"
|
| 242 |
+
)
|
| 243 |
+
DYNAMODB_USAGE_LOG_HEADERS = get_or_create_env_var("DYNAMODB_USAGE_LOG_HEADERS", "")
|
| 244 |
+
|
| 245 |
+
# Report logging to console?
|
| 246 |
+
LOGGING = get_or_create_env_var("LOGGING", "False")
|
| 247 |
+
|
| 248 |
+
if LOGGING == "True":
|
| 249 |
+
# Configure logging
|
| 250 |
+
logging.basicConfig(
|
| 251 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
###
|
| 255 |
+
# App run variables
|
| 256 |
+
###
|
| 257 |
+
OUTPUT_DEBUG_FILES = get_or_create_env_var(
|
| 258 |
+
"OUTPUT_DEBUG_FILES", "False"
|
| 259 |
+
) # Whether to output debug files
|
| 260 |
+
SHOW_ADDITIONAL_INSTRUCTION_TEXTBOXES = get_or_create_env_var(
|
| 261 |
+
"SHOW_ADDITIONAL_INSTRUCTION_TEXTBOXES", "True"
|
| 262 |
+
) # Whether to show additional instruction textboxes in the GUI
|
| 263 |
+
|
| 264 |
+
TIMEOUT_WAIT = int(
|
| 265 |
+
get_or_create_env_var("TIMEOUT_WAIT", "30")
|
| 266 |
+
) # Maximum number of seconds to wait for a response from the LLM
|
| 267 |
+
NUMBER_OF_RETRY_ATTEMPTS = int(
|
| 268 |
+
get_or_create_env_var("NUMBER_OF_RETRY_ATTEMPTS", "5")
|
| 269 |
+
) # Maximum number of times to retry a request to the LLM
|
| 270 |
+
# Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
|
| 271 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS = int(
|
| 272 |
+
get_or_create_env_var("MAX_OUTPUT_VALIDATION_ATTEMPTS", "3")
|
| 273 |
+
)
|
| 274 |
+
ENABLE_VALIDATION = get_or_create_env_var(
|
| 275 |
+
"ENABLE_VALIDATION", "False"
|
| 276 |
+
) # Whether to run validation loop after initial topic extraction
|
| 277 |
+
MAX_TIME_FOR_LOOP = int(
|
| 278 |
+
get_or_create_env_var("MAX_TIME_FOR_LOOP", "99999")
|
| 279 |
+
) # Maximum number of seconds to run the loop for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there)
|
| 280 |
+
|
| 281 |
+
MAX_COMMENT_CHARS = int(
|
| 282 |
+
get_or_create_env_var("MAX_COMMENT_CHARS", "14000")
|
| 283 |
+
) # Maximum number of characters in a comment
|
| 284 |
+
MAX_ROWS = int(
|
| 285 |
+
get_or_create_env_var("MAX_ROWS", "5000")
|
| 286 |
+
) # Maximum number of rows to process
|
| 287 |
+
MAX_GROUPS = int(
|
| 288 |
+
get_or_create_env_var("MAX_GROUPS", "99")
|
| 289 |
+
) # Maximum number of groups to process
|
| 290 |
+
BATCH_SIZE_DEFAULT = int(
|
| 291 |
+
get_or_create_env_var("BATCH_SIZE_DEFAULT", "5")
|
| 292 |
+
) # Default batch size for LLM calls
|
| 293 |
+
MAXIMUM_ZERO_SHOT_TOPICS = int(
|
| 294 |
+
get_or_create_env_var("MAXIMUM_ZERO_SHOT_TOPICS", "120")
|
| 295 |
+
) # Maximum number of zero shot topics to process
|
| 296 |
+
MAX_SPACES_GPU_RUN_TIME = int(
|
| 297 |
+
get_or_create_env_var("MAX_SPACES_GPU_RUN_TIME", "240")
|
| 298 |
+
) # Maximum number of seconds to run on GPU on Hugging Face Spaces
|
| 299 |
+
|
| 300 |
+
DEDUPLICATION_THRESHOLD = int(
|
| 301 |
+
get_or_create_env_var("DEDUPLICATION_THRESHOLD", "90")
|
| 302 |
+
) # Deduplication threshold for topic summary tables
|
| 303 |
+
|
| 304 |
+
###
|
| 305 |
+
# Model options
|
| 306 |
+
###
|
| 307 |
+
|
| 308 |
+
RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "0")
|
| 309 |
+
|
| 310 |
+
RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
|
| 311 |
+
|
| 312 |
+
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
| 313 |
+
GEMINI_API_KEY = get_or_create_env_var("GEMINI_API_KEY", "")
|
| 314 |
+
|
| 315 |
+
INTRO_TEXT = get_or_create_env_var(
|
| 316 |
+
"INTRO_TEXT",
|
| 317 |
+
"""# Large language model topic modelling
|
| 318 |
+
|
| 319 |
+
Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure/OpenAI, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure/OpenAI, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
|
| 320 |
+
|
| 321 |
+
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""",
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Read in intro text from a text file if it is a path to a text file
|
| 325 |
+
if INTRO_TEXT.endswith(".txt"):
|
| 326 |
+
INTRO_TEXT = open(INTRO_TEXT, "r").read()
|
| 327 |
+
|
| 328 |
+
INTRO_TEXT = INTRO_TEXT.strip('"').strip("'")
|
| 329 |
+
|
| 330 |
+
# Azure/OpenAI AI Inference settings
|
| 331 |
+
RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "1")
|
| 332 |
+
AZURE_OPENAI_API_KEY = get_or_create_env_var("AZURE_OPENAI_API_KEY", "")
|
| 333 |
+
AZURE_OPENAI_INFERENCE_ENDPOINT = get_or_create_env_var(
|
| 334 |
+
"AZURE_OPENAI_INFERENCE_ENDPOINT", ""
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Llama-server settings
|
| 338 |
+
RUN_INFERENCE_SERVER = get_or_create_env_var("RUN_INFERENCE_SERVER", "0")
|
| 339 |
+
API_URL = get_or_create_env_var("API_URL", "http://localhost:8080")
|
| 340 |
+
|
| 341 |
+
RUN_MCP_SERVER = convert_string_to_boolean(
|
| 342 |
+
get_or_create_env_var("RUN_MCP_SERVER", "False")
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Build up options for models
|
| 346 |
+
model_full_names = list()
|
| 347 |
+
model_short_names = list()
|
| 348 |
+
model_source = list()
|
| 349 |
+
|
| 350 |
+
CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var(
|
| 351 |
+
"CHOSEN_LOCAL_MODEL_TYPE", "Qwen 3 4B"
|
| 352 |
+
) # Gemma 3 1B # "Gemma 2b" # "Gemma 3 4B"
|
| 353 |
+
|
| 354 |
+
USE_LLAMA_SWAP = get_or_create_env_var("USE_LLAMA_SWAP", "False")
|
| 355 |
+
if USE_LLAMA_SWAP == "True":
|
| 356 |
+
USE_LLAMA_SWAP = True
|
| 357 |
+
else:
|
| 358 |
+
USE_LLAMA_SWAP = False
|
| 359 |
+
|
| 360 |
+
if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
|
| 361 |
+
model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
| 362 |
+
model_short_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
| 363 |
+
model_source.append("Local")
|
| 364 |
+
|
| 365 |
+
if RUN_AWS_BEDROCK_MODELS == "1":
|
| 366 |
+
amazon_models = [
|
| 367 |
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
| 368 |
+
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
| 369 |
+
"anthropic.claude-sonnet-4-5-20250929-v1:0",
|
| 370 |
+
"amazon.nova-micro-v1:0",
|
| 371 |
+
"amazon.nova-lite-v1:0",
|
| 372 |
+
"amazon.nova-pro-v1:0",
|
| 373 |
+
"deepseek.v3-v1:0",
|
| 374 |
+
"openai.gpt-oss-20b-1:0",
|
| 375 |
+
"openai.gpt-oss-120b-1:0",
|
| 376 |
+
"google.gemma-3-12b-it",
|
| 377 |
+
"mistral.ministral-3-14b-instruct",
|
| 378 |
+
]
|
| 379 |
+
model_full_names.extend(amazon_models)
|
| 380 |
+
model_short_names.extend(
|
| 381 |
+
[
|
| 382 |
+
"haiku",
|
| 383 |
+
"sonnet_3_7",
|
| 384 |
+
"sonnet_4_5",
|
| 385 |
+
"nova_micro",
|
| 386 |
+
"nova_lite",
|
| 387 |
+
"nova_pro",
|
| 388 |
+
"deepseek_v3",
|
| 389 |
+
"gpt_oss_20b_aws",
|
| 390 |
+
"gpt_oss_120b_aws",
|
| 391 |
+
"gemma_3_12b_it",
|
| 392 |
+
"ministral_3_14b_instruct",
|
| 393 |
+
]
|
| 394 |
+
)
|
| 395 |
+
model_source.extend(["AWS"] * len(amazon_models))
|
| 396 |
+
|
| 397 |
+
if RUN_GEMINI_MODELS == "1":
|
| 398 |
+
gemini_models = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"]
|
| 399 |
+
model_full_names.extend(gemini_models)
|
| 400 |
+
model_short_names.extend(
|
| 401 |
+
["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"]
|
| 402 |
+
)
|
| 403 |
+
model_source.extend(["Gemini"] * len(gemini_models))
|
| 404 |
+
|
| 405 |
+
# Register Azure/OpenAI AI models (model names must match your Azure/OpenAI deployments)
|
| 406 |
+
if RUN_AZURE_MODELS == "1":
|
| 407 |
+
# Example deployments; adjust to the deployments you actually create in Azure/OpenAI
|
| 408 |
+
azure_models = ["gpt-5-mini", "gpt-4o-mini"]
|
| 409 |
+
model_full_names.extend(azure_models)
|
| 410 |
+
model_short_names.extend(["gpt-5-mini", "gpt-4o-mini"])
|
| 411 |
+
model_source.extend(["Azure/OpenAI"] * len(azure_models))
|
| 412 |
+
|
| 413 |
+
# Register inference-server models
|
| 414 |
+
CHOSEN_INFERENCE_SERVER_MODEL = ""
|
| 415 |
+
if RUN_INFERENCE_SERVER == "1":
|
| 416 |
+
# Example inference-server models; adjust to the models you have available on your server
|
| 417 |
+
inference_server_models = [
|
| 418 |
+
"unnamed-inference-server-model",
|
| 419 |
+
"qwen_3_4b_it",
|
| 420 |
+
"qwen_3_4b_think",
|
| 421 |
+
"gpt_oss_20b",
|
| 422 |
+
"gemma_3_12b",
|
| 423 |
+
"ministral_3_14b_it",
|
| 424 |
+
]
|
| 425 |
+
model_full_names.extend(inference_server_models)
|
| 426 |
+
model_short_names.extend(inference_server_models)
|
| 427 |
+
model_source.extend(["inference-server"] * len(inference_server_models))
|
| 428 |
+
|
| 429 |
+
CHOSEN_INFERENCE_SERVER_MODEL = get_or_create_env_var(
|
| 430 |
+
"CHOSEN_INFERENCE_SERVER_MODEL", inference_server_models[0]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if CHOSEN_INFERENCE_SERVER_MODEL not in inference_server_models:
|
| 434 |
+
model_full_names.append(CHOSEN_INFERENCE_SERVER_MODEL)
|
| 435 |
+
model_short_names.append(CHOSEN_INFERENCE_SERVER_MODEL)
|
| 436 |
+
model_source.append("inference-server")
|
| 437 |
+
|
| 438 |
+
model_name_map = {
|
| 439 |
+
full: {"short_name": short, "source": source}
|
| 440 |
+
for full, short, source in zip(model_full_names, model_short_names, model_source)
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
if RUN_LOCAL_MODEL == "1":
|
| 444 |
+
default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
|
| 445 |
+
elif RUN_INFERENCE_SERVER == "1":
|
| 446 |
+
default_model_choice = CHOSEN_INFERENCE_SERVER_MODEL
|
| 447 |
+
elif RUN_AWS_FUNCTIONS == "1":
|
| 448 |
+
default_model_choice = amazon_models[0]
|
| 449 |
+
else:
|
| 450 |
+
default_model_choice = gemini_models[0]
|
| 451 |
+
|
| 452 |
+
default_model_source = model_name_map[default_model_choice]["source"]
|
| 453 |
+
model_sources = list(
|
| 454 |
+
set([model_name_map[model]["source"] for model in model_full_names])
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def update_model_choice_config(default_model_source, model_name_map):
|
| 459 |
+
# Filter models by source and return the first matching model name
|
| 460 |
+
matching_models = [
|
| 461 |
+
model_name
|
| 462 |
+
for model_name, model_info in model_name_map.items()
|
| 463 |
+
if model_info["source"] == default_model_source
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
output_model = matching_models[0] if matching_models else model_full_names[0]
|
| 467 |
+
|
| 468 |
+
return output_model, matching_models
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
default_model_choice, default_source_models = update_model_choice_config(
|
| 472 |
+
default_model_source, model_name_map
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# print("model_name_map:", model_name_map)
|
| 476 |
+
|
| 477 |
+
# HF token may or may not be needed for downloading models from Hugging Face
|
| 478 |
+
HF_TOKEN = get_or_create_env_var("HF_TOKEN", "")
|
| 479 |
+
|
| 480 |
+
LOAD_LOCAL_MODEL_AT_START = get_or_create_env_var("LOAD_LOCAL_MODEL_AT_START", "False")
|
| 481 |
+
|
| 482 |
+
# If you are using a system with low VRAM, you can set this to True to reduce the memory requirements
|
| 483 |
+
LOW_VRAM_SYSTEM = get_or_create_env_var("LOW_VRAM_SYSTEM", "False")
|
| 484 |
+
|
| 485 |
+
MULTIMODAL_PROMPT_FORMAT = get_or_create_env_var("MULTIMODAL_PROMPT_FORMAT", "False")
|
| 486 |
+
|
| 487 |
+
if LOW_VRAM_SYSTEM == "True":
|
| 488 |
+
print("Using settings for low VRAM system")
|
| 489 |
+
USE_LLAMA_CPP = get_or_create_env_var("USE_LLAMA_CPP", "True")
|
| 490 |
+
LLM_MAX_NEW_TOKENS = int(get_or_create_env_var("LLM_MAX_NEW_TOKENS", "4096"))
|
| 491 |
+
LLM_CONTEXT_LENGTH = int(get_or_create_env_var("LLM_CONTEXT_LENGTH", "16384"))
|
| 492 |
+
LLM_BATCH_SIZE = int(get_or_create_env_var("LLM_BATCH_SIZE", "512"))
|
| 493 |
+
K_QUANT_LEVEL = int(
|
| 494 |
+
get_or_create_env_var("K_QUANT_LEVEL", "2")
|
| 495 |
+
) # 2 = q4_0, 8 = q8_0, 4 = fp16
|
| 496 |
+
V_QUANT_LEVEL = int(
|
| 497 |
+
get_or_create_env_var("V_QUANT_LEVEL", "2")
|
| 498 |
+
) # 2 = q4_0, 8 = q8_0, 4 = fp16
|
| 499 |
+
|
| 500 |
+
USE_LLAMA_CPP = get_or_create_env_var(
|
| 501 |
+
"USE_LLAMA_CPP", "True"
|
| 502 |
+
) # Llama.cpp or transformers with unsloth
|
| 503 |
+
|
| 504 |
+
LOCAL_REPO_ID = get_or_create_env_var("LOCAL_REPO_ID", "")
|
| 505 |
+
LOCAL_MODEL_FILE = get_or_create_env_var("LOCAL_MODEL_FILE", "")
|
| 506 |
+
LOCAL_MODEL_FOLDER = get_or_create_env_var("LOCAL_MODEL_FOLDER", "")
|
| 507 |
+
|
| 508 |
+
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "unsloth/gemma-2-it-GGUF")
|
| 509 |
+
GEMMA2_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 510 |
+
"GEMMA2_2B_REPO_TRANSFORMERS_ID", "unsloth/gemma-2-2b-it-bnb-4bit"
|
| 511 |
+
)
|
| 512 |
+
if USE_LLAMA_CPP == "False":
|
| 513 |
+
GEMMA2_REPO_ID = GEMMA2_REPO_TRANSFORMERS_ID
|
| 514 |
+
GEMMA2_MODEL_FILE = get_or_create_env_var(
|
| 515 |
+
"GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it.q8_0.gguf"
|
| 516 |
+
)
|
| 517 |
+
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
|
| 518 |
+
|
| 519 |
+
GEMMA3_4B_REPO_ID = get_or_create_env_var(
|
| 520 |
+
"GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF"
|
| 521 |
+
)
|
| 522 |
+
GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 523 |
+
"GEMMA3_4B_REPO_TRANSFORMERS_ID", "unsloth/gemma-3-4b-it-bnb-4bit"
|
| 524 |
+
)
|
| 525 |
+
if USE_LLAMA_CPP == "False":
|
| 526 |
+
GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
|
| 527 |
+
GEMMA3_4B_MODEL_FILE = get_or_create_env_var(
|
| 528 |
+
"GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-UD-Q4_K_XL.gguf"
|
| 529 |
+
)
|
| 530 |
+
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var(
|
| 531 |
+
"GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b"
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
GEMMA3_12B_REPO_ID = get_or_create_env_var(
|
| 535 |
+
"GEMMA3_12B_REPO_ID", "unsloth/gemma-3-12b-it-GGUF"
|
| 536 |
+
)
|
| 537 |
+
GEMMA3_12B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 538 |
+
"GEMMA3_12B_REPO_TRANSFORMERS_ID", "unsloth/gemma-3-12b-it-bnb-4bit"
|
| 539 |
+
)
|
| 540 |
+
if USE_LLAMA_CPP == "False":
|
| 541 |
+
GEMMA3_12B_REPO_ID = GEMMA3_12B_REPO_TRANSFORMERS_ID
|
| 542 |
+
GEMMA3_12B_MODEL_FILE = get_or_create_env_var(
|
| 543 |
+
"GEMMA3_12B_MODEL_FILE", "gemma-3-12b-it-UD-Q4_K_XL.gguf"
|
| 544 |
+
)
|
| 545 |
+
GEMMA3_12B_MODEL_FOLDER = get_or_create_env_var(
|
| 546 |
+
"GEMMA3_12B_MODEL_FOLDER", "model/gemma3_12b"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
|
| 550 |
+
GPT_OSS_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 551 |
+
"GPT_OSS_REPO_TRANSFORMERS_ID", "unsloth/gpt-oss-20b-unsloth-bnb-4bit"
|
| 552 |
+
)
|
| 553 |
+
if USE_LLAMA_CPP == "False":
|
| 554 |
+
GPT_OSS_REPO_ID = GPT_OSS_REPO_TRANSFORMERS_ID
|
| 555 |
+
GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
|
| 556 |
+
GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
|
| 557 |
+
|
| 558 |
+
QWEN3_4B_REPO_ID = get_or_create_env_var(
|
| 559 |
+
"QWEN3_4B_REPO_ID", "unsloth/Qwen3-4B-Instruct-2507-GGUF"
|
| 560 |
+
)
|
| 561 |
+
QWEN3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 562 |
+
"QWEN3_4B_REPO_TRANSFORMERS_ID", "unsloth/Qwen3-4B-unsloth-bnb-4bit"
|
| 563 |
+
)
|
| 564 |
+
if USE_LLAMA_CPP == "False":
|
| 565 |
+
QWEN3_4B_REPO_ID = QWEN3_4B_REPO_TRANSFORMERS_ID
|
| 566 |
+
|
| 567 |
+
QWEN3_4B_MODEL_FILE = get_or_create_env_var(
|
| 568 |
+
"QWEN3_4B_MODEL_FILE", "Qwen3-4B-Instruct-2507-UD-Q4_K_XL.gguf"
|
| 569 |
+
)
|
| 570 |
+
QWEN3_4B_MODEL_FOLDER = get_or_create_env_var("QWEN3_4B_MODEL_FOLDER", "model/qwen")
|
| 571 |
+
|
| 572 |
+
GRANITE_4_TINY_REPO_ID = get_or_create_env_var(
|
| 573 |
+
"GRANITE_4_TINY_REPO_ID", "unsloth/granite-4.0-h-tiny-GGUF"
|
| 574 |
+
)
|
| 575 |
+
GRANITE_4_TINY_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 576 |
+
"GRANITE_4_TINY_REPO_TRANSFORMERS_ID", "unsloth/granite-4.0-h-tiny-FP8-Dynamic"
|
| 577 |
+
)
|
| 578 |
+
if USE_LLAMA_CPP == "False":
|
| 579 |
+
GRANITE_4_TINY_REPO_ID = GRANITE_4_TINY_REPO_TRANSFORMERS_ID
|
| 580 |
+
GRANITE_4_TINY_MODEL_FILE = get_or_create_env_var(
|
| 581 |
+
"GRANITE_4_TINY_MODEL_FILE", "granite-4.0-h-tiny-UD-Q4_K_XL.gguf"
|
| 582 |
+
)
|
| 583 |
+
GRANITE_4_TINY_MODEL_FOLDER = get_or_create_env_var(
|
| 584 |
+
"GRANITE_4_TINY_MODEL_FOLDER", "model/granite"
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
GRANITE_4_3B_REPO_ID = get_or_create_env_var(
|
| 588 |
+
"GRANITE_4_3B_REPO_ID", "unsloth/granite-4.0-h-micro-GGUF"
|
| 589 |
+
)
|
| 590 |
+
GRANITE_4_3B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
|
| 591 |
+
"GRANITE_4_3B_REPO_TRANSFORMERS_ID", "unsloth/granite-4.0-micro-unsloth-bnb-4bit"
|
| 592 |
+
)
|
| 593 |
+
if USE_LLAMA_CPP == "False":
|
| 594 |
+
GRANITE_4_3B_REPO_ID = GRANITE_4_3B_REPO_TRANSFORMERS_ID
|
| 595 |
+
GRANITE_4_3B_MODEL_FILE = get_or_create_env_var(
|
| 596 |
+
"GRANITE_4_3B_MODEL_FILE", "granite-4.0-h-micro-UD-Q4_K_XL.gguf"
|
| 597 |
+
)
|
| 598 |
+
GRANITE_4_3B_MODEL_FOLDER = get_or_create_env_var(
|
| 599 |
+
"GRANITE_4_3B_MODEL_FOLDER", "model/granite"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
|
| 603 |
+
LOCAL_REPO_ID = GEMMA2_REPO_ID
|
| 604 |
+
LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
|
| 605 |
+
LOCAL_MODEL_FOLDER = GEMMA2_MODEL_FOLDER
|
| 606 |
+
|
| 607 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
|
| 608 |
+
LOCAL_REPO_ID = GEMMA3_4B_REPO_ID
|
| 609 |
+
LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
|
| 610 |
+
LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
|
| 611 |
+
MULTIMODAL_PROMPT_FORMAT = "True"
|
| 612 |
+
|
| 613 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 12B":
|
| 614 |
+
LOCAL_REPO_ID = GEMMA3_12B_REPO_ID
|
| 615 |
+
LOCAL_MODEL_FILE = GEMMA3_12B_MODEL_FILE
|
| 616 |
+
LOCAL_MODEL_FOLDER = GEMMA3_12B_MODEL_FOLDER
|
| 617 |
+
MULTIMODAL_PROMPT_FORMAT = "True"
|
| 618 |
+
|
| 619 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B":
|
| 620 |
+
LOCAL_REPO_ID = QWEN3_4B_REPO_ID
|
| 621 |
+
LOCAL_MODEL_FILE = QWEN3_4B_MODEL_FILE
|
| 622 |
+
LOCAL_MODEL_FOLDER = QWEN3_4B_MODEL_FOLDER
|
| 623 |
+
|
| 624 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
|
| 625 |
+
LOCAL_REPO_ID = GPT_OSS_REPO_ID
|
| 626 |
+
LOCAL_MODEL_FILE = GPT_OSS_MODEL_FILE
|
| 627 |
+
LOCAL_MODEL_FOLDER = GPT_OSS_MODEL_FOLDER
|
| 628 |
+
|
| 629 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Granite 4 Tiny":
|
| 630 |
+
LOCAL_REPO_ID = GRANITE_4_TINY_REPO_ID
|
| 631 |
+
LOCAL_MODEL_FILE = GRANITE_4_TINY_MODEL_FILE
|
| 632 |
+
LOCAL_MODEL_FOLDER = GRANITE_4_TINY_MODEL_FOLDER
|
| 633 |
+
|
| 634 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Granite 4 Micro":
|
| 635 |
+
LOCAL_REPO_ID = GRANITE_4_3B_REPO_ID
|
| 636 |
+
LOCAL_MODEL_FILE = GRANITE_4_3B_MODEL_FILE
|
| 637 |
+
LOCAL_MODEL_FOLDER = GRANITE_4_3B_MODEL_FOLDER
|
| 638 |
+
|
| 639 |
+
elif not CHOSEN_LOCAL_MODEL_TYPE:
|
| 640 |
+
print("No local model type chosen")
|
| 641 |
+
LOCAL_REPO_ID = ""
|
| 642 |
+
LOCAL_MODEL_FILE = ""
|
| 643 |
+
LOCAL_MODEL_FOLDER = ""
|
| 644 |
+
else:
|
| 645 |
+
print("CHOSEN_LOCAL_MODEL_TYPE not found")
|
| 646 |
+
LOCAL_REPO_ID = ""
|
| 647 |
+
LOCAL_MODEL_FILE = ""
|
| 648 |
+
LOCAL_MODEL_FOLDER = ""
|
| 649 |
+
|
| 650 |
+
USE_SPECULATIVE_DECODING = get_or_create_env_var("USE_SPECULATIVE_DECODING", "False")
|
| 651 |
+
|
| 652 |
+
ASSISTANT_MODEL = get_or_create_env_var("ASSISTANT_MODEL", "")
|
| 653 |
+
if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
|
| 654 |
+
ASSISTANT_MODEL = get_or_create_env_var(
|
| 655 |
+
"ASSISTANT_MODEL", "unsloth/gemma-3-270m-it"
|
| 656 |
+
)
|
| 657 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B":
|
| 658 |
+
ASSISTANT_MODEL = get_or_create_env_var("ASSISTANT_MODEL", "unsloth/Qwen3-0.6B")
|
| 659 |
+
|
| 660 |
+
DRAFT_MODEL_LOC = get_or_create_env_var("DRAFT_MODEL_LOC", ".cache/llama.cpp/")
|
| 661 |
+
|
| 662 |
+
GEMMA3_DRAFT_MODEL_LOC = get_or_create_env_var(
|
| 663 |
+
"GEMMA3_DRAFT_MODEL_LOC",
|
| 664 |
+
DRAFT_MODEL_LOC + "unsloth_gemma-3-270m-it-qat-GGUF_gemma-3-270m-it-qat-F16.gguf",
|
| 665 |
+
)
|
| 666 |
+
GEMMA3_4B_DRAFT_MODEL_LOC = get_or_create_env_var(
|
| 667 |
+
"GEMMA3_4B_DRAFT_MODEL_LOC",
|
| 668 |
+
DRAFT_MODEL_LOC + "unsloth_gemma-3-4b-it-qat-GGUF_gemma-3-4b-it-qat-Q4_K_M.gguf",
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
QWEN3_DRAFT_MODEL_LOC = get_or_create_env_var(
|
| 672 |
+
"QWEN3_DRAFT_MODEL_LOC", DRAFT_MODEL_LOC + "Qwen3-0.6B-Q8_0.gguf"
|
| 673 |
+
)
|
| 674 |
+
QWEN3_4B_DRAFT_MODEL_LOC = get_or_create_env_var(
|
| 675 |
+
"QWEN3_4B_DRAFT_MODEL_LOC",
|
| 676 |
+
DRAFT_MODEL_LOC + "Qwen3-4B-Instruct-2507-UD-Q4_K_XL.gguf",
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
LLM_MAX_GPU_LAYERS = int(
|
| 681 |
+
get_or_create_env_var("LLM_MAX_GPU_LAYERS", "-1")
|
| 682 |
+
) # Maximum possible
|
| 683 |
+
LLM_TEMPERATURE = float(get_or_create_env_var("LLM_TEMPERATURE", "0.6"))
|
| 684 |
+
LLM_TOP_K = int(
|
| 685 |
+
get_or_create_env_var("LLM_TOP_K", "64")
|
| 686 |
+
) # https://docs.unsloth.ai/basics/gemma-3-how-to-run-and-fine-tune
|
| 687 |
+
LLM_MIN_P = float(get_or_create_env_var("LLM_MIN_P", "0"))
|
| 688 |
+
LLM_TOP_P = float(get_or_create_env_var("LLM_TOP_P", "0.95"))
|
| 689 |
+
LLM_REPETITION_PENALTY = float(get_or_create_env_var("LLM_REPETITION_PENALTY", "1.0"))
|
| 690 |
+
|
| 691 |
+
LLM_LAST_N_TOKENS = int(get_or_create_env_var("LLM_LAST_N_TOKENS", "512"))
|
| 692 |
+
LLM_MAX_NEW_TOKENS = int(get_or_create_env_var("LLM_MAX_NEW_TOKENS", "4096"))
|
| 693 |
+
LLM_SEED = int(get_or_create_env_var("LLM_SEED", "42"))
|
| 694 |
+
LLM_RESET = get_or_create_env_var("LLM_RESET", "False")
|
| 695 |
+
LLM_STREAM = get_or_create_env_var("LLM_STREAM", "True")
|
| 696 |
+
LLM_THREADS = int(get_or_create_env_var("LLM_THREADS", "-1"))
|
| 697 |
+
LLM_BATCH_SIZE = int(get_or_create_env_var("LLM_BATCH_SIZE", "2048"))
|
| 698 |
+
LLM_CONTEXT_LENGTH = int(get_or_create_env_var("LLM_CONTEXT_LENGTH", "24576"))
|
| 699 |
+
LLM_SAMPLE = get_or_create_env_var("LLM_SAMPLE", "True")
|
| 700 |
+
LLM_STOP_STRINGS = get_or_create_env_var("LLM_STOP_STRINGS", r"['\n\n\n\n\n\n']")
|
| 701 |
+
|
| 702 |
+
SPECULATIVE_DECODING = get_or_create_env_var("SPECULATIVE_DECODING", "False")
|
| 703 |
+
NUM_PRED_TOKENS = int(get_or_create_env_var("NUM_PRED_TOKENS", "2"))
|
| 704 |
+
K_QUANT_LEVEL = get_or_create_env_var(
|
| 705 |
+
"K_QUANT_LEVEL", ""
|
| 706 |
+
) # 2 = q4_0, 8 = q8_0, 4 = fp16
|
| 707 |
+
V_QUANT_LEVEL = get_or_create_env_var(
|
| 708 |
+
"V_QUANT_LEVEL", ""
|
| 709 |
+
) # 2 = q4_0, 8 = q8_0, 4 = fp16
|
| 710 |
+
|
| 711 |
+
if not K_QUANT_LEVEL:
|
| 712 |
+
K_QUANT_LEVEL = None
|
| 713 |
+
else:
|
| 714 |
+
K_QUANT_LEVEL = int(K_QUANT_LEVEL)
|
| 715 |
+
if not V_QUANT_LEVEL:
|
| 716 |
+
V_QUANT_LEVEL = None
|
| 717 |
+
else:
|
| 718 |
+
V_QUANT_LEVEL = int(V_QUANT_LEVEL)
|
| 719 |
+
|
| 720 |
+
# If you are using e.g. gpt-oss, you can add a reasoning suffix to set reasoning level, or turn it off in the case of Qwen 3 4B
|
| 721 |
+
if CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
|
| 722 |
+
REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "Reasoning: low")
|
| 723 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B" and USE_LLAMA_CPP == "False":
|
| 724 |
+
REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "/nothink")
|
| 725 |
+
else:
|
| 726 |
+
REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "")
|
| 727 |
+
|
| 728 |
+
# Transformers variables
|
| 729 |
+
COMPILE_TRANSFORMERS = get_or_create_env_var(
|
| 730 |
+
"COMPILE_TRANSFORMERS", "False"
|
| 731 |
+
) # Whether to compile transformers models
|
| 732 |
+
USE_BITSANDBYTES = get_or_create_env_var(
|
| 733 |
+
"USE_BITSANDBYTES", "True"
|
| 734 |
+
) # Whether to use bitsandbytes for quantization
|
| 735 |
+
COMPILE_MODE = get_or_create_env_var(
|
| 736 |
+
"COMPILE_MODE", "reduce-overhead"
|
| 737 |
+
) # alternatively 'max-autotune'
|
| 738 |
+
MODEL_DTYPE = get_or_create_env_var(
|
| 739 |
+
"MODEL_DTYPE", "bfloat16"
|
| 740 |
+
) # alternatively 'bfloat16'
|
| 741 |
+
INT8_WITH_OFFLOAD_TO_CPU = get_or_create_env_var(
|
| 742 |
+
"INT8_WITH_OFFLOAD_TO_CPU", "False"
|
| 743 |
+
) # Whether to offload to CPU
|
| 744 |
+
|
| 745 |
+
DEFAULT_SAMPLED_SUMMARIES = int(
|
| 746 |
+
get_or_create_env_var("DEFAULT_SAMPLED_SUMMARIES", "75")
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
###
|
| 750 |
+
# Gradio app variables
|
| 751 |
+
###
|
| 752 |
+
|
| 753 |
+
# Get some environment variables and Launch the Gradio app
|
| 754 |
+
COGNITO_AUTH = get_or_create_env_var("COGNITO_AUTH", "0")
|
| 755 |
+
|
| 756 |
+
RUN_DIRECT_MODE = get_or_create_env_var("RUN_DIRECT_MODE", "0")
|
| 757 |
+
|
| 758 |
+
# Direct mode environment variables
|
| 759 |
+
DIRECT_MODE_TASK = get_or_create_env_var("DIRECT_MODE_TASK", "extract")
|
| 760 |
+
DIRECT_MODE_INPUT_FILE = get_or_create_env_var("DIRECT_MODE_INPUT_FILE", "")
|
| 761 |
+
DIRECT_MODE_OUTPUT_DIR = get_or_create_env_var("DIRECT_MODE_OUTPUT_DIR", OUTPUT_FOLDER)
|
| 762 |
+
DIRECT_MODE_S3_OUTPUT_BUCKET = get_or_create_env_var(
|
| 763 |
+
"DIRECT_MODE_S3_OUTPUT_BUCKET", S3_OUTPUTS_BUCKET
|
| 764 |
+
)
|
| 765 |
+
DIRECT_MODE_TEXT_COLUMN = get_or_create_env_var("DIRECT_MODE_TEXT_COLUMN", "")
|
| 766 |
+
DIRECT_MODE_PREVIOUS_OUTPUT_FILES = get_or_create_env_var(
|
| 767 |
+
"DIRECT_MODE_PREVIOUS_OUTPUT_FILES", ""
|
| 768 |
+
)
|
| 769 |
+
DIRECT_MODE_USERNAME = get_or_create_env_var("DIRECT_MODE_USERNAME", "")
|
| 770 |
+
DIRECT_MODE_GROUP_BY = get_or_create_env_var("DIRECT_MODE_GROUP_BY", "")
|
| 771 |
+
DIRECT_MODE_EXCEL_SHEETS = get_or_create_env_var("DIRECT_MODE_EXCEL_SHEETS", "")
|
| 772 |
+
DIRECT_MODE_MODEL_CHOICE = get_or_create_env_var(
|
| 773 |
+
"DIRECT_MODE_MODEL_CHOICE", default_model_choice
|
| 774 |
+
)
|
| 775 |
+
DIRECT_MODE_TEMPERATURE = get_or_create_env_var(
|
| 776 |
+
"DIRECT_MODE_TEMPERATURE", str(LLM_TEMPERATURE)
|
| 777 |
+
)
|
| 778 |
+
DIRECT_MODE_BATCH_SIZE = get_or_create_env_var(
|
| 779 |
+
"DIRECT_MODE_BATCH_SIZE", str(BATCH_SIZE_DEFAULT)
|
| 780 |
+
)
|
| 781 |
+
DIRECT_MODE_MAX_TOKENS = get_or_create_env_var(
|
| 782 |
+
"DIRECT_MODE_MAX_TOKENS", str(LLM_MAX_NEW_TOKENS)
|
| 783 |
+
)
|
| 784 |
+
DIRECT_MODE_CONTEXT = get_or_create_env_var("DIRECT_MODE_CONTEXT", "")
|
| 785 |
+
DIRECT_MODE_CANDIDATE_TOPICS = get_or_create_env_var("DIRECT_MODE_CANDIDATE_TOPICS", "")
|
| 786 |
+
DIRECT_MODE_FORCE_ZERO_SHOT = get_or_create_env_var("DIRECT_MODE_FORCE_ZERO_SHOT", "No")
|
| 787 |
+
DIRECT_MODE_FORCE_SINGLE_TOPIC = get_or_create_env_var(
|
| 788 |
+
"DIRECT_MODE_FORCE_SINGLE_TOPIC", "No"
|
| 789 |
+
)
|
| 790 |
+
DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY = get_or_create_env_var(
|
| 791 |
+
"DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY", "No"
|
| 792 |
+
)
|
| 793 |
+
DIRECT_MODE_SENTIMENT = get_or_create_env_var(
|
| 794 |
+
"DIRECT_MODE_SENTIMENT", "Negative or Positive"
|
| 795 |
+
)
|
| 796 |
+
DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS = get_or_create_env_var(
|
| 797 |
+
"DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS", ""
|
| 798 |
+
)
|
| 799 |
+
DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES = get_or_create_env_var(
|
| 800 |
+
"DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES", ""
|
| 801 |
+
)
|
| 802 |
+
DIRECT_MODE_SHOW_PREVIOUS_TABLE = get_or_create_env_var(
|
| 803 |
+
"DIRECT_MODE_SHOW_PREVIOUS_TABLE", "Yes"
|
| 804 |
+
)
|
| 805 |
+
DIRECT_MODE_MAX_TIME_FOR_LOOP = get_or_create_env_var(
|
| 806 |
+
"DIRECT_MODE_MAX_TIME_FOR_LOOP", str(MAX_TIME_FOR_LOOP)
|
| 807 |
+
)
|
| 808 |
+
DIRECT_MODE_DEDUP_METHOD = get_or_create_env_var("DIRECT_MODE_DEDUP_METHOD", "fuzzy")
|
| 809 |
+
DIRECT_MODE_SIMILARITY_THRESHOLD = get_or_create_env_var(
|
| 810 |
+
"DIRECT_MODE_SIMILARITY_THRESHOLD", str(DEDUPLICATION_THRESHOLD)
|
| 811 |
+
)
|
| 812 |
+
DIRECT_MODE_MERGE_SENTIMENT = get_or_create_env_var("DIRECT_MODE_MERGE_SENTIMENT", "No")
|
| 813 |
+
DIRECT_MODE_MERGE_GENERAL_TOPICS = get_or_create_env_var(
|
| 814 |
+
"DIRECT_MODE_MERGE_GENERAL_TOPICS", "Yes"
|
| 815 |
+
)
|
| 816 |
+
DIRECT_MODE_SUMMARY_FORMAT = get_or_create_env_var(
|
| 817 |
+
"DIRECT_MODE_SUMMARY_FORMAT", "two_paragraph"
|
| 818 |
+
)
|
| 819 |
+
DIRECT_MODE_SAMPLE_REFERENCE_TABLE = get_or_create_env_var(
|
| 820 |
+
"DIRECT_MODE_SAMPLE_REFERENCE_TABLE", "True"
|
| 821 |
+
)
|
| 822 |
+
DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES = get_or_create_env_var(
|
| 823 |
+
"DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES", str(DEFAULT_SAMPLED_SUMMARIES)
|
| 824 |
+
)
|
| 825 |
+
DIRECT_MODE_RANDOM_SEED = get_or_create_env_var(
|
| 826 |
+
"DIRECT_MODE_RANDOM_SEED", str(LLM_SEED)
|
| 827 |
+
)
|
| 828 |
+
DIRECT_MODE_CREATE_XLSX_OUTPUT = get_or_create_env_var(
|
| 829 |
+
"DIRECT_MODE_CREATE_XLSX_OUTPUT", "True"
|
| 830 |
+
)
|
| 831 |
+
# CHOSEN_INFERENCE_SERVER_MODEL is defined later, so we'll handle it after that definition
|
| 832 |
+
|
| 833 |
+
MAX_QUEUE_SIZE = int(get_or_create_env_var("MAX_QUEUE_SIZE", "5"))
|
| 834 |
+
|
| 835 |
+
MAX_FILE_SIZE = get_or_create_env_var("MAX_FILE_SIZE", "250mb")
|
| 836 |
+
|
| 837 |
+
GRADIO_SERVER_PORT = int(get_or_create_env_var("GRADIO_SERVER_PORT", "7860"))
|
| 838 |
+
|
| 839 |
+
ROOT_PATH = get_or_create_env_var("ROOT_PATH", "")
|
| 840 |
+
|
| 841 |
+
DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var("DEFAULT_CONCURRENCY_LIMIT", "3")
|
| 842 |
+
|
| 843 |
+
GET_DEFAULT_ALLOW_LIST = get_or_create_env_var("GET_DEFAULT_ALLOW_LIST", "")
|
| 844 |
+
|
| 845 |
+
ALLOW_LIST_PATH = get_or_create_env_var(
|
| 846 |
+
"ALLOW_LIST_PATH", ""
|
| 847 |
+
) # config/default_allow_list.csv
|
| 848 |
+
|
| 849 |
+
S3_ALLOW_LIST_PATH = get_or_create_env_var(
|
| 850 |
+
"S3_ALLOW_LIST_PATH", ""
|
| 851 |
+
) # default_allow_list.csv # This is a path within the named S3 bucket
|
| 852 |
+
|
| 853 |
+
if ALLOW_LIST_PATH:
|
| 854 |
+
OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
|
| 855 |
+
else:
|
| 856 |
+
OUTPUT_ALLOW_LIST_PATH = "config/default_allow_list.csv"
|
| 857 |
+
|
| 858 |
+
FILE_INPUT_HEIGHT = int(get_or_create_env_var("FILE_INPUT_HEIGHT", "125"))
|
| 859 |
+
|
| 860 |
+
SHOW_EXAMPLES = get_or_create_env_var("SHOW_EXAMPLES", "True")
|
| 861 |
+
|
| 862 |
+
###
|
| 863 |
+
# COST CODE OPTIONS
|
| 864 |
+
###
|
| 865 |
+
|
| 866 |
+
SHOW_COSTS = get_or_create_env_var("SHOW_COSTS", "False")
|
| 867 |
+
|
| 868 |
+
GET_COST_CODES = get_or_create_env_var("GET_COST_CODES", "False")
|
| 869 |
+
|
| 870 |
+
DEFAULT_COST_CODE = get_or_create_env_var("DEFAULT_COST_CODE", "")
|
| 871 |
+
|
| 872 |
+
COST_CODES_PATH = get_or_create_env_var(
|
| 873 |
+
"COST_CODES_PATH", ""
|
| 874 |
+
) # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
|
| 875 |
+
|
| 876 |
+
S3_COST_CODES_PATH = get_or_create_env_var(
|
| 877 |
+
"S3_COST_CODES_PATH", ""
|
| 878 |
+
) # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
| 879 |
+
|
| 880 |
+
# A default path in case s3 cost code location is provided but no local cost code location given
|
| 881 |
+
if COST_CODES_PATH:
|
| 882 |
+
OUTPUT_COST_CODES_PATH = COST_CODES_PATH
|
| 883 |
+
else:
|
| 884 |
+
OUTPUT_COST_CODES_PATH = "config/cost_codes.csv"
|
| 885 |
+
|
| 886 |
+
ENFORCE_COST_CODES = get_or_create_env_var(
|
| 887 |
+
"ENFORCE_COST_CODES", "False"
|
| 888 |
+
) # If you have cost codes listed, is it compulsory to choose one before redacting?
|
| 889 |
+
|
| 890 |
+
if ENFORCE_COST_CODES == "True":
|
| 891 |
+
GET_COST_CODES = "True"
|
| 892 |
+
|
| 893 |
+
###
|
| 894 |
+
# VALIDATE FOLDERS AND CONFIG OPTIONS
|
| 895 |
+
###
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def ensure_folder_exists(output_folder: str):
|
| 899 |
+
"""Checks if the specified folder exists, creates it if not."""
|
| 900 |
+
|
| 901 |
+
if not os.path.exists(output_folder):
|
| 902 |
+
# Create the folder if it doesn't exist
|
| 903 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 904 |
+
print(f"Created the {output_folder} folder.")
|
| 905 |
+
else:
|
| 906 |
+
pass
|
| 907 |
+
# print(f"The {output_folder} folder already exists.")
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
def _get_env_list(env_var_name: str, strip_strings: bool = True) -> List[str]:
|
| 911 |
+
"""Parses a comma-separated environment variable into a list of strings."""
|
| 912 |
+
value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
|
| 913 |
+
if not value:
|
| 914 |
+
return []
|
| 915 |
+
# Split by comma and filter out any empty strings that might result from extra commas
|
| 916 |
+
if strip_strings:
|
| 917 |
+
return [s.strip() for s in value.split(",") if s.strip()]
|
| 918 |
+
else:
|
| 919 |
+
return [codecs.decode(s, "unicode_escape") for s in value.split(",") if s]
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
# Convert string environment variables to string or list
|
| 923 |
+
if SAVE_LOGS_TO_CSV == "True":
|
| 924 |
+
SAVE_LOGS_TO_CSV = True
|
| 925 |
+
else:
|
| 926 |
+
SAVE_LOGS_TO_CSV = False
|
| 927 |
+
if SAVE_LOGS_TO_DYNAMODB == "True":
|
| 928 |
+
SAVE_LOGS_TO_DYNAMODB = True
|
| 929 |
+
else:
|
| 930 |
+
SAVE_LOGS_TO_DYNAMODB = False
|
| 931 |
+
|
| 932 |
+
if CSV_ACCESS_LOG_HEADERS:
|
| 933 |
+
CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
|
| 934 |
+
if CSV_FEEDBACK_LOG_HEADERS:
|
| 935 |
+
CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
|
| 936 |
+
if CSV_USAGE_LOG_HEADERS:
|
| 937 |
+
CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
|
| 938 |
+
|
| 939 |
+
if DYNAMODB_ACCESS_LOG_HEADERS:
|
| 940 |
+
DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
|
| 941 |
+
if DYNAMODB_FEEDBACK_LOG_HEADERS:
|
| 942 |
+
DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
|
| 943 |
+
if DYNAMODB_USAGE_LOG_HEADERS:
|
| 944 |
+
DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
|
| 945 |
+
|
| 946 |
+
# Set DIRECT_MODE_INFERENCE_SERVER_MODEL after CHOSEN_INFERENCE_SERVER_MODEL is defined
|
| 947 |
+
DIRECT_MODE_INFERENCE_SERVER_MODEL = get_or_create_env_var(
|
| 948 |
+
"DIRECT_MODE_INFERENCE_SERVER_MODEL",
|
| 949 |
+
CHOSEN_INFERENCE_SERVER_MODEL if CHOSEN_INFERENCE_SERVER_MODEL else "",
|
| 950 |
+
)
|
tools/custom_csvlogger.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
# from multiprocessing import Lock
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import TYPE_CHECKING, Any
|
| 14 |
+
|
| 15 |
+
import boto3
|
| 16 |
+
import botocore
|
| 17 |
+
from gradio import utils
|
| 18 |
+
from gradio_client import utils as client_utils
|
| 19 |
+
|
| 20 |
+
from tools.config import AWS_ACCESS_KEY, AWS_REGION, AWS_SECRET_KEY, RUN_AWS_FUNCTIONS
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from gradio.components import Component
|
| 24 |
+
from threading import Lock
|
| 25 |
+
|
| 26 |
+
from gradio.flagging import FlaggingCallback
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CSVLogger_custom(FlaggingCallback):
|
| 30 |
+
"""
|
| 31 |
+
The default implementation of the FlaggingCallback abstract class in gradio>=5.0. Each flagged
|
| 32 |
+
sample (both the input and output data) is logged to a CSV file with headers on the machine running
|
| 33 |
+
the gradio app. Unlike ClassicCSVLogger, this implementation is concurrent-safe and it creates a new
|
| 34 |
+
dataset file every time the headers of the CSV (derived from the labels of the components) change. It also
|
| 35 |
+
only creates columns for "username" and "flag" if the flag_option and username are provided, respectively.
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
import gradio as gr
|
| 39 |
+
def image_classifier(inp):
|
| 40 |
+
return {'cat': 0.3, 'dog': 0.7}
|
| 41 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
| 42 |
+
flagging_callback=CSVLogger())
|
| 43 |
+
Guides: using-flagging
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
simplify_file_data: bool = True,
|
| 49 |
+
verbose: bool = True,
|
| 50 |
+
dataset_file_name: str | None = None,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Parameters:
|
| 54 |
+
simplify_file_data: If True, the file data will be simplified before being written to the CSV file. If CSVLogger is being used to cache examples, this is set to False to preserve the original FileData class
|
| 55 |
+
verbose: If True, prints messages to the console about the dataset file creation
|
| 56 |
+
dataset_file_name: The name of the dataset file to be created (should end in ".csv"). If None, the dataset file will be named "dataset1.csv" or the next available number.
|
| 57 |
+
"""
|
| 58 |
+
self.simplify_file_data = simplify_file_data
|
| 59 |
+
self.verbose = verbose
|
| 60 |
+
self.dataset_file_name = dataset_file_name
|
| 61 |
+
self.lock = Lock()
|
| 62 |
+
|
| 63 |
+
def setup(
|
| 64 |
+
self,
|
| 65 |
+
components: Sequence[Component],
|
| 66 |
+
flagging_dir: str | Path,
|
| 67 |
+
):
|
| 68 |
+
self.components = components
|
| 69 |
+
self.flagging_dir = Path(flagging_dir)
|
| 70 |
+
self.first_time = True
|
| 71 |
+
|
| 72 |
+
def _create_dataset_file(
|
| 73 |
+
self,
|
| 74 |
+
additional_headers: list[str] | None = None,
|
| 75 |
+
replacement_headers: list[str] | None = None,
|
| 76 |
+
):
|
| 77 |
+
os.makedirs(self.flagging_dir, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
if replacement_headers:
|
| 80 |
+
if additional_headers is None:
|
| 81 |
+
additional_headers = []
|
| 82 |
+
|
| 83 |
+
if len(replacement_headers) != len(self.components):
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"replacement_headers must have the same length as components "
|
| 86 |
+
f"({len(replacement_headers)} provided, {len(self.components)} expected)"
|
| 87 |
+
)
|
| 88 |
+
headers = replacement_headers + additional_headers + ["timestamp"]
|
| 89 |
+
else:
|
| 90 |
+
if additional_headers is None:
|
| 91 |
+
additional_headers = []
|
| 92 |
+
headers = (
|
| 93 |
+
[
|
| 94 |
+
getattr(component, "label", None) or f"component {idx}"
|
| 95 |
+
for idx, component in enumerate(self.components)
|
| 96 |
+
]
|
| 97 |
+
+ additional_headers
|
| 98 |
+
+ ["timestamp"]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
headers = utils.sanitize_list_for_csv(headers)
|
| 102 |
+
dataset_files = list(Path(self.flagging_dir).glob("dataset*.csv"))
|
| 103 |
+
|
| 104 |
+
if self.dataset_file_name:
|
| 105 |
+
self.dataset_filepath = self.flagging_dir / self.dataset_file_name
|
| 106 |
+
elif dataset_files:
|
| 107 |
+
try:
|
| 108 |
+
latest_file = max(
|
| 109 |
+
dataset_files, key=lambda f: int(re.findall(r"\d+", f.stem)[0])
|
| 110 |
+
)
|
| 111 |
+
latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
|
| 112 |
+
|
| 113 |
+
with open(latest_file, newline="", encoding="utf-8-sig") as csvfile:
|
| 114 |
+
reader = csv.reader(csvfile)
|
| 115 |
+
existing_headers = next(reader, None)
|
| 116 |
+
|
| 117 |
+
if existing_headers != headers:
|
| 118 |
+
new_num = latest_num + 1
|
| 119 |
+
self.dataset_filepath = self.flagging_dir / f"dataset{new_num}.csv"
|
| 120 |
+
else:
|
| 121 |
+
self.dataset_filepath = latest_file
|
| 122 |
+
except Exception:
|
| 123 |
+
self.dataset_filepath = self.flagging_dir / "dataset1.csv"
|
| 124 |
+
else:
|
| 125 |
+
self.dataset_filepath = self.flagging_dir / "dataset1.csv"
|
| 126 |
+
|
| 127 |
+
if not Path(self.dataset_filepath).exists():
|
| 128 |
+
with open(
|
| 129 |
+
self.dataset_filepath, "w", newline="", encoding="utf-8-sig"
|
| 130 |
+
) as csvfile:
|
| 131 |
+
writer = csv.writer(csvfile)
|
| 132 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
| 133 |
+
if self.verbose:
|
| 134 |
+
print("Created dataset file at:", self.dataset_filepath)
|
| 135 |
+
elif self.verbose:
|
| 136 |
+
print("Using existing dataset file at:", self.dataset_filepath)
|
| 137 |
+
|
| 138 |
+
def flag(
|
| 139 |
+
self,
|
| 140 |
+
flag_data: list[Any],
|
| 141 |
+
flag_option: str | None = None,
|
| 142 |
+
username: str | None = None,
|
| 143 |
+
save_to_csv: bool = True,
|
| 144 |
+
save_to_dynamodb: bool = False,
|
| 145 |
+
dynamodb_table_name: str | None = None,
|
| 146 |
+
dynamodb_headers: list[str] | None = None, # New: specify headers for DynamoDB
|
| 147 |
+
replacement_headers: list[str] | None = None,
|
| 148 |
+
) -> int:
|
| 149 |
+
if self.first_time:
|
| 150 |
+
# print("First time creating log file")
|
| 151 |
+
additional_headers = []
|
| 152 |
+
if flag_option is not None:
|
| 153 |
+
additional_headers.append("flag")
|
| 154 |
+
if username is not None:
|
| 155 |
+
additional_headers.append("username")
|
| 156 |
+
additional_headers.append("id")
|
| 157 |
+
# additional_headers.append("timestamp")
|
| 158 |
+
self._create_dataset_file(
|
| 159 |
+
additional_headers=additional_headers,
|
| 160 |
+
replacement_headers=replacement_headers,
|
| 161 |
+
)
|
| 162 |
+
self.first_time = False
|
| 163 |
+
|
| 164 |
+
csv_data = []
|
| 165 |
+
for idx, (component, sample) in enumerate(
|
| 166 |
+
zip(self.components, flag_data, strict=False)
|
| 167 |
+
):
|
| 168 |
+
save_dir = (
|
| 169 |
+
self.flagging_dir
|
| 170 |
+
/ client_utils.strip_invalid_filename_characters(
|
| 171 |
+
getattr(component, "label", None) or f"component {idx}"
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
if utils.is_prop_update(sample):
|
| 175 |
+
csv_data.append(str(sample))
|
| 176 |
+
else:
|
| 177 |
+
data = (
|
| 178 |
+
component.flag(sample, flag_dir=save_dir)
|
| 179 |
+
if sample is not None
|
| 180 |
+
else ""
|
| 181 |
+
)
|
| 182 |
+
if self.simplify_file_data:
|
| 183 |
+
data = utils.simplify_file_data_in_str(data)
|
| 184 |
+
csv_data.append(data)
|
| 185 |
+
|
| 186 |
+
if flag_option is not None:
|
| 187 |
+
csv_data.append(flag_option)
|
| 188 |
+
if username is not None:
|
| 189 |
+
csv_data.append(username)
|
| 190 |
+
|
| 191 |
+
generated_id = str(uuid.uuid4())
|
| 192 |
+
csv_data.append(generated_id)
|
| 193 |
+
|
| 194 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
|
| 195 |
+
:-3
|
| 196 |
+
] # Correct format for Amazon Athena
|
| 197 |
+
csv_data.append(timestamp)
|
| 198 |
+
|
| 199 |
+
# Build the headers
|
| 200 |
+
headers = [
|
| 201 |
+
getattr(component, "label", None) or f"component {idx}"
|
| 202 |
+
for idx, component in enumerate(self.components)
|
| 203 |
+
]
|
| 204 |
+
if flag_option is not None:
|
| 205 |
+
headers.append("flag")
|
| 206 |
+
if username is not None:
|
| 207 |
+
headers.append("username")
|
| 208 |
+
headers.append("id")
|
| 209 |
+
headers.append("timestamp")
|
| 210 |
+
|
| 211 |
+
line_count = -1
|
| 212 |
+
|
| 213 |
+
if save_to_csv:
|
| 214 |
+
with self.lock:
|
| 215 |
+
with open(
|
| 216 |
+
self.dataset_filepath, "a", newline="", encoding="utf-8-sig"
|
| 217 |
+
) as csvfile:
|
| 218 |
+
writer = csv.writer(csvfile)
|
| 219 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
| 220 |
+
with open(self.dataset_filepath, encoding="utf-8-sig") as csvfile:
|
| 221 |
+
line_count = len(list(csv.reader(csvfile))) - 1
|
| 222 |
+
|
| 223 |
+
if save_to_dynamodb is True:
|
| 224 |
+
print("Saving to DynamoDB")
|
| 225 |
+
|
| 226 |
+
if RUN_AWS_FUNCTIONS == "1":
|
| 227 |
+
try:
|
| 228 |
+
print("Connecting to DynamoDB via existing SSO connection")
|
| 229 |
+
dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
|
| 230 |
+
# client = boto3.client('dynamodb')
|
| 231 |
+
|
| 232 |
+
dynamodb.meta.client.list_tables()
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print("No SSO credentials found:", e)
|
| 236 |
+
if AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 237 |
+
print("Trying DynamoDB credentials from environment variables")
|
| 238 |
+
dynamodb = boto3.resource(
|
| 239 |
+
"dynamodb",
|
| 240 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 241 |
+
aws_secret_access_key=AWS_SECRET_KEY,
|
| 242 |
+
region_name=AWS_REGION,
|
| 243 |
+
)
|
| 244 |
+
# client = boto3.client('dynamodb',aws_access_key_id=AWS_ACCESS_KEY,
|
| 245 |
+
# aws_secret_access_key=AWS_SECRET_KEY, region_name=AWS_REGION)
|
| 246 |
+
else:
|
| 247 |
+
raise Exception(
|
| 248 |
+
"AWS credentials for DynamoDB logging not found"
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
raise Exception("AWS credentials for DynamoDB logging not found")
|
| 252 |
+
|
| 253 |
+
if dynamodb_table_name is None:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"You must provide a dynamodb_table_name if save_to_dynamodb is True"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if dynamodb_headers:
|
| 259 |
+
dynamodb_headers = dynamodb_headers
|
| 260 |
+
if not dynamodb_headers and replacement_headers:
|
| 261 |
+
dynamodb_headers = replacement_headers
|
| 262 |
+
elif headers:
|
| 263 |
+
dynamodb_headers = headers
|
| 264 |
+
elif not dynamodb_headers:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
"Headers not found. You must provide dynamodb_headers or replacement_headers to create a new table."
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if flag_option is not None:
|
| 270 |
+
if "flag" not in dynamodb_headers:
|
| 271 |
+
dynamodb_headers.append("flag")
|
| 272 |
+
if username is not None:
|
| 273 |
+
if "username" not in dynamodb_headers:
|
| 274 |
+
dynamodb_headers.append("username")
|
| 275 |
+
if "timestamp" not in dynamodb_headers:
|
| 276 |
+
dynamodb_headers.append("timestamp")
|
| 277 |
+
if "id" not in dynamodb_headers:
|
| 278 |
+
dynamodb_headers.append("id")
|
| 279 |
+
|
| 280 |
+
# Table doesn't exist β create it
|
| 281 |
+
try:
|
| 282 |
+
table = dynamodb.Table(dynamodb_table_name)
|
| 283 |
+
table.load()
|
| 284 |
+
except botocore.exceptions.ClientError as e:
|
| 285 |
+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
|
| 286 |
+
|
| 287 |
+
attribute_definitions = [
|
| 288 |
+
{
|
| 289 |
+
"AttributeName": "id",
|
| 290 |
+
"AttributeType": "S",
|
| 291 |
+
} # Only define key attributes here
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
table = dynamodb.create_table(
|
| 295 |
+
TableName=dynamodb_table_name,
|
| 296 |
+
KeySchema=[
|
| 297 |
+
{"AttributeName": "id", "KeyType": "HASH"} # Partition key
|
| 298 |
+
],
|
| 299 |
+
AttributeDefinitions=attribute_definitions,
|
| 300 |
+
BillingMode="PAY_PER_REQUEST",
|
| 301 |
+
)
|
| 302 |
+
# Wait until the table exists
|
| 303 |
+
table.meta.client.get_waiter("table_exists").wait(
|
| 304 |
+
TableName=dynamodb_table_name
|
| 305 |
+
)
|
| 306 |
+
time.sleep(5)
|
| 307 |
+
print(f"Table '{dynamodb_table_name}' created successfully.")
|
| 308 |
+
else:
|
| 309 |
+
raise
|
| 310 |
+
|
| 311 |
+
# Prepare the DynamoDB item to upload
|
| 312 |
+
try:
|
| 313 |
+
item = {
|
| 314 |
+
"id": str(generated_id), # UUID primary key
|
| 315 |
+
#'created_by': username if username else "unknown",
|
| 316 |
+
"timestamp": timestamp,
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
# Map the headers to values
|
| 320 |
+
item.update(
|
| 321 |
+
{
|
| 322 |
+
header: str(value)
|
| 323 |
+
for header, value in zip(dynamodb_headers, csv_data)
|
| 324 |
+
}
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
table.put_item(Item=item)
|
| 328 |
+
|
| 329 |
+
print("Successfully uploaded log to DynamoDB")
|
| 330 |
+
except Exception as e:
|
| 331 |
+
print("Could not upload log to DynamobDB due to", e)
|
| 332 |
+
|
| 333 |
+
return line_count
|
tools/dedup_summaries.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tools/example_table_outputs.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dummy_consultation_table = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
|
| 2 |
+
|:---------------------|:-----------------------|:------------|:--------|----------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 3 |
+
| Development proposal | Affordable housing | Positive | All | 6 | The proposed development is overwhelmingly viewed favorably by the local community, primarily due to<br>its potential to address significant needs. residents express strong support for the provision of<br>amenities, particularly much-needed housing for young people and families. crucially, the<br>development is also considered vital for increasing green space, with some suggesting this could<br>even contribute to further housing opportunities. <br>furthermore, a key theme emerging from the<br>responses is the des... |
|
| 4 |
+
| Development proposal | Environmental damage | Negative | All | 5 | A primary concern expressed across the dataset relates to the potential negative consequences of a<br>proposed development on the local environment and wildlife. multiple respondents highlighted worries<br>about environmental damage, suggesting a significant apprehension regarding the projectβs impact.<br>specifically, thereβs a shared concern about the detrimental effects on local wildlife, indicating a<br>potential disruption to the natural ecosystem.<br>furthermore, the proposed development is perceived<br>to ... |
|
| 5 |
+
| Community impact | Loss of local business | Negative | All | 4 | A significant concern expressed within the dataset relates to the potential negative consequences of<br>a development project on local businesses and the wider community. multiple respondents voiced<br>worries about increased traffic congestion, directly impacting the viability of local businesses and<br>leading to apprehension about their future. specifically, there is a palpable sadness surrounding<br>the possibility of a beloved local cafe closing, emphasizing the detrimental effect on community<br>connect... |
|
| 6 |
+
| Development proposal | Architectural style | Negative | All | 4 | The primary concern regarding the proposed development is its incompatibility with the established<br>character of the area. residents express a strong feeling that the design is fundamentally at odds<br>with the existing aesthetic and atmosphere, suggesting a lack of sensitivity to the local context.<br>this sentiment highlights a significant apprehension about disrupting the areaβs unique<br>identity.<br>furthermore, significant anxieties center on the potential negative impact of the<br>development on the surr... |
|
| 7 |
+
| Economic impact | Investment and jobs | Positive | All | 4 | The proposed development is widely anticipated to generate significant positive economic impacts<br>within the local community. residents believe it will lead to substantial investment and the<br>creation of numerous job opportunities, directly boosting the local economy and revitalizing the<br>town centre. thereβs a strong consensus that this development represents a key step towards economic<br>growth and prosperity for the area.<br>specifically, the anticipated benefits include the creation<br>of jobs for loca... |
|
| 8 |
+
| Development proposal | Height of building | Negative | All | 3 | Residents expressed significant concerns regarding the proposed developmentβs height, specifically<br>highlighting the five-storey structure as a major issue. this height was perceived as excessively<br>tall and likely to cause overshadowing of existing buildings in the area, directly impacting the<br>views enjoyed by current residents. the potential for diminished sunlight and altered visual<br>landscapes was a central point of contention.<br>furthermore, the impact on the existing character<br>of the neighborho... |
|
| 9 |
+
| Development proposal | Infrastructure impact | Negative | All | 3 | Analysis of the provided text reveals significant concerns regarding the potential consequences of a<br>proposed project on existing infrastructure. specifically, there is a notable worry about the<br>detrimental effects on local infrastructure, with the possibility of widespread disruption as a key<br>consequence. this suggests a need for careful assessment and mitigation strategies to avoid<br>negatively impacting essential services and community operations.<br>furthermore, the repeated<br>emphasis on infrastru... |
|
| 10 |
+
| Development proposal | Noise pollution | Negative | All | 3 | The primary concern highlighted within the dataset relates to anticipated noise pollution stemming<br>from the proposed development. multiple responses explicitly express this worry, emphasizing it as a<br>βsignificant concernβ and a key area of apprehension. thereβs a clear understanding that the<br>development will likely exacerbate existing noise levels within the surrounding area, suggesting a<br>potential negative impact on residents and the local environment.<br>several respondents reiterate<br>this concer... |
|
| 11 |
+
| Housing needs | Supply of housing | Positive | All | 3 | The proposed development is viewed as a crucial solution to address the townβs significant housing<br>shortage, reflecting a clear desire for increased housing supply within the community. residents<br>express a strong need for more homes, and this development is seen as a key step towards meeting<br>that demand. <br>furthermore, the project is anticipated to alleviate existing parking issues, which<br>is considered a valuable contribution to the overall housing supply. the provision of additional<br>parking spac... |
|
| 12 |
+
| Development proposal | Height of building | Neutral | All | 2 | The analysis of the provided text reveals a nuanced perspective regarding a development project,<br>with no explicit sentiment expressed concerning the building's height itself. however, significant<br>concerns are raised about the potential impact of the development on local schools. these concerns<br>appear to be linked, at least in part, to the building's height, suggesting a worry that the<br>increased scale could strain existing resources and infrastructure within the school system.<br><br>further investigat... |
|
| 13 |
+
| Community impact | Community facilities | Negative | All | 1 | Concerns exist regarding the negative impact on local amenities. |
|
| 14 |
+
| Community impact | Community facilities | Positive | All | 1 | The development will provide much-needed community facilities, enhancing the local area. |
|
| 15 |
+
| Development proposal | Architectural style | Neutral | All | 1 | The development is expected to provide facilities for young people, but no specific architectural<br>concerns. |
|
| 16 |
+
| Development proposal | Noise pollution | Neutral | All | 1 | Potential for increased noise pollution due to the development is a concern. |
|
| 17 |
+
| Economic impact | Economic decline | Negative | All | 1 | Worries about a negative impact on the local economy are expressed, suggesting potential harm. |"""
|
| 18 |
+
|
| 19 |
+
dummy_consultation_table_zero_shot = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
|
| 20 |
+
|:---------------------------|:------------------------------------|:------------|:--------|----------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 21 |
+
| Planning & development | Impact on the character of the area | Negative | All | 10 | Residents overwhelmingly express strong objections to the proposed development, primarily focusing<br>on its incompatibility with the established character of the area. A central concern is the<br>development's height and design, which they believe clashes significantly with the existing<br>aesthetic and creates a sense of being overshadowed by taller structures, leading to a feeling of<br>crampedness. Many respondents specifically highlighted the potential for the development to<br>negatively impact Main Stre... |
|
| 22 |
+
| Environmental impact | Impact on the local environment | Negative | All | 8 | Several concerns have been raised regarding the potential negative impacts of a development on the<br>local environment. Multiple respondents expressed worry about the developmentβs possible detrimental<br>effects on the surrounding environment and quality of life, highlighting a significant area of<br>concern. These anxieties include potential damage to the environment and a general feeling of unease<br>about the developmentβs consequences.<br><br>Despite a single positive note regarding the provision<br>of green s... |
|
| 23 |
+
| Infrastructure & transport | Traffic congestion | Negative | All | 7 | Concerns regarding increased traffic congestion are prevalent in the dataset, largely stemming from<br>the anticipated impact of the proposed development. Specifically, Main Street is predicted to<br>experience heightened congestion due to the increased volume of traffic it will attract. Multiple<br>responses repeatedly highlight this anticipation as a key issue associated with the<br>project.<br><br>Despite the consistent apprehension about traffic congestion, no direct responses<br>offer specific solutions or miti... |
|
| 24 |
+
| Planning & development | Need for family housing | Positive | All | 7 | The proposed development is overwhelmingly viewed as a crucial solution to the need for family<br>housing within the community. Multiple sources highlight its significance in providing much-needed<br>homes, particularly for families, and specifically addressing the demand for affordable family<br>housing options. Several respondents emphasized the beneficial impact on local residents, with the<br>development also anticipated to create jobs and offer facilities geared towards young people<br>alongside housing. ... |
|
| 25 |
+
| Quality of life | Impact on quality of life | Negative | All | 7 | Analysis of the provided text reveals significant concerns regarding a proposed development's<br>potential negative impact on the quality of life within the area. Residents are particularly worried<br>that the development will overshadow existing buildings, creating a sense of crampedness and<br>diminishing their living experience. Furthermore, anxieties extend beyond immediate residential<br>impacts, encompassing broader concerns about the developmentβs effects on local businesses, schools,<br>and crucial inf... |
|
| 26 |
+
| Economic impact | Investment and job creation | Positive | All | 6 | The proposed development is overwhelmingly viewed positively, with significant anticipation for its<br>economic impact on the area. Residents and observers alike believe it will stimulate considerable<br>investment and generate numerous job opportunities, particularly for local residents. Furthermore,<br>the project is expected to revitalize the town center and provide crucial affordable housing,<br>potentially benefiting young people seeking to establish themselves in the<br>community.<br><br>Specifically, the deve... |
|
| 27 |
+
| Infrastructure & transport | Parking | Negative | All | 6 | Analysis of the '{column_name}' column reveals significant concerns regarding the potential impact<br>of a new development on Main Street. The primary issue identified is increased traffic congestion,<br>directly linked to the developmentβs activity. Furthermore, there is widespread apprehension that<br>the project will worsen existing parking problems, with multiple respondents explicitly stating a<br>lack of adequate parking provisions as a key worry. <br><br>Specifically, numerous individuals<br>expressed concern... |
|
| 28 |
+
| Community & local life | Amenities for the local community | Positive | All | 5 | The proposed development is anticipated to significantly benefit the local community, offering a<br>range of amenities and a positive contribution to the area. Specifically, the project will deliver<br>crucial green space alongside facilities designed to cater to the needs of young people and the<br>broader community.<br><br>Furthermore, the development is expected to address critical social needs<br>by providing much-needed community facilities and social housing, indicating a commitment to<br>supporting local resi... |
|
| 29 |
+
| Environmental impact | Impact on local wildlife | Neutral | All | 4 | No specific responses were provided, and the dataset contained no information relevant to the<br>specified consultation context. Consequently, a summary cannot be generated based on the provided<br>data. <br><br>Due to the absence of any textual data within the dataset, there is no content to<br>consolidate and summarize. |
|
| 30 |
+
| Improvement of main street | Improvement of main street | Positive | All | 4 | This development is being hailed as a positive step for the revitalization of Main Street, primarily<br>due to its anticipated improvement in the streetβs appearance. Stakeholders view this initiative as<br>a crucial element in breathing new life into the area, suggesting a significant upgrade to the<br>existing landscape.<br><br>Specifically, the project aims to enhance the visual appeal of Main<br>Street, representing a tangible advancement in its overall attractiveness and desirability. The<br>development is wide... |
|
| 31 |
+
| Planning & development | Impact on views | Negative | All | 4 | A primary concern expressed regarding the proposed development is its potential negative impact on<br>existing views. Multiple respondents voiced worries about how the development might obstruct or<br>diminish the current vistas, alongside specific concerns about its effect on views from neighboring<br>properties. This suggests a significant sensitivity to the visual landscape and its value within the<br>community.<br><br>Furthermore, the potential aesthetic consequences of the development are<br>highlighted, with s... |
|
| 32 |
+
| Community & local life | Amenities for the local community | Negative | All | 2 | Residents are voicing significant concerns regarding a proposed development, primarily focusing on<br>its anticipated detrimental effects on local amenities. A key point of contention is the planned<br>removal of the existing cafe, which is being viewed as a substantial loss to the communityβs social<br>fabric and a vital local resource.<br><br>The overall sentiment suggests a strong apprehension that<br>the development will diminish the quality of life for those living nearby, highlighting a desire to<br>preserve c... |
|
| 33 |
+
| Impact on local businesses | Impact on local businesses | Negative | All | 2 | A primary concern expressed relates to the potential detrimental effects of the development on local<br>businesses. Thereβs a clear worry that the project will negatively impact these businesses,<br>suggesting a potential loss of revenue, customer base, or even business closure. The repeated<br>emphasis on a βnegative impactβ highlights a significant apprehension regarding the economic<br>repercussions for the existing business community.<br><br>The sentiment underscores a desire to<br>mitigate potential harm and li... |
|
| 34 |
+
| Impact on local heritage | Impact on local heritage | Negative | All | 2 | There are growing concerns regarding the potential negative impact of the development on the local<br>heritage. While specific details and references havenβt been explicitly stated, the underlying<br>sentiment suggests a worry about the developmentβs effects on historically significant elements<br>within the area. This implies a recognition that the proposed project could, perhaps inadvertently,<br>threaten or diminish the cultural value and character of the local environment.<br><br>The presence<br>of these concern... |
|
| 35 |
+
| Environmental impact | Impact on local wildlife | Negative | All | 1 | Concerns regarding the negative impact of the development on local wildlife. |
|
| 36 |
+
| Impact on local heritage | Impact on local heritage | Neutral | All | 1 | No specific responses mention this topic. |
|
| 37 |
+
| Impact on local schools | Impact on local schools | Negative | All | 1 | Concerns about the negative impact on the local schools. |
|
| 38 |
+
| Impact on local schools | Impact on local schools | Neutral | All | 1 | No specific responses mention this topic. |
|
| 39 |
+
| Infrastructure & transport | Parking | Positive | All | 1 | The development is expected to provide much-needed parking spaces. |"""
|
| 40 |
+
|
| 41 |
+
case_notes_table = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
|
| 42 |
+
|:------------------|:----------------------------|:------------|:--------|----------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 43 |
+
| Family dynamics | Parental conflict | Negative | All | 6 | Several parents expressed significant concerns regarding the well-being of their children, primarily<br>focusing on escalating aggression and withdrawal. alexβs mother specifically highlighted a pattern<br>of arguments at home and attributed the aggressive behavior to external provocation, suggesting a<br>destabilizing family environment. furthermore, parents voiced a lack of confidence in existing<br>interventions for their children, particularly jamie, indicating a perceived need for supplemental<br>support ... |
|
| 44 |
+
| Mental health | Feelings of isolation | Negative | All | 6 | Several individuals expressed significant emotional distress, primarily centered around feelings of<br>isolation and hopelessness. jamieβs withdrawal and reported emptiness suggest a deep-seated sense of<br>disconnection, while alex powerfully articulated his feeling of being misunderstood with the<br>statement βno one gets me.β these experiences appear to be impacting daily functioning, as evidenced<br>by jamieβs struggles and disruption of his sleep patterns. <br>the observations from parents further<br>indicat... |
|
| 45 |
+
| Overall mood | Agitation | Negative | All | 6 | Several individuals expressed concerns regarding heightened emotional distress within the subject<br>group, primarily focusing on alex. alex repeatedly demonstrated significant frustration and<br>aggression, necessitating anger management support, and appeared increasingly agitated in recent<br>meetings, suggesting a worsening emotional state. parents specifically highlighted jamieβs ongoing<br>struggles and agitation, emphasizing the need for a comprehensive assessment and subsequent<br>intervention.<br>while so... |
|
| 46 |
+
| Family dynamics | Peer influence | Negative | All | 5 | Concerns centered around alexβs social circle and behavior highlighted potential negative peer<br>influence, specifically due to his new friends and frequent late-night activities. further<br>investigation revealed troubling admissions from alex himself, including alcohol use and<br>participation in a physical altercation with a fellow student, indicating a concerning pattern of<br>risk-taking behavior. <br>simultaneously, jamieβs situation presented a separate area of concern,<br>characterized by isolation and l... |
|
| 47 |
+
| Substance use | Potential substance abuse | Negative | All | 5 | Alex has disclosed alcohol use, raising significant concerns about potential ongoing substance<br>abuse. the situation is further complicated by reports from alexβs mother, who has observed<br>potential signs of substance abuse and expressed her worries regarding this matter. these<br>observations highlight a need for further assessment and support to address the individualβs<br>substance use patterns and ensure their well-being.<br>the situation requires careful monitoring and<br>intervention. the motherβs repor... |
|
| 48 |
+
| Mental health | Depression | Positive | All | 3 | Jamie is diagnosed with major depressive disorder and initiated on antidepressant medication, with<br>initial positive feedback on mood and energy. |
|
| 49 |
+
| Mental health | Self-harm | Negative | All | 3 | The assessment revealed a complex picture regarding the individualβs mental state. while initial<br>observations did not indicate any active self-harm, a thorough evaluation is strongly recommended to<br>identify potential underlying issues contributing to the risk. this proactive approach is crucial<br>for a complete understanding of the individualβs needs.<br>more concerningly, alex presented with<br>visible self-harm indicators on his arms and explicitly communicated thoughts of self-harm,<br>signifying a sign... |
|
| 50 |
+
| School engagement | Absenteeism | Negative | All | 3 | Recent reports highlight a concerning trend of declining student engagement within the school<br>environment. specifically, there has been an increase in absences alongside a decrease in academic<br>performance, suggesting a fundamental lack of connection with schoolwork and learning. several<br>students, including jamie, are exhibiting problematic behaviors that further underscore this issue,<br>such as consistent tardiness and reduced participation in classroom activities.<br>furthermore,<br>observations indica... |
|
| 51 |
+
| Mental health | Depression | Negative | All | 2 | Jamie exhibits symptoms of moderate depression, requiring further evaluation and intervention. |
|
| 52 |
+
| Mental health | Self-harm | Neutral | All | 2 | The psychiatristβs assessment centered on the potential advantages and drawbacks of antidepressant<br>medication, with a notable emphasis on evaluating the possibility of self-harm risk. this indicates<br>a proactive approach to patient safety and a recognition of the complex interplay between medication<br>and mental health. the discussion highlights a careful consideration of the potential for increased<br>suicidal ideation, suggesting a thorough risk assessment was undertaken.<br>furthermore, the<br>analysis o... |
|
| 53 |
+
| School engagement | Academic performance | Negative | All | 2 | Analysis of the provided text reveals concerns regarding student engagement and academic<br>performance. specifically, jamieβs reduced involvement in class is flagged as a potential indicator<br>of negative consequences, with declining grades reported as a direct result. this suggests a<br>concerning downward trend in alexβs academic progress, highlighting a need for further investigation<br>into the underlying causes of this shift.<br>the combined observations point to a possible<br>correlation between decreased... |
|
| 54 |
+
| Substance use | Substance use (unspecified) | Negative | All | 2 | Concerns regarding ongoing substance use prompted discussion about the possibility of a short-term<br>residential treatment program. alexβs involvement highlighted a potential issue, as they reported<br>occasional substance use, though the specific substances involved were not detailed during the<br>consultation. this lack of specificity regarding the substances used raises a need for further<br>investigation into the nature and frequency of alexβs substance use.<br>the consultation focused on<br>assessing the ri... |
|
| 55 |
+
| Family dynamics | Stepfather relationship | Negative | All | 1 | Alex displayed sudden outbursts of anger when discussing his new stepfather, indicating significant<br>distress related to this family change. |
|
| 56 |
+
| School engagement | Academic performance | Positive | All | 1 | Jamie's academic performance has slightly improved, indicating a potential positive change. |"""
|
| 57 |
+
|
| 58 |
+
case_notes_table_grouped = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
|
| 59 |
+
|:--------------------|:---------------------------|:------------|:---------|----------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 60 |
+
| Trends over time | Trends over time | Negative | Alex D. | 7 | Alexβs case note reveals a troubling deterioration in his well-being marked by a gradual escalation<br>of issues. Initially, the record details an incident involving a physical altercation, which quickly<br>spiraled into increasingly concerning behaviours at home, specifically escalating aggression. Over<br>subsequent meetings, observations consistently pointed towards heightened agitation and expressions<br>of hopelessness, indicating a worsening emotional state and a significant decline in his overall<br>con... |
|
| 61 |
+
| Physical health | Substance misuse | Negative | Alex D. | 6 | Alexβs substance use remains a significant concern, necessitating continued vigilance and support<br>despite recent positive developments in group therapy. While Alex has acknowledged instances of<br>substance use, the details surrounding these occurrences have not been shared, raising questions<br>about the extent and nature of the problem. Concerns were specifically noted regarding potential<br>substance abuse, highlighting a need for further investigation and assessment.<br><br>Ongoing<br>monitoring is crucial to... |
|
| 62 |
+
| Behaviour at school | Behaviour at school | Negative | Alex D. | 3 | A recent case note details a troubling incident involving a physical altercation at school,<br>alongside concerning admissions from Alex regarding alcohol use. This event has sparked worries<br>about potential behavioural issues within the school setting, suggesting a need for further<br>investigation and support. Alexβs demeanor was notably problematic, characterized by sullen behavior<br>and a deliberate avoidance of eye contact, indicating a possible struggle with emotional<br>regulation.<br><br>Furthermore, Alex... |
|
| 63 |
+
| Mental health | Anger | Negative | Alex D. | 3 | Alex exhibits a pronounced anger issue, characterized by frustration and a tendency to blame others<br>for triggering his aggressive behavior. He demonstrated this significantly when discussing his<br>personal life, particularly relating to his new stepfather, suggesting a volatile emotional response<br>to this change. The observed outbursts highlight a need for immediate intervention to manage his<br>escalating anger.<br><br>Further investigation reveals that Alexβs anger is closely linked to his<br>home environmen... |
|
| 64 |
+
| Mental health | Self-harm | Negative | Alex D. | 3 | The analysis reveals significant concerns regarding Alexβs mental health, centering around potential<br>self-harm behaviors. Indications suggest a possible diagnosis of Oppositional Defiant Disorder<br>alongside a co-occurring substance use disorder, warranting a comprehensive treatment plan. Alex<br>demonstrated visible signs of self-harm and openly confessed to experiencing thoughts of self-harm,<br>highlighting a critical need for immediate intervention.<br><br>Following this disclosure, an<br>immediate referral ... |
|
| 65 |
+
| Mental health | Social issues | Negative | Alex D. | 3 | Alex exhibits a pattern of blaming others for his problematic behavior, indicating underlying<br>challenges in social interaction and conflict resolution. This behavior appears to be contributing<br>to further instability in his life. Specifically, his mother voiced concerns regarding his new<br>social circle and increasingly frequent late-night activities, suggesting she perceives these<br>relationships and outings as potentially risky.<br><br>The motherβs observations highlight a<br>potential area of concern for A... |
|
| 66 |
+
| Mental health | Depression | Negative | Jamie L. | 6 | Jamie is currently experiencing concerning symptoms indicative of depression, as noted by both<br>Jamieβs behavior and parental observations. Specifically, he demonstrates limited social<br>interaction, struggles with his mood, and has difficulty engaging with his schoolwork. These<br>difficulties appear persistent, with parents reporting ongoing struggles despite occasional positive<br>moments. <br><br>Further assessment suggests a more pronounced picture, with indications of moderate<br>depression characterized by... |
|
| 67 |
+
| Mental health | Social isolation | Negative | Jamie L. | 4 | Jamie is experiencing significant social isolation, which is negatively affecting both his academic<br>performance and his general well-being. He has expressed feelings of loneliness and difficulty<br>sleeping, strongly suggesting a core social issue is contributing to his distress. Current efforts<br>are focused on promoting increased social interaction to address these challenges.<br><br>The report<br>highlights the urgency of this situation, emphasizing the need for intervention to mitigate Jamieβs<br>isolation a... |
|
| 68 |
+
| Mental health | Medication | Neutral | Jamie L. | 3 | Consideration is being given to medication as a potential intervention alongside therapy to manage<br>depressive symptoms. Initial feedback on the antidepressant is positive. |
|
| 69 |
+
| Mental health | Withdrawal & sadness | Negative | Jamie L. | 3 | Jamie is experiencing a significant downturn in his emotional state, characterized by withdrawal,<br>sadness, and a pervasive sense of emptiness and hopelessness. These negative feelings appear to be<br>triggered by recent reports of tardiness and decreased participation, suggesting a possible link<br>between his behavior and external pressures or expectations. The combination of these symptoms<br>points to a low mood and a feeling of struggle, indicating a potentially serious situation requiring<br>attention.... |
|
| 70 |
+
| Mental health | Low self-worth | Negative | Jamie L. | 2 | Parents are increasingly concerned about Jamieβs well-being due to observed difficulties and a<br>potential lack of self-worth. These concerns are primarily fueled by Jamieβs own statements, where<br>he articulated feelings of low self-esteem and a significant struggle to find<br>motivation.<br><br>Further investigation revealed a direct link between Jamieβs emotional state and<br>recent family financial hardships. The pressures of these struggles appear to have deeply impacted<br>his self-perception and ability to ... |
|
| 71 |
+
| Trends over time | Increasing withdrawal | Negative | Jamie L. | 2 | A significant and worrying trend is emerging regarding withdrawal, necessitating continuous<br>observation and targeted intervention strategies. Specifically, Jamie is exhibiting a noticeable<br>decline in engagement with family activities, representing a key indicator of this broader issue.<br>This withdrawal suggests a potential underlying problem requiring careful assessment and proactive<br>support.<br><br>The observed pattern of withdrawal highlights the importance of sustained monitoring<br>to understand its p... |
|
| 72 |
+
| Behaviour at school | Attendance issues | Negative | Jamie L. | 1 | Jamieβs consistent tardiness was a concern leading to a meeting. |
|
| 73 |
+
| Behaviour at school | Reduced participation | Negative | Jamie L. | 1 | Jamieβs decreased participation in class was noted. |
|
| 74 |
+
| Behaviour at school | Social engagement | Negative | Jamie L. | 1 | Jamie's withdrawal from family activities and hobbies was highlighted. |
|
| 75 |
+
| Behaviour at school | Social engagement | Positive | Jamie L. | 1 | Encouraging Jamie to join school clubs and groups is a strategy to foster social connection and<br>improve his social engagement. |
|
| 76 |
+
| Family & social | Family communication | Negative | Jamie L. | 1 | Parents expressed concerns about Jamieβs withdrawal and lack of communication within the family. |
|
| 77 |
+
| Family & social | Family communication | Neutral | Jamie L. | 1 | Parents are actively involved in Jamie's care and are communicating their observations to the care<br>team. |
|
| 78 |
+
| Family & social | Family financial struggles | Negative | Jamie L. | 1 | Jamie's low motivation is attributed to recent family financial difficulties. |"""
|
| 79 |
+
|
| 80 |
+
case_notes_table_structured_summary = """| Main heading | Subheading | Summary | Group |
|
| 81 |
+
|:--------------------|:--------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------|
|
| 82 |
+
| Behaviour at school | Behaviour at school | Several cases involved disruptions at school, including increased absences, declining grades, and a<br>physical altercation. Alex displayed sullenness, avoidance, and agitation, sometimes reacting with<br>frustration. A key theme was isolation and a lack of connection with peers and school staff. | Alex D. |
|
| 83 |
+
| Mental health | Anger | Anger was a prominent feature across multiple cases, particularly when discussing home life and<br>family dynamics. Outbursts of anger were observed, especially related to a new stepfather, and Alex<br>displayed defensiveness when questioned about his actions. | Alex D. |
|
| 84 |
+
| Mental health | Social issues | Alex experienced feelings of isolation and difficulty connecting with others. He had a new group of<br>friends and engaged in late-night outings, which raised concerns about potential risky behaviours<br>and social influences. | Alex D. |
|
| 85 |
+
| Physical health | General | Signs of self-harm were present on Alexβs arms, indicating a heightened level of distress and<br>potentially a need for immediate support. He displayed visible agitation and defensive behaviour<br>during questioning. | Alex D. |
|
| 86 |
+
| Physical health | Substance misuse | Substance use was a recurring concern, with Alex admitting to occasional substance use and his<br>mother reporting potential signs of abuse. Alcohol use was noted in several instances, leading to<br>recommendations for assessment and potential intervention. | Alex D. |
|
| 87 |
+
| Trends over time | Trends over time | There was a gradual escalation of concerning behaviours over time. Early interventions focused on<br>initial meetings and observation, progressing to more intensive interventions like referrals to<br>mental health professionals, residential treatment programs, and family counseling. | Alex D. |
|
| 88 |
+
| Behaviour at school | Behaviour at school | Jamie exhibited concerning behaviours at school, including consistent tardiness and decreased<br>participation in class. This was accompanied by withdrawn behaviour and signs of sadness, suggesting<br>a need for immediate intervention to address potential underlying issues impacting his academic<br>performance. | Jamie L. |
|
| 89 |
+
| Mental health | Anger | There is no direct indication of anger in Jamie's case notes. | Jamie L. |
|
| 90 |
+
| Mental health | Mental health | Jamie displayed concerning signs of mental health difficulties, including feelings of emptiness,<br>hopelessness, low self-worth, and isolation. He reported difficulty sleeping and a lack of<br>motivation. The need for a comprehensive mental health assessment was highlighted to fully<br>understand the nature and severity of his condition. | Jamie L. |
|
| 91 |
+
| Mental health | Social issues | Jamie experienced significant social difficulties, including limited social interactions, feelings<br>of isolation, and a lack of engagement with family activities and hobbies. He spends a lot of time<br>alone in his room. Recommendations focused on fostering connection through school clubs and family<br>therapy were made. | Jamie L. |
|
| 92 |
+
| Physical health | General | While no direct physical health concerns were explicitly stated, Jamie's emotional state and<br>associated symptoms (difficulty sleeping) warrant consideration of his overall well-being and<br>potential physical manifestations of his mental health challenges. | Jamie L. |
|
| 93 |
+
| Physical health | Substance misuse | There is no indication of substance misuse in the provided case notes. | Jamie L. |
|
| 94 |
+
| Trends over time | Trends over time | Jamieβs case demonstrates fluctuating progress. Initial feedback indicated slight improvements in<br>mood on some days, but overall he continues to struggle. A shift occurred with the commencement of<br>antidepressant medication, showing initial positive feedback in terms of mood and energy levels,<br>requiring continued monitoring and adjustment. | Jamie L. |"""
|
tools/helper_functions.py
ADDED
|
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import codecs
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import boto3
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from botocore.exceptions import ClientError
|
| 12 |
+
|
| 13 |
+
from tools.config import (
|
| 14 |
+
AWS_USER_POOL_ID,
|
| 15 |
+
CUSTOM_HEADER,
|
| 16 |
+
CUSTOM_HEADER_VALUE,
|
| 17 |
+
INPUT_FOLDER,
|
| 18 |
+
MAXIMUM_ZERO_SHOT_TOPICS,
|
| 19 |
+
OUTPUT_FOLDER,
|
| 20 |
+
SESSION_OUTPUT_FOLDER,
|
| 21 |
+
model_full_names,
|
| 22 |
+
model_name_map,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def empty_output_vars_extract_topics():
|
| 27 |
+
# Empty output objects before processing a new file
|
| 28 |
+
|
| 29 |
+
master_topic_df_state = pd.DataFrame()
|
| 30 |
+
master_topic_summary_df_state = pd.DataFrame()
|
| 31 |
+
master_reference_df_state = pd.DataFrame()
|
| 32 |
+
text_output_file = list()
|
| 33 |
+
text_output_file_list_state = list()
|
| 34 |
+
latest_batch_completed = 0
|
| 35 |
+
log_files_output = list()
|
| 36 |
+
log_files_output_list_state = list()
|
| 37 |
+
conversation_metadata_textbox = ""
|
| 38 |
+
estimated_time_taken_number = 0
|
| 39 |
+
file_data_state = pd.DataFrame()
|
| 40 |
+
reference_data_file_name_textbox = ""
|
| 41 |
+
display_topic_table_markdown = ""
|
| 42 |
+
summary_output_file_list = list()
|
| 43 |
+
summary_input_file_list = list()
|
| 44 |
+
overall_summarisation_input_files = list()
|
| 45 |
+
overall_summary_output_files = list()
|
| 46 |
+
|
| 47 |
+
return (
|
| 48 |
+
master_topic_df_state,
|
| 49 |
+
master_topic_summary_df_state,
|
| 50 |
+
master_reference_df_state,
|
| 51 |
+
text_output_file,
|
| 52 |
+
text_output_file_list_state,
|
| 53 |
+
latest_batch_completed,
|
| 54 |
+
log_files_output,
|
| 55 |
+
log_files_output_list_state,
|
| 56 |
+
conversation_metadata_textbox,
|
| 57 |
+
estimated_time_taken_number,
|
| 58 |
+
file_data_state,
|
| 59 |
+
reference_data_file_name_textbox,
|
| 60 |
+
display_topic_table_markdown,
|
| 61 |
+
summary_output_file_list,
|
| 62 |
+
summary_input_file_list,
|
| 63 |
+
overall_summarisation_input_files,
|
| 64 |
+
overall_summary_output_files,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def empty_output_vars_summarise():
|
| 69 |
+
# Empty output objects before summarising files
|
| 70 |
+
|
| 71 |
+
summary_reference_table_sample_state = pd.DataFrame()
|
| 72 |
+
master_topic_summary_df_revised_summaries_state = pd.DataFrame()
|
| 73 |
+
master_reference_df_revised_summaries_state = pd.DataFrame()
|
| 74 |
+
summary_output_files = list()
|
| 75 |
+
summarised_outputs_list = list()
|
| 76 |
+
latest_summary_completed_num = 0
|
| 77 |
+
overall_summarisation_input_files = list()
|
| 78 |
+
|
| 79 |
+
return (
|
| 80 |
+
summary_reference_table_sample_state,
|
| 81 |
+
master_topic_summary_df_revised_summaries_state,
|
| 82 |
+
master_reference_df_revised_summaries_state,
|
| 83 |
+
summary_output_files,
|
| 84 |
+
summarised_outputs_list,
|
| 85 |
+
latest_summary_completed_num,
|
| 86 |
+
overall_summarisation_input_files,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_or_create_env_var(var_name: str, default_value: str):
|
| 91 |
+
# Get the environment variable if it exists
|
| 92 |
+
value = os.environ.get(var_name)
|
| 93 |
+
|
| 94 |
+
# If it doesn't exist, set it to the default value
|
| 95 |
+
if value is None:
|
| 96 |
+
os.environ[var_name] = default_value
|
| 97 |
+
value = default_value
|
| 98 |
+
|
| 99 |
+
return value
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_file_path_with_extension(file_path: str):
|
| 103 |
+
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
| 104 |
+
basename = os.path.basename(file_path)
|
| 105 |
+
|
| 106 |
+
# Return the basename with its extension
|
| 107 |
+
return basename
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_file_name_no_ext(file_path: str):
|
| 111 |
+
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
| 112 |
+
basename = os.path.basename(file_path)
|
| 113 |
+
|
| 114 |
+
# Then, split the basename and its extension and return only the basename without the extension
|
| 115 |
+
filename_without_extension, _ = os.path.splitext(basename)
|
| 116 |
+
|
| 117 |
+
# print(filename_without_extension)
|
| 118 |
+
|
| 119 |
+
return filename_without_extension
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def detect_file_type(filename: str):
|
| 123 |
+
"""Detect the file type based on its extension."""
|
| 124 |
+
|
| 125 |
+
# Strip quotes and whitespace that might have been accidentally included
|
| 126 |
+
filename = filename.strip().strip("'\"")
|
| 127 |
+
|
| 128 |
+
if (
|
| 129 |
+
(filename.endswith(".csv"))
|
| 130 |
+
| (filename.endswith(".csv.gz"))
|
| 131 |
+
| (filename.endswith(".zip"))
|
| 132 |
+
):
|
| 133 |
+
return "csv"
|
| 134 |
+
elif filename.endswith(".xlsx"):
|
| 135 |
+
return "xlsx"
|
| 136 |
+
elif filename.endswith(".parquet"):
|
| 137 |
+
return "parquet"
|
| 138 |
+
elif filename.endswith(".pdf"):
|
| 139 |
+
return "pdf"
|
| 140 |
+
elif filename.endswith(".jpg"):
|
| 141 |
+
return "jpg"
|
| 142 |
+
elif filename.endswith(".jpeg"):
|
| 143 |
+
return "jpeg"
|
| 144 |
+
elif filename.endswith(".png"):
|
| 145 |
+
return "png"
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError("Unsupported file type.")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def read_file(filename: str, sheet: str = ""):
|
| 151 |
+
"""Read the file based on its detected type."""
|
| 152 |
+
# Strip quotes and whitespace that might have been accidentally included
|
| 153 |
+
filename = filename.strip().strip("'\"")
|
| 154 |
+
file_type = detect_file_type(filename)
|
| 155 |
+
|
| 156 |
+
if file_type == "csv":
|
| 157 |
+
return pd.read_csv(filename, low_memory=False)
|
| 158 |
+
elif file_type == "xlsx":
|
| 159 |
+
if sheet:
|
| 160 |
+
return pd.read_excel(filename, sheet_name=sheet)
|
| 161 |
+
else:
|
| 162 |
+
return pd.read_excel(filename)
|
| 163 |
+
elif file_type == "parquet":
|
| 164 |
+
return pd.read_parquet(filename)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_in_file(file_path: str, colnames: List[str] = "", excel_sheet: str = ""):
|
| 168 |
+
"""
|
| 169 |
+
Loads in a tabular data file and returns data and file name.
|
| 170 |
+
|
| 171 |
+
Parameters:
|
| 172 |
+
- file_path (str): The path to the file to be processed.
|
| 173 |
+
- colnames (List[str], optional): list of colnames to load in
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
file_name = get_file_name_no_ext(file_path)
|
| 177 |
+
file_data = read_file(file_path, excel_sheet)
|
| 178 |
+
|
| 179 |
+
if colnames and isinstance(colnames, list):
|
| 180 |
+
col_list = colnames
|
| 181 |
+
else:
|
| 182 |
+
col_list = list(file_data.columns)
|
| 183 |
+
|
| 184 |
+
if not isinstance(col_list, List):
|
| 185 |
+
col_list = [col_list]
|
| 186 |
+
|
| 187 |
+
col_list = [item for item in col_list if item not in ["", "NA"]]
|
| 188 |
+
|
| 189 |
+
for col in col_list:
|
| 190 |
+
file_data[col] = file_data[col].fillna("")
|
| 191 |
+
file_data[col] = (
|
| 192 |
+
file_data[col].astype(str).str.replace("\bnan\b", "", regex=True)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# print(file_data[colnames])
|
| 196 |
+
|
| 197 |
+
return file_data, file_name
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def load_in_data_file(
|
| 201 |
+
file_paths: List[str],
|
| 202 |
+
in_colnames: List[str],
|
| 203 |
+
batch_size: int = 5,
|
| 204 |
+
in_excel_sheets: str = "",
|
| 205 |
+
):
|
| 206 |
+
"""Load in data table, work out how many batches needed."""
|
| 207 |
+
|
| 208 |
+
if not isinstance(in_colnames, list):
|
| 209 |
+
in_colnames = [in_colnames]
|
| 210 |
+
|
| 211 |
+
# print("in_colnames:", in_colnames)
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
file_data, file_name = load_in_file(
|
| 215 |
+
file_paths[0], colnames=in_colnames, excel_sheet=in_excel_sheets
|
| 216 |
+
)
|
| 217 |
+
num_batches = math.ceil(len(file_data) / batch_size)
|
| 218 |
+
print(
|
| 219 |
+
f"File {file_name} loaded successfully. Number of rows: {len(file_data)}. Total number of batches: {num_batches}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print("Could not load data file due to:", e)
|
| 224 |
+
file_data = pd.DataFrame()
|
| 225 |
+
file_name = ""
|
| 226 |
+
num_batches = 1
|
| 227 |
+
|
| 228 |
+
return file_data, file_name, num_batches
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def clean_column_name(
|
| 232 |
+
column_name: str, max_length: int = 20, front_characters: bool = True
|
| 233 |
+
):
|
| 234 |
+
# Convert to string
|
| 235 |
+
column_name = str(column_name)
|
| 236 |
+
# Replace non-alphanumeric characters (except underscores) with underscores
|
| 237 |
+
column_name = re.sub(r"\W+", "_", column_name)
|
| 238 |
+
# Remove leading/trailing underscores
|
| 239 |
+
column_name = column_name.strip("_")
|
| 240 |
+
# Ensure the result is not empty; fall back to "column" if necessary
|
| 241 |
+
column_name = column_name if column_name else "column"
|
| 242 |
+
# Truncate to max_length
|
| 243 |
+
if front_characters is True:
|
| 244 |
+
output_text = column_name[:max_length]
|
| 245 |
+
else:
|
| 246 |
+
output_text = column_name[-max_length:]
|
| 247 |
+
return output_text
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_in_previous_reference_file(file: str):
|
| 251 |
+
"""Load in data table from a partially completed consultation summary to continue it."""
|
| 252 |
+
|
| 253 |
+
reference_file_data = pd.DataFrame()
|
| 254 |
+
reference_file_name = ""
|
| 255 |
+
out_message = ""
|
| 256 |
+
|
| 257 |
+
# for file in file_paths:
|
| 258 |
+
|
| 259 |
+
print("file:", file)
|
| 260 |
+
|
| 261 |
+
# If reference table
|
| 262 |
+
if "reference_table" in file:
|
| 263 |
+
try:
|
| 264 |
+
reference_file_data, reference_file_name = load_in_file(file)
|
| 265 |
+
# print("reference_file_data:", reference_file_data.head(2))
|
| 266 |
+
out_message = out_message + " Reference file load successful."
|
| 267 |
+
except Exception as e:
|
| 268 |
+
out_message = "Could not load reference file data:" + str(e)
|
| 269 |
+
raise Exception("Could not load reference file data:", e)
|
| 270 |
+
|
| 271 |
+
if reference_file_data.empty:
|
| 272 |
+
out_message = out_message + " No reference data table provided."
|
| 273 |
+
raise Exception(out_message)
|
| 274 |
+
|
| 275 |
+
print(out_message)
|
| 276 |
+
|
| 277 |
+
return reference_file_data, reference_file_name
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def load_in_previous_data_files(
|
| 281 |
+
file_paths_partial_output: List[str], for_modified_table: bool = False
|
| 282 |
+
):
|
| 283 |
+
"""Load in data table from a partially completed consultation summary to continue it."""
|
| 284 |
+
|
| 285 |
+
reference_file_data = pd.DataFrame()
|
| 286 |
+
reference_file_name = ""
|
| 287 |
+
unique_file_data = pd.DataFrame()
|
| 288 |
+
unique_file_name = ""
|
| 289 |
+
out_message = ""
|
| 290 |
+
latest_batch = 0
|
| 291 |
+
|
| 292 |
+
if not file_paths_partial_output:
|
| 293 |
+
out_message = out_message + " No reference or unique data table provided."
|
| 294 |
+
return (
|
| 295 |
+
reference_file_data,
|
| 296 |
+
unique_file_data,
|
| 297 |
+
latest_batch,
|
| 298 |
+
out_message,
|
| 299 |
+
reference_file_name,
|
| 300 |
+
unique_file_name,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if not isinstance(file_paths_partial_output, list):
|
| 304 |
+
file_paths_partial_output = [file_paths_partial_output]
|
| 305 |
+
|
| 306 |
+
for file in file_paths_partial_output:
|
| 307 |
+
|
| 308 |
+
if isinstance(file, gr.FileData):
|
| 309 |
+
name = file.name
|
| 310 |
+
else:
|
| 311 |
+
name = file
|
| 312 |
+
|
| 313 |
+
# If reference table
|
| 314 |
+
if "reference_table" in name:
|
| 315 |
+
try:
|
| 316 |
+
reference_file_data, reference_file_name = load_in_file(file)
|
| 317 |
+
# print("reference_file_data:", reference_file_data.head(2))
|
| 318 |
+
out_message = out_message + " Reference file load successful."
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
out_message = "Could not load reference file data:" + str(e)
|
| 322 |
+
raise Exception("Could not load reference file data:", e)
|
| 323 |
+
# If unique table
|
| 324 |
+
if "unique_topic" in name:
|
| 325 |
+
try:
|
| 326 |
+
unique_file_data, unique_file_name = load_in_file(file)
|
| 327 |
+
# print("unique_topics_file:", unique_file_data.head(2))
|
| 328 |
+
out_message = out_message + " Unique table file load successful."
|
| 329 |
+
except Exception as e:
|
| 330 |
+
out_message = "Could not load unique table file data:" + str(e)
|
| 331 |
+
raise Exception("Could not load unique table file data:", e)
|
| 332 |
+
if "batch_" in name:
|
| 333 |
+
latest_batch = re.search(r"batch_(\d+)", file.name).group(1)
|
| 334 |
+
print("latest batch:", latest_batch)
|
| 335 |
+
latest_batch = int(latest_batch)
|
| 336 |
+
|
| 337 |
+
if latest_batch == 0:
|
| 338 |
+
out_message = out_message + " Latest batch number not found."
|
| 339 |
+
if reference_file_data.empty:
|
| 340 |
+
out_message = out_message + " No reference data table provided."
|
| 341 |
+
# raise Exception(out_message)
|
| 342 |
+
if unique_file_data.empty:
|
| 343 |
+
out_message = out_message + " No unique data table provided."
|
| 344 |
+
|
| 345 |
+
print(out_message)
|
| 346 |
+
|
| 347 |
+
# Return all data if using for deduplication task. Return just modified unique table if using just for table modification
|
| 348 |
+
if for_modified_table is False:
|
| 349 |
+
return (
|
| 350 |
+
reference_file_data,
|
| 351 |
+
unique_file_data,
|
| 352 |
+
latest_batch,
|
| 353 |
+
out_message,
|
| 354 |
+
reference_file_name,
|
| 355 |
+
unique_file_name,
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
reference_file_data.drop("Topic number", axis=1, inplace=True, errors="ignore")
|
| 359 |
+
|
| 360 |
+
unique_file_data = create_topic_summary_df_from_reference_table(
|
| 361 |
+
reference_file_data
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
unique_file_data.drop("Summary", axis=1, inplace=True)
|
| 365 |
+
|
| 366 |
+
# Then merge the topic numbers back to the original dataframe
|
| 367 |
+
reference_file_data = reference_file_data.merge(
|
| 368 |
+
unique_file_data[
|
| 369 |
+
["General topic", "Subtopic", "Sentiment", "Topic number"]
|
| 370 |
+
],
|
| 371 |
+
on=["General topic", "Subtopic", "Sentiment"],
|
| 372 |
+
how="left",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
out_file_names = [reference_file_name + ".csv"]
|
| 376 |
+
out_file_names.append(unique_file_name + ".csv")
|
| 377 |
+
|
| 378 |
+
return (
|
| 379 |
+
unique_file_data,
|
| 380 |
+
reference_file_data,
|
| 381 |
+
unique_file_data,
|
| 382 |
+
reference_file_name,
|
| 383 |
+
unique_file_name,
|
| 384 |
+
out_file_names,
|
| 385 |
+
) # gr.Dataframe(value=unique_file_data, headers=None, column_count=(unique_file_data.shape[1], "fixed"), row_count = (unique_file_data.shape[0], "fixed"), visible=True, type="pandas")
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def join_cols_onto_reference_df(
|
| 389 |
+
reference_df: pd.DataFrame,
|
| 390 |
+
original_data_df: pd.DataFrame,
|
| 391 |
+
join_columns: List[str],
|
| 392 |
+
original_file_name: str,
|
| 393 |
+
output_folder: str = OUTPUT_FOLDER,
|
| 394 |
+
):
|
| 395 |
+
|
| 396 |
+
# print("original_data_df columns:", original_data_df.columns)
|
| 397 |
+
# print("original_data_df:", original_data_df)
|
| 398 |
+
|
| 399 |
+
original_data_df.reset_index(names="Response References", inplace=True)
|
| 400 |
+
original_data_df["Response References"] += 1
|
| 401 |
+
|
| 402 |
+
# print("reference_df columns:", reference_df.columns)
|
| 403 |
+
# print("reference_df:", reference_df)
|
| 404 |
+
|
| 405 |
+
join_columns.append("Response References")
|
| 406 |
+
|
| 407 |
+
reference_df["Response References"] = (
|
| 408 |
+
reference_df["Response References"].fillna("-1").astype(int)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
save_file_name = output_folder + original_file_name + "_j.csv"
|
| 412 |
+
|
| 413 |
+
out_reference_df = reference_df.merge(
|
| 414 |
+
original_data_df[join_columns], on="Response References", how="left"
|
| 415 |
+
)
|
| 416 |
+
out_reference_df.to_csv(save_file_name, index=None)
|
| 417 |
+
|
| 418 |
+
file_data_outputs = [save_file_name]
|
| 419 |
+
|
| 420 |
+
return out_reference_df, file_data_outputs
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def get_basic_response_data(
|
| 424 |
+
file_data: pd.DataFrame, chosen_cols: List[str], verify_titles: bool = False
|
| 425 |
+
) -> pd.DataFrame:
|
| 426 |
+
|
| 427 |
+
if not isinstance(chosen_cols, list):
|
| 428 |
+
chosen_cols = [chosen_cols]
|
| 429 |
+
|
| 430 |
+
if chosen_cols[0] not in file_data.columns:
|
| 431 |
+
error_msg = (
|
| 432 |
+
f"Column '{chosen_cols[0]}' not found in file_data columns. "
|
| 433 |
+
f"Available columns: {list(file_data.columns)}"
|
| 434 |
+
)
|
| 435 |
+
print(error_msg)
|
| 436 |
+
raise KeyError(error_msg)
|
| 437 |
+
|
| 438 |
+
# If verify_titles is True, we need to check and include the second column
|
| 439 |
+
if verify_titles is True:
|
| 440 |
+
if len(chosen_cols) < 2:
|
| 441 |
+
error_msg = (
|
| 442 |
+
"verify_titles is True but only one column provided. "
|
| 443 |
+
"Need at least 2 columns: one for response text and one for title."
|
| 444 |
+
)
|
| 445 |
+
print(error_msg)
|
| 446 |
+
raise ValueError(error_msg)
|
| 447 |
+
if chosen_cols[1] not in file_data.columns:
|
| 448 |
+
error_msg = (
|
| 449 |
+
f"Column '{chosen_cols[1]}' not found in file_data columns for title. "
|
| 450 |
+
f"Available columns: {list(file_data.columns)}"
|
| 451 |
+
)
|
| 452 |
+
print(error_msg)
|
| 453 |
+
raise KeyError(error_msg)
|
| 454 |
+
# Include both columns when verify_titles is True
|
| 455 |
+
basic_response_data = file_data[[chosen_cols[0], chosen_cols[1]]]
|
| 456 |
+
basic_response_data = basic_response_data.rename(
|
| 457 |
+
columns={
|
| 458 |
+
basic_response_data.columns[0]: "Response",
|
| 459 |
+
basic_response_data.columns[1]: "Title",
|
| 460 |
+
}
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
basic_response_data = file_data[[chosen_cols[0]]]
|
| 464 |
+
basic_response_data = basic_response_data.rename(
|
| 465 |
+
columns={basic_response_data.columns[0]: "Response"}
|
| 466 |
+
)
|
| 467 |
+
basic_response_data = basic_response_data.reset_index(
|
| 468 |
+
names="Original Reference"
|
| 469 |
+
) # .reset_index(drop=True) #
|
| 470 |
+
# Try to convert to int, if it fails, return a range of 1 to last row + 1
|
| 471 |
+
try:
|
| 472 |
+
basic_response_data["Original Reference"] = (
|
| 473 |
+
basic_response_data["Original Reference"].astype(int) + 1
|
| 474 |
+
)
|
| 475 |
+
except (ValueError, TypeError):
|
| 476 |
+
basic_response_data["Original Reference"] = range(
|
| 477 |
+
1, len(basic_response_data) + 1
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
basic_response_data["Reference"] = basic_response_data.index.astype(int) + 1
|
| 481 |
+
|
| 482 |
+
if verify_titles is True:
|
| 483 |
+
basic_response_data["Title"] = basic_response_data["Title"].str.strip()
|
| 484 |
+
basic_response_data["Title"] = basic_response_data["Title"].apply(initial_clean)
|
| 485 |
+
else:
|
| 486 |
+
basic_response_data = basic_response_data[
|
| 487 |
+
["Reference", "Response", "Original Reference"]
|
| 488 |
+
]
|
| 489 |
+
|
| 490 |
+
basic_response_data["Response"] = basic_response_data["Response"].str.strip()
|
| 491 |
+
basic_response_data["Response"] = basic_response_data["Response"].apply(
|
| 492 |
+
initial_clean
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
return basic_response_data
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def convert_reference_table_to_pivot_table(
|
| 499 |
+
df: pd.DataFrame, basic_response_data: pd.DataFrame = pd.DataFrame()
|
| 500 |
+
):
|
| 501 |
+
|
| 502 |
+
df_in = df[["Response References", "General topic", "Subtopic", "Sentiment"]].copy()
|
| 503 |
+
|
| 504 |
+
df_in["Response References"] = df_in["Response References"].astype(int)
|
| 505 |
+
|
| 506 |
+
# Create a combined category column
|
| 507 |
+
df_in["Category"] = (
|
| 508 |
+
df_in["General topic"] + " - " + df_in["Subtopic"] + " - " + df_in["Sentiment"]
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Create pivot table counting occurrences of each unique combination
|
| 512 |
+
pivot_table = pd.crosstab(
|
| 513 |
+
index=df_in["Response References"],
|
| 514 |
+
columns=[df_in["General topic"], df_in["Subtopic"], df_in["Sentiment"]],
|
| 515 |
+
margins=True,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Flatten column names to make them more readable
|
| 519 |
+
pivot_table.columns = [" - ".join(col) for col in pivot_table.columns]
|
| 520 |
+
|
| 521 |
+
pivot_table.reset_index(inplace=True)
|
| 522 |
+
|
| 523 |
+
if not basic_response_data.empty:
|
| 524 |
+
pivot_table = basic_response_data.merge(
|
| 525 |
+
pivot_table, right_on="Response References", left_on="Reference", how="left"
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
pivot_table.drop("Response References", axis=1, inplace=True)
|
| 529 |
+
|
| 530 |
+
pivot_table.columns = pivot_table.columns.str.replace(
|
| 531 |
+
"Not assessed - ", ""
|
| 532 |
+
).str.replace("- Not assessed", "")
|
| 533 |
+
|
| 534 |
+
return pivot_table
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def create_topic_summary_df_from_reference_table(reference_df: pd.DataFrame):
|
| 538 |
+
|
| 539 |
+
if "Group" not in reference_df.columns:
|
| 540 |
+
reference_df["Group"] = "All"
|
| 541 |
+
|
| 542 |
+
# Ensure 'Start row of group' column is numeric to avoid comparison errors
|
| 543 |
+
if "Start row of group" in reference_df.columns:
|
| 544 |
+
reference_df["Start row of group"] = pd.to_numeric(
|
| 545 |
+
reference_df["Start row of group"], errors="coerce"
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
out_topic_summary_df = (
|
| 549 |
+
reference_df.groupby(["General topic", "Subtopic", "Sentiment", "Group"])
|
| 550 |
+
.agg(
|
| 551 |
+
{
|
| 552 |
+
"Response References": "size", # Count the number of references
|
| 553 |
+
"Summary": lambda x: "<br>".join(
|
| 554 |
+
sorted(
|
| 555 |
+
set(x),
|
| 556 |
+
key=lambda summary: reference_df.loc[
|
| 557 |
+
reference_df["Summary"] == summary, "Start row of group"
|
| 558 |
+
].min(),
|
| 559 |
+
)
|
| 560 |
+
),
|
| 561 |
+
}
|
| 562 |
+
)
|
| 563 |
+
.reset_index()
|
| 564 |
+
# .sort_values('Response References', ascending=False) # Sort by size, biggest first
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
out_topic_summary_df = out_topic_summary_df.rename(
|
| 568 |
+
columns={"Response References": "Number of responses"}, errors="ignore"
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Sort the dataframe first
|
| 572 |
+
out_topic_summary_df = out_topic_summary_df.sort_values(
|
| 573 |
+
["Group", "Number of responses", "General topic", "Subtopic", "Sentiment"],
|
| 574 |
+
ascending=[True, False, True, True, True],
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Then assign Topic number based on the final sorted order
|
| 578 |
+
out_topic_summary_df = out_topic_summary_df.assign(
|
| 579 |
+
Topic_number=lambda df: np.arange(1, len(df) + 1)
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
out_topic_summary_df.rename(columns={"Topic_number": "Topic number"}, inplace=True)
|
| 583 |
+
|
| 584 |
+
return out_topic_summary_df
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# Wrap text in each column to the specified max width, including whole words
|
| 588 |
+
def wrap_text(text: str, max_width=80, max_text_length=None):
|
| 589 |
+
if not isinstance(text, str):
|
| 590 |
+
return text
|
| 591 |
+
|
| 592 |
+
# If max_text_length is set, truncate the text and add ellipsis
|
| 593 |
+
if max_text_length and len(text) > max_text_length:
|
| 594 |
+
text = text[:max_text_length] + "..."
|
| 595 |
+
|
| 596 |
+
text = text.replace("\r\n", "<br>").replace("\n", "<br>")
|
| 597 |
+
|
| 598 |
+
words = text.split()
|
| 599 |
+
if not words:
|
| 600 |
+
return text
|
| 601 |
+
|
| 602 |
+
# First pass: initial word wrapping
|
| 603 |
+
wrapped_lines = list()
|
| 604 |
+
current_line = list()
|
| 605 |
+
current_length = 0
|
| 606 |
+
|
| 607 |
+
def add_line():
|
| 608 |
+
if current_line:
|
| 609 |
+
wrapped_lines.append(" ".join(current_line))
|
| 610 |
+
current_line.clear()
|
| 611 |
+
|
| 612 |
+
for i, word in enumerate(words):
|
| 613 |
+
word_length = len(word)
|
| 614 |
+
|
| 615 |
+
# Handle words longer than max_width
|
| 616 |
+
if word_length > max_width:
|
| 617 |
+
add_line()
|
| 618 |
+
wrapped_lines.append(word)
|
| 619 |
+
current_length = 0
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
# Calculate space needed for this word
|
| 623 |
+
space_needed = word_length if not current_line else word_length + 1
|
| 624 |
+
|
| 625 |
+
# Check if adding this word would exceed max_width
|
| 626 |
+
if current_length + space_needed > max_width:
|
| 627 |
+
add_line()
|
| 628 |
+
current_line.append(word)
|
| 629 |
+
current_length = word_length
|
| 630 |
+
else:
|
| 631 |
+
current_line.append(word)
|
| 632 |
+
current_length += space_needed
|
| 633 |
+
|
| 634 |
+
add_line() # Add any remaining text
|
| 635 |
+
|
| 636 |
+
# Second pass: redistribute words from lines following single-word lines
|
| 637 |
+
def can_fit_in_previous_line(prev_line, word):
|
| 638 |
+
return len(prev_line) + 1 + len(word) <= max_width
|
| 639 |
+
|
| 640 |
+
i = 0
|
| 641 |
+
while i < len(wrapped_lines) - 1:
|
| 642 |
+
words_in_line = wrapped_lines[i].split()
|
| 643 |
+
next_line_words = wrapped_lines[i + 1].split()
|
| 644 |
+
|
| 645 |
+
# If current line has only one word and isn't too long
|
| 646 |
+
if len(words_in_line) == 1 and len(words_in_line[0]) < max_width * 0.8:
|
| 647 |
+
# Try to bring words back from the next line
|
| 648 |
+
words_to_bring_back = list()
|
| 649 |
+
remaining_words = list()
|
| 650 |
+
current_length = len(words_in_line[0])
|
| 651 |
+
|
| 652 |
+
for word in next_line_words:
|
| 653 |
+
if current_length + len(word) + 1 <= max_width:
|
| 654 |
+
words_to_bring_back.append(word)
|
| 655 |
+
current_length += len(word) + 1
|
| 656 |
+
else:
|
| 657 |
+
remaining_words.append(word)
|
| 658 |
+
|
| 659 |
+
if words_to_bring_back:
|
| 660 |
+
# Update current line with additional words
|
| 661 |
+
wrapped_lines[i] = " ".join(words_in_line + words_to_bring_back)
|
| 662 |
+
|
| 663 |
+
# Update next line with remaining words
|
| 664 |
+
if remaining_words:
|
| 665 |
+
wrapped_lines[i + 1] = " ".join(remaining_words)
|
| 666 |
+
else:
|
| 667 |
+
wrapped_lines.pop(i + 1)
|
| 668 |
+
continue # Don't increment i if we removed a line
|
| 669 |
+
i += 1
|
| 670 |
+
|
| 671 |
+
return "<br>".join(wrapped_lines)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def initial_clean(text: str):
|
| 675 |
+
#### Some of my cleaning functions
|
| 676 |
+
html_pattern_regex = r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0| "
|
| 677 |
+
html_start_pattern_end_dots_regex = r"<(.*?)\.\."
|
| 678 |
+
non_ascii_pattern = r"[^\x00-\x7F]+"
|
| 679 |
+
multiple_spaces_regex = r"\s{2,}"
|
| 680 |
+
|
| 681 |
+
# Define a list of patterns and their replacements
|
| 682 |
+
patterns = [
|
| 683 |
+
(html_pattern_regex, " "),
|
| 684 |
+
(html_start_pattern_end_dots_regex, " "),
|
| 685 |
+
(non_ascii_pattern, " "),
|
| 686 |
+
(multiple_spaces_regex, " "),
|
| 687 |
+
]
|
| 688 |
+
|
| 689 |
+
# Apply each regex replacement
|
| 690 |
+
for pattern, replacement in patterns:
|
| 691 |
+
text = re.sub(pattern, replacement, text)
|
| 692 |
+
|
| 693 |
+
return text
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def view_table(file_path: str): # Added max_width parameter
|
| 697 |
+
df = pd.read_csv(file_path)
|
| 698 |
+
|
| 699 |
+
df_cleaned = df.replace("\n", " ", regex=True)
|
| 700 |
+
|
| 701 |
+
# Use apply with axis=1 to apply wrap_text to each element
|
| 702 |
+
df_cleaned = df_cleaned.apply(lambda col: col.map(wrap_text))
|
| 703 |
+
|
| 704 |
+
table_out = df_cleaned.to_markdown(index=False)
|
| 705 |
+
|
| 706 |
+
return table_out
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def ensure_output_folder_exists():
|
| 710 |
+
"""Checks if the 'output/' folder exists, creates it if not."""
|
| 711 |
+
|
| 712 |
+
folder_name = "output/"
|
| 713 |
+
|
| 714 |
+
if not os.path.exists(folder_name):
|
| 715 |
+
# Create the folder if it doesn't exist
|
| 716 |
+
os.makedirs(folder_name)
|
| 717 |
+
print("Created the 'output/' folder.")
|
| 718 |
+
else:
|
| 719 |
+
print("The 'output/' folder already exists.")
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def put_columns_in_df(in_file: List[str]):
|
| 723 |
+
new_choices = list()
|
| 724 |
+
concat_choices = list()
|
| 725 |
+
all_sheet_names = list()
|
| 726 |
+
number_of_excel_files = 0
|
| 727 |
+
|
| 728 |
+
if not in_file:
|
| 729 |
+
return (
|
| 730 |
+
gr.Dropdown(choices=list()),
|
| 731 |
+
gr.Dropdown(choices=list()),
|
| 732 |
+
"",
|
| 733 |
+
gr.Dropdown(choices=list()),
|
| 734 |
+
gr.Dropdown(choices=list()),
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
for file in in_file:
|
| 738 |
+
file_name = file.name
|
| 739 |
+
file_type = detect_file_type(file_name)
|
| 740 |
+
# print("File type is:", file_type)
|
| 741 |
+
|
| 742 |
+
file_end = get_file_path_with_extension(file_name)
|
| 743 |
+
|
| 744 |
+
if file_type == "xlsx":
|
| 745 |
+
number_of_excel_files += 1
|
| 746 |
+
new_choices = list()
|
| 747 |
+
print("Running through all xlsx sheets")
|
| 748 |
+
anon_xlsx = pd.ExcelFile(file_name)
|
| 749 |
+
new_sheet_names = anon_xlsx.sheet_names
|
| 750 |
+
# Iterate through the sheet names
|
| 751 |
+
for sheet_name in new_sheet_names:
|
| 752 |
+
# Read each sheet into a DataFrame
|
| 753 |
+
df = pd.read_excel(file_name, sheet_name=sheet_name)
|
| 754 |
+
|
| 755 |
+
new_choices.extend(list(df.columns))
|
| 756 |
+
|
| 757 |
+
all_sheet_names.extend(new_sheet_names)
|
| 758 |
+
|
| 759 |
+
else:
|
| 760 |
+
df = read_file(file_name)
|
| 761 |
+
new_choices = list(df.columns)
|
| 762 |
+
|
| 763 |
+
concat_choices.extend(new_choices)
|
| 764 |
+
|
| 765 |
+
# Drop duplicate columns
|
| 766 |
+
concat_choices = sorted(set(concat_choices))
|
| 767 |
+
|
| 768 |
+
if number_of_excel_files > 0:
|
| 769 |
+
return (
|
| 770 |
+
gr.Dropdown(choices=concat_choices, value=concat_choices[0]),
|
| 771 |
+
gr.Dropdown(
|
| 772 |
+
choices=all_sheet_names,
|
| 773 |
+
value=all_sheet_names[0],
|
| 774 |
+
visible=True,
|
| 775 |
+
interactive=True,
|
| 776 |
+
),
|
| 777 |
+
file_end,
|
| 778 |
+
gr.Dropdown(choices=concat_choices),
|
| 779 |
+
gr.Dropdown(choices=concat_choices),
|
| 780 |
+
)
|
| 781 |
+
else:
|
| 782 |
+
return (
|
| 783 |
+
gr.Dropdown(choices=concat_choices, value=concat_choices[0]),
|
| 784 |
+
gr.Dropdown(visible=False),
|
| 785 |
+
file_end,
|
| 786 |
+
gr.Dropdown(choices=concat_choices),
|
| 787 |
+
gr.Dropdown(choices=concat_choices),
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# Following function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
|
| 792 |
+
def add_folder_to_path(folder_path: str):
|
| 793 |
+
"""
|
| 794 |
+
Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist.
|
| 795 |
+
"""
|
| 796 |
+
|
| 797 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
| 798 |
+
print(folder_path, "folder exists.")
|
| 799 |
+
|
| 800 |
+
# Resolve relative path to absolute path
|
| 801 |
+
absolute_path = os.path.abspath(folder_path)
|
| 802 |
+
|
| 803 |
+
current_path = os.environ["PATH"]
|
| 804 |
+
if absolute_path not in current_path.split(os.pathsep):
|
| 805 |
+
full_path_extension = absolute_path + os.pathsep + current_path
|
| 806 |
+
os.environ["PATH"] = full_path_extension
|
| 807 |
+
# print(f"Updated PATH with: ", full_path_extension)
|
| 808 |
+
else:
|
| 809 |
+
print(f"Directory {folder_path} already exists in PATH.")
|
| 810 |
+
else:
|
| 811 |
+
print(f"Folder not found at {folder_path} - not added to PATH")
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
# Upon running a process, the feedback buttons are revealed
|
| 815 |
+
def reveal_feedback_buttons():
|
| 816 |
+
return (
|
| 817 |
+
gr.Radio(visible=True),
|
| 818 |
+
gr.Textbox(visible=True),
|
| 819 |
+
gr.Button(visible=True),
|
| 820 |
+
gr.Markdown(visible=True),
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def wipe_logs(feedback_logs_loc: str, usage_logs_loc: str):
|
| 825 |
+
try:
|
| 826 |
+
os.remove(feedback_logs_loc)
|
| 827 |
+
except Exception as e:
|
| 828 |
+
print("Could not remove feedback logs file", e)
|
| 829 |
+
try:
|
| 830 |
+
os.remove(usage_logs_loc)
|
| 831 |
+
except Exception as e:
|
| 832 |
+
print("Could not remove usage logs file", e)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
async def get_connection_params(
|
| 836 |
+
request: gr.Request,
|
| 837 |
+
output_folder_textbox: str = OUTPUT_FOLDER,
|
| 838 |
+
input_folder_textbox: str = INPUT_FOLDER,
|
| 839 |
+
session_output_folder: str = SESSION_OUTPUT_FOLDER,
|
| 840 |
+
):
|
| 841 |
+
|
| 842 |
+
# print("Session hash:", request.session_hash)
|
| 843 |
+
|
| 844 |
+
if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
|
| 845 |
+
if CUSTOM_HEADER in request.headers:
|
| 846 |
+
supplied_custom_header_value = request.headers[CUSTOM_HEADER]
|
| 847 |
+
if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
|
| 848 |
+
print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
|
| 849 |
+
else:
|
| 850 |
+
print("Custom header value does not match expected value.")
|
| 851 |
+
raise ValueError("Custom header value does not match expected value.")
|
| 852 |
+
else:
|
| 853 |
+
print("Custom header value not found.")
|
| 854 |
+
raise ValueError("Custom header value not found.")
|
| 855 |
+
|
| 856 |
+
# Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
|
| 857 |
+
|
| 858 |
+
if request.username:
|
| 859 |
+
out_session_hash = request.username
|
| 860 |
+
# print("Request username found:", out_session_hash)
|
| 861 |
+
|
| 862 |
+
elif "x-cognito-id" in request.headers:
|
| 863 |
+
out_session_hash = request.headers["x-cognito-id"]
|
| 864 |
+
# print("Cognito ID found:", out_session_hash)
|
| 865 |
+
|
| 866 |
+
elif "x-amzn-oidc-identity" in request.headers:
|
| 867 |
+
out_session_hash = request.headers["x-amzn-oidc-identity"]
|
| 868 |
+
|
| 869 |
+
# Fetch email address using Cognito client
|
| 870 |
+
cognito_client = boto3.client("cognito-idp")
|
| 871 |
+
try:
|
| 872 |
+
response = cognito_client.admin_get_user(
|
| 873 |
+
UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
|
| 874 |
+
Username=out_session_hash,
|
| 875 |
+
)
|
| 876 |
+
email = next(
|
| 877 |
+
attr["Value"]
|
| 878 |
+
for attr in response["UserAttributes"]
|
| 879 |
+
if attr["Name"] == "email"
|
| 880 |
+
)
|
| 881 |
+
# print("Email address found:", email)
|
| 882 |
+
|
| 883 |
+
out_session_hash = email
|
| 884 |
+
except ClientError as e:
|
| 885 |
+
print("Error fetching user details:", e)
|
| 886 |
+
email = None
|
| 887 |
+
|
| 888 |
+
print("Cognito ID found:", out_session_hash)
|
| 889 |
+
|
| 890 |
+
else:
|
| 891 |
+
out_session_hash = request.session_hash
|
| 892 |
+
|
| 893 |
+
if session_output_folder == "True" or session_output_folder is True:
|
| 894 |
+
output_folder = output_folder_textbox + out_session_hash + "/"
|
| 895 |
+
input_folder = input_folder_textbox + out_session_hash + "/"
|
| 896 |
+
else:
|
| 897 |
+
output_folder = output_folder_textbox
|
| 898 |
+
input_folder = input_folder_textbox
|
| 899 |
+
|
| 900 |
+
if not os.path.exists(output_folder):
|
| 901 |
+
os.mkdir(output_folder)
|
| 902 |
+
if not os.path.exists(input_folder):
|
| 903 |
+
os.mkdir(input_folder)
|
| 904 |
+
|
| 905 |
+
return out_session_hash, output_folder, out_session_hash, input_folder
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def load_in_default_cost_codes(cost_codes_path: str, default_cost_code: str = ""):
|
| 909 |
+
"""
|
| 910 |
+
Load in the cost codes list from file.
|
| 911 |
+
"""
|
| 912 |
+
cost_codes_df = pd.read_csv(cost_codes_path)
|
| 913 |
+
dropdown_choices = cost_codes_df.iloc[:, 0].astype(str).tolist()
|
| 914 |
+
|
| 915 |
+
# Avoid inserting duplicate or empty cost code values
|
| 916 |
+
if default_cost_code and default_cost_code not in dropdown_choices:
|
| 917 |
+
dropdown_choices.insert(0, default_cost_code)
|
| 918 |
+
|
| 919 |
+
# Always have a blank option at the top
|
| 920 |
+
if "" not in dropdown_choices:
|
| 921 |
+
dropdown_choices.insert(0, "")
|
| 922 |
+
|
| 923 |
+
out_dropdown = gr.Dropdown(
|
| 924 |
+
value=default_cost_code if default_cost_code in dropdown_choices else "",
|
| 925 |
+
label="Choose cost code for analysis",
|
| 926 |
+
choices=dropdown_choices,
|
| 927 |
+
allow_custom_value=False,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
return cost_codes_df, cost_codes_df, out_dropdown
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def df_select_callback_cost(df: pd.DataFrame, evt: gr.SelectData):
|
| 934 |
+
row_value_code = evt.row_value[0] # This is the value for cost code
|
| 935 |
+
|
| 936 |
+
return row_value_code
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def update_cost_code_dataframe_from_dropdown_select(
|
| 940 |
+
cost_dropdown_selection: str, cost_code_df: pd.DataFrame
|
| 941 |
+
):
|
| 942 |
+
cost_code_df = cost_code_df.loc[
|
| 943 |
+
cost_code_df.iloc[:, 0] == cost_dropdown_selection, :
|
| 944 |
+
]
|
| 945 |
+
return cost_code_df
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def reset_base_dataframe(df: pd.DataFrame):
|
| 949 |
+
return df
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def enforce_cost_codes(
|
| 953 |
+
enforce_cost_code_textbox: str,
|
| 954 |
+
cost_code_choice: str,
|
| 955 |
+
cost_code_df: pd.DataFrame,
|
| 956 |
+
verify_cost_codes: bool = True,
|
| 957 |
+
):
|
| 958 |
+
"""
|
| 959 |
+
Check if the enforce cost codes variable is set to true, and then check that a cost cost has been chosen. If not, raise an error. Then, check against the values in the cost code dataframe to ensure that the cost code exists.
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
if enforce_cost_code_textbox == "True":
|
| 963 |
+
if not cost_code_choice:
|
| 964 |
+
raise Exception("Please choose a cost code before continuing")
|
| 965 |
+
|
| 966 |
+
if verify_cost_codes is True:
|
| 967 |
+
if cost_code_df.empty:
|
| 968 |
+
# Warn but don't block - cost code is still required above
|
| 969 |
+
print(
|
| 970 |
+
"Warning: Cost code dataframe is empty. Verification skipped. Please ensure cost codes are loaded for full validation."
|
| 971 |
+
)
|
| 972 |
+
else:
|
| 973 |
+
valid_cost_codes_list = list(cost_code_df.iloc[:, 0].unique())
|
| 974 |
+
|
| 975 |
+
if cost_code_choice not in valid_cost_codes_list:
|
| 976 |
+
raise Exception(
|
| 977 |
+
"Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions."
|
| 978 |
+
)
|
| 979 |
+
return
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def _get_env_list(env_var_name: str, strip_strings: bool = True) -> List[str]:
|
| 983 |
+
"""Parses a comma-separated environment variable into a list of strings."""
|
| 984 |
+
value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
|
| 985 |
+
if not value:
|
| 986 |
+
return []
|
| 987 |
+
# Split by comma and filter out any empty strings that might result from extra commas
|
| 988 |
+
if strip_strings:
|
| 989 |
+
return [s.strip() for s in value.split(",") if s.strip()]
|
| 990 |
+
else:
|
| 991 |
+
return [codecs.decode(s, "unicode_escape") for s in value.split(",") if s]
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def create_batch_file_path_details(
|
| 995 |
+
reference_data_file_name: str,
|
| 996 |
+
latest_batch_completed: int = None,
|
| 997 |
+
batch_size_number: int = None,
|
| 998 |
+
in_column: str = None,
|
| 999 |
+
) -> str:
|
| 1000 |
+
"""
|
| 1001 |
+
Creates a standardised batch file path detail string from a reference data filename.
|
| 1002 |
+
|
| 1003 |
+
Args:
|
| 1004 |
+
reference_data_file_name (str): Name of the reference data file
|
| 1005 |
+
latest_batch_completed (int, optional): Latest batch completed. Defaults to None.
|
| 1006 |
+
batch_size_number (int, optional): Batch size number. Defaults to None.
|
| 1007 |
+
in_column (str, optional): In column. Defaults to None.
|
| 1008 |
+
Returns:
|
| 1009 |
+
str: Formatted batch file path detail string
|
| 1010 |
+
"""
|
| 1011 |
+
|
| 1012 |
+
# Extract components from filename using regex
|
| 1013 |
+
file_name = (
|
| 1014 |
+
re.search(
|
| 1015 |
+
r"(.*?)(?:_all_|_final_|_batch_|_col_)", reference_data_file_name
|
| 1016 |
+
).group(1)
|
| 1017 |
+
if re.search(r"(.*?)(?:_all_|_final_|_batch_|_col_)", reference_data_file_name)
|
| 1018 |
+
else reference_data_file_name
|
| 1019 |
+
)
|
| 1020 |
+
latest_batch_completed = (
|
| 1021 |
+
int(re.search(r"batch_(\d+)_", reference_data_file_name).group(1))
|
| 1022 |
+
if "batch_" in reference_data_file_name
|
| 1023 |
+
else latest_batch_completed
|
| 1024 |
+
)
|
| 1025 |
+
batch_size_number = (
|
| 1026 |
+
int(re.search(r"size_(\d+)_", reference_data_file_name).group(1))
|
| 1027 |
+
if "size_" in reference_data_file_name
|
| 1028 |
+
else batch_size_number
|
| 1029 |
+
)
|
| 1030 |
+
in_column = (
|
| 1031 |
+
re.search(r"col_(.*?)_reference", reference_data_file_name).group(1)
|
| 1032 |
+
if "col_" in reference_data_file_name
|
| 1033 |
+
else in_column
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# Clean the extracted names
|
| 1037 |
+
file_name_cleaned = clean_column_name(file_name, max_length=20)
|
| 1038 |
+
in_column_cleaned = clean_column_name(in_column, max_length=20)
|
| 1039 |
+
|
| 1040 |
+
# Create batch file path details string
|
| 1041 |
+
if latest_batch_completed:
|
| 1042 |
+
return f"{file_name_cleaned}_batch_{latest_batch_completed}_size_{batch_size_number}_col_{in_column_cleaned}"
|
| 1043 |
+
return f"{file_name_cleaned}_col_{in_column_cleaned}"
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def move_overall_summary_output_files_to_front_page(
|
| 1047 |
+
overall_summary_output_files_xlsx: List[str],
|
| 1048 |
+
):
|
| 1049 |
+
return overall_summary_output_files_xlsx
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def generate_zero_shot_topics_df(
|
| 1053 |
+
zero_shot_topics: pd.DataFrame,
|
| 1054 |
+
force_zero_shot_radio: str = "No",
|
| 1055 |
+
create_revised_general_topics: bool = False,
|
| 1056 |
+
max_topic_no: int = MAXIMUM_ZERO_SHOT_TOPICS,
|
| 1057 |
+
):
|
| 1058 |
+
"""
|
| 1059 |
+
Preprocesses a DataFrame of zero-shot topics, cleaning and formatting them
|
| 1060 |
+
for use with a large language model. It handles different column configurations
|
| 1061 |
+
(e.g., only subtopics, general topics and subtopics, or subtopics with descriptions)
|
| 1062 |
+
and enforces a maximum number of topics.
|
| 1063 |
+
|
| 1064 |
+
Args:
|
| 1065 |
+
zero_shot_topics (pd.DataFrame): A DataFrame containing the initial zero-shot topics.
|
| 1066 |
+
Expected columns can vary, but typically include
|
| 1067 |
+
"General topic", "Subtopic", and/or "Description".
|
| 1068 |
+
force_zero_shot_radio (str, optional): A string indicating whether to force
|
| 1069 |
+
the use of zero-shot topics. Defaults to "No".
|
| 1070 |
+
(Currently not used in the function logic, but kept for signature consistency).
|
| 1071 |
+
create_revised_general_topics (bool, optional): A boolean indicating whether to
|
| 1072 |
+
create revised general topics. Defaults to False.
|
| 1073 |
+
(Currently not used in the function logic, but kept for signature consistency).
|
| 1074 |
+
max_topic_no (int, optional): The maximum number of topics allowed to fit within
|
| 1075 |
+
LLM context limits. If `zero_shot_topics` exceeds this,
|
| 1076 |
+
it will be truncated. Defaults to 120.
|
| 1077 |
+
|
| 1078 |
+
Returns:
|
| 1079 |
+
tuple: A tuple containing:
|
| 1080 |
+
- zero_shot_topics_gen_topics_list (list): A list of cleaned general topics.
|
| 1081 |
+
- zero_shot_topics_subtopics_list (list): A list of cleaned subtopics.
|
| 1082 |
+
- zero_shot_topics_description_list (list): A list of cleaned topic descriptions.
|
| 1083 |
+
"""
|
| 1084 |
+
|
| 1085 |
+
zero_shot_topics_gen_topics_list = list()
|
| 1086 |
+
zero_shot_topics_subtopics_list = list()
|
| 1087 |
+
zero_shot_topics_description_list = list()
|
| 1088 |
+
|
| 1089 |
+
# Max 120 topics allowed
|
| 1090 |
+
if zero_shot_topics.shape[0] > max_topic_no:
|
| 1091 |
+
out_message = (
|
| 1092 |
+
"Maximum "
|
| 1093 |
+
+ str(max_topic_no)
|
| 1094 |
+
+ " zero-shot topics allowed according to application configuration."
|
| 1095 |
+
)
|
| 1096 |
+
print(out_message)
|
| 1097 |
+
raise Exception(out_message)
|
| 1098 |
+
|
| 1099 |
+
# Forward slashes in the topic names seems to confuse the model
|
| 1100 |
+
if zero_shot_topics.shape[1] >= 1: # Check if there is at least one column
|
| 1101 |
+
for x in zero_shot_topics.columns:
|
| 1102 |
+
if not zero_shot_topics[x].isnull().all():
|
| 1103 |
+
zero_shot_topics[x] = zero_shot_topics[x].apply(initial_clean)
|
| 1104 |
+
|
| 1105 |
+
zero_shot_topics.loc[:, x] = (
|
| 1106 |
+
zero_shot_topics.loc[:, x]
|
| 1107 |
+
.str.strip()
|
| 1108 |
+
.str.replace("\n", " ")
|
| 1109 |
+
.str.replace("\r", " ")
|
| 1110 |
+
.str.replace("/", " or ")
|
| 1111 |
+
.str.replace("&", " and ")
|
| 1112 |
+
.str.replace(" s ", "s ")
|
| 1113 |
+
.str.lower()
|
| 1114 |
+
.str.capitalize()
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
# If number of columns is 1, keep only subtopics
|
| 1118 |
+
if (
|
| 1119 |
+
zero_shot_topics.shape[1] == 1
|
| 1120 |
+
and "General topic" not in zero_shot_topics.columns
|
| 1121 |
+
):
|
| 1122 |
+
print("Found only Subtopic in zero shot topics")
|
| 1123 |
+
zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
|
| 1124 |
+
zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 0])
|
| 1125 |
+
# Allow for possibility that the user only wants to set general topics and not subtopics
|
| 1126 |
+
elif (
|
| 1127 |
+
zero_shot_topics.shape[1] == 1
|
| 1128 |
+
and "General topic" in zero_shot_topics.columns
|
| 1129 |
+
):
|
| 1130 |
+
print("Found only General topic in zero shot topics")
|
| 1131 |
+
zero_shot_topics_gen_topics_list = list(zero_shot_topics["General topic"])
|
| 1132 |
+
zero_shot_topics_subtopics_list = [""] * zero_shot_topics.shape[0]
|
| 1133 |
+
# If general topic and subtopic are specified
|
| 1134 |
+
elif set(["General topic", "Subtopic"]).issubset(zero_shot_topics.columns):
|
| 1135 |
+
print("Found General topic and Subtopic in zero shot topics")
|
| 1136 |
+
zero_shot_topics_gen_topics_list = list(zero_shot_topics["General topic"])
|
| 1137 |
+
zero_shot_topics_subtopics_list = list(zero_shot_topics["Subtopic"])
|
| 1138 |
+
# If subtopic and description are specified
|
| 1139 |
+
elif set(["Subtopic", "Description"]).issubset(zero_shot_topics.columns):
|
| 1140 |
+
print("Found Subtopic and Description in zero shot topics")
|
| 1141 |
+
zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
|
| 1142 |
+
zero_shot_topics_subtopics_list = list(zero_shot_topics["Subtopic"])
|
| 1143 |
+
zero_shot_topics_description_list = list(zero_shot_topics["Description"])
|
| 1144 |
+
|
| 1145 |
+
# If number of columns is at least 2, keep general topics and subtopics
|
| 1146 |
+
elif (
|
| 1147 |
+
zero_shot_topics.shape[1] >= 2
|
| 1148 |
+
and "Description" not in zero_shot_topics.columns
|
| 1149 |
+
):
|
| 1150 |
+
zero_shot_topics_gen_topics_list = list(zero_shot_topics.iloc[:, 0])
|
| 1151 |
+
zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 1])
|
| 1152 |
+
else:
|
| 1153 |
+
# If there are more columns, just assume that the first column was meant to be a subtopic
|
| 1154 |
+
zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
|
| 1155 |
+
zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 0])
|
| 1156 |
+
|
| 1157 |
+
# Add a description if column is present
|
| 1158 |
+
if not zero_shot_topics_description_list:
|
| 1159 |
+
if "Description" in zero_shot_topics.columns:
|
| 1160 |
+
zero_shot_topics_description_list = list(
|
| 1161 |
+
zero_shot_topics["Description"]
|
| 1162 |
+
)
|
| 1163 |
+
# print("Description found in topic title. List is:", zero_shot_topics_description_list)
|
| 1164 |
+
elif zero_shot_topics.shape[1] >= 3:
|
| 1165 |
+
zero_shot_topics_description_list = list(
|
| 1166 |
+
zero_shot_topics.iloc[:, 2]
|
| 1167 |
+
) # Assume the third column is description
|
| 1168 |
+
else:
|
| 1169 |
+
zero_shot_topics_description_list = [""] * zero_shot_topics.shape[0]
|
| 1170 |
+
|
| 1171 |
+
# If the responses are being forced into zero shot topics, allow an option for nothing relevant
|
| 1172 |
+
if force_zero_shot_radio == "Yes":
|
| 1173 |
+
zero_shot_topics_gen_topics_list.append("")
|
| 1174 |
+
zero_shot_topics_subtopics_list.append("No relevant topic")
|
| 1175 |
+
zero_shot_topics_description_list.append("")
|
| 1176 |
+
|
| 1177 |
+
# Add description or not
|
| 1178 |
+
zero_shot_topics_df = pd.DataFrame(
|
| 1179 |
+
data={
|
| 1180 |
+
"General topic": zero_shot_topics_gen_topics_list,
|
| 1181 |
+
"Subtopic": zero_shot_topics_subtopics_list,
|
| 1182 |
+
"Description": zero_shot_topics_description_list,
|
| 1183 |
+
}
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
# Filter out duplicate General topic and subtopic names
|
| 1187 |
+
zero_shot_topics_df = zero_shot_topics_df.drop_duplicates(
|
| 1188 |
+
["General topic", "Subtopic"], keep="first"
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
# Sort the dataframe by General topic and subtopic
|
| 1192 |
+
zero_shot_topics_df = zero_shot_topics_df.sort_values(
|
| 1193 |
+
["General topic", "Subtopic"], ascending=[True, True]
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
return zero_shot_topics_df
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
def update_model_choice(model_source):
|
| 1200 |
+
# Filter models by source and return the first matching model name
|
| 1201 |
+
matching_models = [
|
| 1202 |
+
model_name
|
| 1203 |
+
for model_name, model_info in model_name_map.items()
|
| 1204 |
+
if model_info["source"] == model_source
|
| 1205 |
+
]
|
| 1206 |
+
|
| 1207 |
+
output_model = matching_models[0] if matching_models else model_full_names[0]
|
| 1208 |
+
|
| 1209 |
+
return gr.Dropdown(
|
| 1210 |
+
value=output_model,
|
| 1211 |
+
choices=matching_models,
|
| 1212 |
+
label="Large language model for topic extraction and summarisation",
|
| 1213 |
+
multiselect=False,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
def ensure_model_in_map(model_choice: str, model_name_map_dict: dict = None) -> dict:
|
| 1218 |
+
"""
|
| 1219 |
+
Ensures that a model_choice is registered in model_name_map.
|
| 1220 |
+
If the model_choice is not found, it assumes it's an inference-server model
|
| 1221 |
+
and adds it to the map with source "inference-server".
|
| 1222 |
+
|
| 1223 |
+
Args:
|
| 1224 |
+
model_choice (str): The model name to check/register
|
| 1225 |
+
model_name_map_dict (dict, optional): The model_name_map dictionary to update.
|
| 1226 |
+
If None, uses the global model_name_map from config.
|
| 1227 |
+
|
| 1228 |
+
Returns:
|
| 1229 |
+
dict: The model_name_map dictionary (updated if needed)
|
| 1230 |
+
"""
|
| 1231 |
+
# Use provided dict or global one
|
| 1232 |
+
if model_name_map_dict is None:
|
| 1233 |
+
from tools.config import model_name_map
|
| 1234 |
+
|
| 1235 |
+
model_name_map_dict = model_name_map
|
| 1236 |
+
|
| 1237 |
+
# If model_choice is not in the map, assume it's an inference-server model
|
| 1238 |
+
if model_choice not in model_name_map_dict:
|
| 1239 |
+
model_name_map_dict[model_choice] = {
|
| 1240 |
+
"short_name": model_choice,
|
| 1241 |
+
"source": "inference-server",
|
| 1242 |
+
}
|
| 1243 |
+
print(f"Registered custom model '{model_choice}' as inference-server model")
|
| 1244 |
+
|
| 1245 |
+
return model_name_map_dict
|
tools/llm_api_call.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tools/llm_funcs.py
ADDED
|
@@ -0,0 +1,1999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import boto3
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
# Import mock patches if in test mode
|
| 12 |
+
if os.environ.get("USE_MOCK_LLM") == "1" or os.environ.get("TEST_MODE") == "1":
|
| 13 |
+
try:
|
| 14 |
+
# Try to import and apply mock patches
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
# Add project root to sys.path so we can import test.mock_llm_calls
|
| 18 |
+
project_root = os.path.dirname(os.path.dirname(__file__))
|
| 19 |
+
if project_root not in sys.path:
|
| 20 |
+
sys.path.insert(0, project_root)
|
| 21 |
+
try:
|
| 22 |
+
from test.mock_llm_calls import apply_mock_patches
|
| 23 |
+
|
| 24 |
+
apply_mock_patches()
|
| 25 |
+
except ImportError:
|
| 26 |
+
# If mock module not found, continue without mocking
|
| 27 |
+
pass
|
| 28 |
+
except Exception:
|
| 29 |
+
# If anything fails, continue without mocking
|
| 30 |
+
pass
|
| 31 |
+
from google import genai as ai
|
| 32 |
+
from google.genai import types
|
| 33 |
+
from gradio import Progress
|
| 34 |
+
from huggingface_hub import hf_hub_download
|
| 35 |
+
from openai import OpenAI
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
model_type = None # global variable setup
|
| 39 |
+
full_text = (
|
| 40 |
+
"" # Define dummy source text (full text) just to enable highlight function to load
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Global variables for model and tokenizer
|
| 44 |
+
_model = None
|
| 45 |
+
_tokenizer = None
|
| 46 |
+
_assistant_model = None
|
| 47 |
+
|
| 48 |
+
from tools.config import (
|
| 49 |
+
ASSISTANT_MODEL,
|
| 50 |
+
BATCH_SIZE_DEFAULT,
|
| 51 |
+
CHOSEN_LOCAL_MODEL_TYPE,
|
| 52 |
+
COMPILE_MODE,
|
| 53 |
+
COMPILE_TRANSFORMERS,
|
| 54 |
+
DEDUPLICATION_THRESHOLD,
|
| 55 |
+
HF_TOKEN,
|
| 56 |
+
INT8_WITH_OFFLOAD_TO_CPU,
|
| 57 |
+
K_QUANT_LEVEL,
|
| 58 |
+
LLM_BATCH_SIZE,
|
| 59 |
+
LLM_CONTEXT_LENGTH,
|
| 60 |
+
LLM_LAST_N_TOKENS,
|
| 61 |
+
LLM_MAX_GPU_LAYERS,
|
| 62 |
+
LLM_MAX_NEW_TOKENS,
|
| 63 |
+
LLM_MIN_P,
|
| 64 |
+
LLM_REPETITION_PENALTY,
|
| 65 |
+
LLM_RESET,
|
| 66 |
+
LLM_SAMPLE,
|
| 67 |
+
LLM_SEED,
|
| 68 |
+
LLM_STOP_STRINGS,
|
| 69 |
+
LLM_STREAM,
|
| 70 |
+
LLM_TEMPERATURE,
|
| 71 |
+
LLM_THREADS,
|
| 72 |
+
LLM_TOP_K,
|
| 73 |
+
LLM_TOP_P,
|
| 74 |
+
LOAD_LOCAL_MODEL_AT_START,
|
| 75 |
+
LOCAL_MODEL_FILE,
|
| 76 |
+
LOCAL_MODEL_FOLDER,
|
| 77 |
+
LOCAL_REPO_ID,
|
| 78 |
+
MAX_COMMENT_CHARS,
|
| 79 |
+
MAX_TIME_FOR_LOOP,
|
| 80 |
+
MODEL_DTYPE,
|
| 81 |
+
MULTIMODAL_PROMPT_FORMAT,
|
| 82 |
+
NUM_PRED_TOKENS,
|
| 83 |
+
NUMBER_OF_RETRY_ATTEMPTS,
|
| 84 |
+
RUN_LOCAL_MODEL,
|
| 85 |
+
SPECULATIVE_DECODING,
|
| 86 |
+
TIMEOUT_WAIT,
|
| 87 |
+
USE_BITSANDBYTES,
|
| 88 |
+
USE_LLAMA_CPP,
|
| 89 |
+
USE_LLAMA_SWAP,
|
| 90 |
+
V_QUANT_LEVEL,
|
| 91 |
+
)
|
| 92 |
+
from tools.helper_functions import _get_env_list
|
| 93 |
+
|
| 94 |
+
if SPECULATIVE_DECODING == "True":
|
| 95 |
+
SPECULATIVE_DECODING = True
|
| 96 |
+
else:
|
| 97 |
+
SPECULATIVE_DECODING = False
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if isinstance(NUM_PRED_TOKENS, str):
|
| 101 |
+
NUM_PRED_TOKENS = int(NUM_PRED_TOKENS)
|
| 102 |
+
if isinstance(LLM_MAX_GPU_LAYERS, str):
|
| 103 |
+
LLM_MAX_GPU_LAYERS = int(LLM_MAX_GPU_LAYERS)
|
| 104 |
+
if isinstance(LLM_THREADS, str):
|
| 105 |
+
LLM_THREADS = int(LLM_THREADS)
|
| 106 |
+
|
| 107 |
+
if LLM_RESET == "True":
|
| 108 |
+
reset = True
|
| 109 |
+
else:
|
| 110 |
+
reset = False
|
| 111 |
+
|
| 112 |
+
if LLM_STREAM == "True":
|
| 113 |
+
stream = True
|
| 114 |
+
else:
|
| 115 |
+
stream = False
|
| 116 |
+
|
| 117 |
+
if LLM_SAMPLE == "True":
|
| 118 |
+
sample = True
|
| 119 |
+
else:
|
| 120 |
+
sample = False
|
| 121 |
+
|
| 122 |
+
if LLM_STOP_STRINGS:
|
| 123 |
+
LLM_STOP_STRINGS = _get_env_list(LLM_STOP_STRINGS, strip_strings=False)
|
| 124 |
+
|
| 125 |
+
max_tokens = LLM_MAX_NEW_TOKENS
|
| 126 |
+
timeout_wait = TIMEOUT_WAIT
|
| 127 |
+
number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
|
| 128 |
+
max_time_for_loop = MAX_TIME_FOR_LOOP
|
| 129 |
+
batch_size_default = BATCH_SIZE_DEFAULT
|
| 130 |
+
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
| 131 |
+
max_comment_character_length = MAX_COMMENT_CHARS
|
| 132 |
+
|
| 133 |
+
temperature = LLM_TEMPERATURE
|
| 134 |
+
top_k = LLM_TOP_K
|
| 135 |
+
top_p = LLM_TOP_P
|
| 136 |
+
min_p = LLM_MIN_P
|
| 137 |
+
repetition_penalty = LLM_REPETITION_PENALTY
|
| 138 |
+
last_n_tokens = LLM_LAST_N_TOKENS
|
| 139 |
+
LLM_MAX_NEW_TOKENS: int = LLM_MAX_NEW_TOKENS
|
| 140 |
+
seed: int = LLM_SEED
|
| 141 |
+
reset: bool = reset
|
| 142 |
+
stream: bool = stream
|
| 143 |
+
batch_size: int = LLM_BATCH_SIZE
|
| 144 |
+
context_length: int = LLM_CONTEXT_LENGTH
|
| 145 |
+
sample = LLM_SAMPLE
|
| 146 |
+
stop_strings = LLM_STOP_STRINGS
|
| 147 |
+
speculative_decoding = SPECULATIVE_DECODING
|
| 148 |
+
if LLM_MAX_GPU_LAYERS != 0:
|
| 149 |
+
gpu_layers = int(LLM_MAX_GPU_LAYERS)
|
| 150 |
+
torch_device = "cuda"
|
| 151 |
+
else:
|
| 152 |
+
gpu_layers = 0
|
| 153 |
+
torch_device = "cpu"
|
| 154 |
+
|
| 155 |
+
if not LLM_THREADS:
|
| 156 |
+
threads = 1
|
| 157 |
+
else:
|
| 158 |
+
threads = LLM_THREADS
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class llama_cpp_init_config_gpu:
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
last_n_tokens=last_n_tokens,
|
| 165 |
+
seed=seed,
|
| 166 |
+
n_threads=threads,
|
| 167 |
+
n_batch=batch_size,
|
| 168 |
+
n_ctx=context_length,
|
| 169 |
+
n_gpu_layers=gpu_layers,
|
| 170 |
+
reset=reset,
|
| 171 |
+
):
|
| 172 |
+
|
| 173 |
+
self.last_n_tokens = last_n_tokens
|
| 174 |
+
self.seed = seed
|
| 175 |
+
self.n_threads = n_threads
|
| 176 |
+
self.n_batch = n_batch
|
| 177 |
+
self.n_ctx = n_ctx
|
| 178 |
+
self.n_gpu_layers = n_gpu_layers
|
| 179 |
+
self.reset = reset
|
| 180 |
+
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
| 181 |
+
|
| 182 |
+
def update_gpu(self, new_value):
|
| 183 |
+
self.n_gpu_layers = new_value
|
| 184 |
+
|
| 185 |
+
def update_context(self, new_value):
|
| 186 |
+
self.n_ctx = new_value
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class llama_cpp_init_config_cpu(llama_cpp_init_config_gpu):
|
| 190 |
+
def __init__(self):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.n_gpu_layers = gpu_layers
|
| 193 |
+
self.n_ctx = context_length
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
gpu_config = llama_cpp_init_config_gpu()
|
| 197 |
+
cpu_config = llama_cpp_init_config_cpu()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class LlamaCPPGenerationConfig:
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
temperature=temperature,
|
| 204 |
+
top_k=top_k,
|
| 205 |
+
min_p=min_p,
|
| 206 |
+
top_p=top_p,
|
| 207 |
+
repeat_penalty=repetition_penalty,
|
| 208 |
+
seed=seed,
|
| 209 |
+
stream=stream,
|
| 210 |
+
max_tokens=LLM_MAX_NEW_TOKENS,
|
| 211 |
+
reset=reset,
|
| 212 |
+
):
|
| 213 |
+
self.temperature = temperature
|
| 214 |
+
self.top_k = top_k
|
| 215 |
+
self.top_p = top_p
|
| 216 |
+
self.repeat_penalty = repeat_penalty
|
| 217 |
+
self.seed = seed
|
| 218 |
+
self.max_tokens = max_tokens
|
| 219 |
+
self.stream = stream
|
| 220 |
+
self.reset = reset
|
| 221 |
+
|
| 222 |
+
def update_temp(self, new_value):
|
| 223 |
+
self.temperature = new_value
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ResponseObject class for AWS Bedrock calls
|
| 227 |
+
class ResponseObject:
|
| 228 |
+
def __init__(self, text, usage_metadata):
|
| 229 |
+
self.text = text
|
| 230 |
+
self.usage_metadata = usage_metadata
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
###
|
| 234 |
+
# LOCAL MODEL FUNCTIONS
|
| 235 |
+
###
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_model_path(
|
| 239 |
+
repo_id=LOCAL_REPO_ID,
|
| 240 |
+
model_filename=LOCAL_MODEL_FILE,
|
| 241 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
| 242 |
+
hf_token=HF_TOKEN,
|
| 243 |
+
):
|
| 244 |
+
# Construct the expected local path
|
| 245 |
+
local_path = os.path.join(model_dir, model_filename)
|
| 246 |
+
|
| 247 |
+
print("local path for model load:", local_path)
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
if os.path.exists(local_path):
|
| 251 |
+
print(f"Model already exists at: {local_path}")
|
| 252 |
+
|
| 253 |
+
return local_path
|
| 254 |
+
else:
|
| 255 |
+
if hf_token:
|
| 256 |
+
print("Downloading model from Hugging Face Hub with HF token")
|
| 257 |
+
downloaded_model_path = hf_hub_download(
|
| 258 |
+
repo_id=repo_id, token=hf_token, filename=model_filename
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return downloaded_model_path
|
| 262 |
+
else:
|
| 263 |
+
print(
|
| 264 |
+
"No HF token found, downloading model from Hugging Face Hub without token"
|
| 265 |
+
)
|
| 266 |
+
downloaded_model_path = hf_hub_download(
|
| 267 |
+
repo_id=repo_id, filename=model_filename
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return downloaded_model_path
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print("Error loading model:", e)
|
| 274 |
+
raise Warning("Error loading model:", e)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def load_model(
|
| 278 |
+
local_model_type: str = CHOSEN_LOCAL_MODEL_TYPE,
|
| 279 |
+
gpu_layers: int = gpu_layers,
|
| 280 |
+
max_context_length: int = context_length,
|
| 281 |
+
gpu_config: llama_cpp_init_config_gpu = gpu_config,
|
| 282 |
+
cpu_config: llama_cpp_init_config_cpu = cpu_config,
|
| 283 |
+
torch_device: str = torch_device,
|
| 284 |
+
repo_id=LOCAL_REPO_ID,
|
| 285 |
+
model_filename=LOCAL_MODEL_FILE,
|
| 286 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
| 287 |
+
compile_mode=COMPILE_MODE,
|
| 288 |
+
model_dtype=MODEL_DTYPE,
|
| 289 |
+
hf_token=HF_TOKEN,
|
| 290 |
+
speculative_decoding=speculative_decoding,
|
| 291 |
+
model=None,
|
| 292 |
+
tokenizer=None,
|
| 293 |
+
assistant_model=None,
|
| 294 |
+
):
|
| 295 |
+
"""
|
| 296 |
+
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
local_model_type (str): The type of local model to load (e.g., "llama-cpp").
|
| 300 |
+
gpu_layers (int): The number of GPU layers to offload to the GPU.
|
| 301 |
+
max_context_length (int): The maximum context length for the model.
|
| 302 |
+
gpu_config (llama_cpp_init_config_gpu): Configuration object for GPU-specific Llama.cpp parameters.
|
| 303 |
+
cpu_config (llama_cpp_init_config_cpu): Configuration object for CPU-specific Llama.cpp parameters.
|
| 304 |
+
torch_device (str): The device to load the model on ("cuda" for GPU, "cpu" for CPU).
|
| 305 |
+
repo_id (str): The Hugging Face repository ID where the model is located.
|
| 306 |
+
model_filename (str): The specific filename of the model to download from the repository.
|
| 307 |
+
model_dir (str): The local directory where the model will be stored or downloaded.
|
| 308 |
+
compile_mode (str): The compilation mode to use for the model.
|
| 309 |
+
model_dtype (str): The data type to use for the model.
|
| 310 |
+
hf_token (str): The Hugging Face token to use for the model.
|
| 311 |
+
speculative_decoding (bool): Whether to use speculative decoding.
|
| 312 |
+
model (Llama/transformers model): The model to load.
|
| 313 |
+
tokenizer (list/transformers tokenizer): The tokenizer to load.
|
| 314 |
+
assistant_model (transformers model): The assistant model for speculative decoding.
|
| 315 |
+
Returns:
|
| 316 |
+
tuple: A tuple containing:
|
| 317 |
+
- model (Llama/transformers model): The loaded Llama.cpp/transformers model instance.
|
| 318 |
+
- tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
|
| 319 |
+
- assistant_model (transformers model): The assistant model for speculative decoding (if speculative_decoding is True).
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
if model:
|
| 323 |
+
return model, tokenizer, assistant_model
|
| 324 |
+
|
| 325 |
+
print("Loading model:", local_model_type)
|
| 326 |
+
|
| 327 |
+
# Verify the device and cuda settings
|
| 328 |
+
# Check if CUDA is enabled
|
| 329 |
+
|
| 330 |
+
import torch
|
| 331 |
+
|
| 332 |
+
torch.cuda.empty_cache()
|
| 333 |
+
print("Is CUDA enabled? ", torch.cuda.is_available())
|
| 334 |
+
print("Is a CUDA device available on this computer?", torch.backends.cudnn.enabled)
|
| 335 |
+
if torch.cuda.is_available():
|
| 336 |
+
torch_device = "cuda"
|
| 337 |
+
gpu_layers = int(LLM_MAX_GPU_LAYERS)
|
| 338 |
+
print("CUDA version:", torch.version.cuda)
|
| 339 |
+
# try:
|
| 340 |
+
# os.system("nvidia-smi")
|
| 341 |
+
# except Exception as e:
|
| 342 |
+
# print("Could not print nvidia-smi settings due to:", e)
|
| 343 |
+
else:
|
| 344 |
+
torch_device = "cpu"
|
| 345 |
+
gpu_layers = 0
|
| 346 |
+
|
| 347 |
+
print("Running on device:", torch_device)
|
| 348 |
+
print("GPU layers assigned to cuda:", gpu_layers)
|
| 349 |
+
|
| 350 |
+
if not LLM_THREADS:
|
| 351 |
+
threads = torch.get_num_threads()
|
| 352 |
+
else:
|
| 353 |
+
threads = LLM_THREADS
|
| 354 |
+
print("CPU threads:", threads)
|
| 355 |
+
|
| 356 |
+
# GPU mode
|
| 357 |
+
if torch_device == "cuda":
|
| 358 |
+
torch.cuda.empty_cache()
|
| 359 |
+
gpu_config.update_gpu(gpu_layers)
|
| 360 |
+
gpu_config.update_context(max_context_length)
|
| 361 |
+
|
| 362 |
+
if USE_LLAMA_CPP == "True":
|
| 363 |
+
from llama_cpp import Llama
|
| 364 |
+
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
| 365 |
+
|
| 366 |
+
model_path = get_model_path(
|
| 367 |
+
repo_id=repo_id, model_filename=model_filename, model_dir=model_dir
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
try:
|
| 371 |
+
print("GPU load variables:", vars(gpu_config))
|
| 372 |
+
if speculative_decoding:
|
| 373 |
+
model = Llama(
|
| 374 |
+
model_path=model_path,
|
| 375 |
+
type_k=K_QUANT_LEVEL,
|
| 376 |
+
type_v=V_QUANT_LEVEL,
|
| 377 |
+
flash_attn=True,
|
| 378 |
+
draft_model=LlamaPromptLookupDecoding(
|
| 379 |
+
num_pred_tokens=NUM_PRED_TOKENS
|
| 380 |
+
),
|
| 381 |
+
**vars(gpu_config),
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
model = Llama(
|
| 385 |
+
model_path=model_path,
|
| 386 |
+
type_k=K_QUANT_LEVEL,
|
| 387 |
+
type_v=V_QUANT_LEVEL,
|
| 388 |
+
flash_attn=True,
|
| 389 |
+
**vars(gpu_config),
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print("GPU load failed due to:", e, "Loading model in CPU mode")
|
| 394 |
+
# If fails, go to CPU mode
|
| 395 |
+
model = Llama(model_path=model_path, **vars(cpu_config))
|
| 396 |
+
|
| 397 |
+
else:
|
| 398 |
+
from transformers import (
|
| 399 |
+
AutoModelForCausalLM,
|
| 400 |
+
BitsAndBytesConfig,
|
| 401 |
+
)
|
| 402 |
+
from unsloth import FastLanguageModel
|
| 403 |
+
|
| 404 |
+
print("Loading model from transformers")
|
| 405 |
+
# Use the official model ID for Gemma 3 4B
|
| 406 |
+
model_id = (
|
| 407 |
+
repo_id.split("https://huggingface.co/")[-1]
|
| 408 |
+
if "https://huggingface.co/" in repo_id
|
| 409 |
+
else repo_id
|
| 410 |
+
)
|
| 411 |
+
# 1. Set Data Type (dtype)
|
| 412 |
+
# For H200/Hopper: 'bfloat16'
|
| 413 |
+
# For RTX 3060/Ampere: 'float16'
|
| 414 |
+
dtype_str = model_dtype # os.environ.get("MODEL_DTYPE", "bfloat16").lower()
|
| 415 |
+
if dtype_str == "bfloat16":
|
| 416 |
+
torch_dtype = torch.bfloat16
|
| 417 |
+
elif dtype_str == "float16":
|
| 418 |
+
torch_dtype = torch.float16
|
| 419 |
+
else:
|
| 420 |
+
torch_dtype = torch.float32 # A safe fallback
|
| 421 |
+
|
| 422 |
+
# 2. Set Compilation Mode
|
| 423 |
+
# 'max-autotune' is great for both but can be slow initially.
|
| 424 |
+
# 'reduce-overhead' is a faster alternative for compiling.
|
| 425 |
+
|
| 426 |
+
print("--- System Configuration ---")
|
| 427 |
+
print(f"Using model id: {model_id}")
|
| 428 |
+
print(f"Using dtype: {torch_dtype}")
|
| 429 |
+
print(f"Using compile mode: {compile_mode}")
|
| 430 |
+
print(f"Using bitsandbytes: {USE_BITSANDBYTES}")
|
| 431 |
+
print("--------------------------\n")
|
| 432 |
+
|
| 433 |
+
# --- Load Tokenizer and Model ---
|
| 434 |
+
|
| 435 |
+
try:
|
| 436 |
+
|
| 437 |
+
# Load Tokenizer and Model
|
| 438 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 439 |
+
|
| 440 |
+
if USE_BITSANDBYTES == "True":
|
| 441 |
+
|
| 442 |
+
if INT8_WITH_OFFLOAD_TO_CPU == "True":
|
| 443 |
+
# This will be very slow. Requires at least 4GB of VRAM and 32GB of RAM
|
| 444 |
+
print(
|
| 445 |
+
"Using bitsandbytes for quantisation to 8 bits, with offloading to CPU"
|
| 446 |
+
)
|
| 447 |
+
max_memory = {0: "4GB", "cpu": "32GB"}
|
| 448 |
+
BitsAndBytesConfig(
|
| 449 |
+
load_in_8bit=True,
|
| 450 |
+
max_memory=max_memory,
|
| 451 |
+
llm_int8_enable_fp32_cpu_offload=True, # Note: if bitsandbytes has to offload to CPU, inference will be slow
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
# For Gemma 4B, requires at least 6GB of VRAM
|
| 455 |
+
print("Using bitsandbytes for quantisation to 4 bits")
|
| 456 |
+
BitsAndBytesConfig(
|
| 457 |
+
load_in_4bit=True,
|
| 458 |
+
bnb_4bit_quant_type="nf4", # Use the modern NF4 quantisation for better performance
|
| 459 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
| 460 |
+
bnb_4bit_use_double_quant=True, # Optional: uses a second quantisation step to save even more memory
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# print("Loading model with bitsandbytes quantisation config:", quantisation_config)
|
| 464 |
+
|
| 465 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 466 |
+
model_id,
|
| 467 |
+
max_seq_length=max_context_length,
|
| 468 |
+
dtype=torch_dtype,
|
| 469 |
+
device_map="auto",
|
| 470 |
+
load_in_4bit=True,
|
| 471 |
+
# quantization_config=quantisation_config, # Not actually used in Unsloth
|
| 472 |
+
token=hf_token,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
FastLanguageModel.for_inference(model)
|
| 476 |
+
else:
|
| 477 |
+
print("Loading model without bitsandbytes quantisation")
|
| 478 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 479 |
+
model_id,
|
| 480 |
+
max_seq_length=max_context_length,
|
| 481 |
+
dtype=torch_dtype,
|
| 482 |
+
device_map="auto",
|
| 483 |
+
token=hf_token,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
FastLanguageModel.for_inference(model)
|
| 487 |
+
|
| 488 |
+
if not tokenizer.pad_token:
|
| 489 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print("Error loading model with bitsandbytes quantisation config:", e)
|
| 493 |
+
raise Warning(
|
| 494 |
+
"Error loading model with bitsandbytes quantisation config:", e
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Compile the Model with the selected mode π
|
| 498 |
+
if COMPILE_TRANSFORMERS == "True":
|
| 499 |
+
try:
|
| 500 |
+
model = torch.compile(model, mode=compile_mode, fullgraph=True)
|
| 501 |
+
except Exception as e:
|
| 502 |
+
print(f"Could not compile model: {e}. Running in eager mode.")
|
| 503 |
+
|
| 504 |
+
print(
|
| 505 |
+
"Loading with",
|
| 506 |
+
gpu_config.n_gpu_layers,
|
| 507 |
+
"model layers sent to GPU and a maximum context length of",
|
| 508 |
+
gpu_config.n_ctx,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# CPU mode
|
| 512 |
+
else:
|
| 513 |
+
if USE_LLAMA_CPP == "False":
|
| 514 |
+
raise Warning(
|
| 515 |
+
"Using transformers model in CPU mode is not supported. Please change your config variable USE_LLAMA_CPP to True if you want to do CPU inference."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
model_path = get_model_path(
|
| 519 |
+
repo_id=repo_id, model_filename=model_filename, model_dir=model_dir
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# gpu_config.update_gpu(gpu_layers)
|
| 523 |
+
cpu_config.update_gpu(gpu_layers)
|
| 524 |
+
|
| 525 |
+
# Update context length according to slider
|
| 526 |
+
# gpu_config.update_context(max_context_length)
|
| 527 |
+
cpu_config.update_context(max_context_length)
|
| 528 |
+
|
| 529 |
+
if speculative_decoding:
|
| 530 |
+
model = Llama(
|
| 531 |
+
model_path=model_path,
|
| 532 |
+
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS),
|
| 533 |
+
**vars(cpu_config),
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
model = Llama(model_path=model_path, **vars(cpu_config))
|
| 537 |
+
|
| 538 |
+
print(
|
| 539 |
+
"Loading with",
|
| 540 |
+
cpu_config.n_gpu_layers,
|
| 541 |
+
"model layers sent to GPU and a maximum context length of",
|
| 542 |
+
cpu_config.n_ctx,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
print("Finished loading model:", local_model_type)
|
| 546 |
+
print("GPU layers assigned to cuda:", gpu_layers)
|
| 547 |
+
|
| 548 |
+
# Load assistant model for speculative decoding if enabled
|
| 549 |
+
if speculative_decoding and USE_LLAMA_CPP == "False" and torch_device == "cuda":
|
| 550 |
+
print("Loading assistant model for speculative decoding:", ASSISTANT_MODEL)
|
| 551 |
+
try:
|
| 552 |
+
from transformers import AutoModelForCausalLM
|
| 553 |
+
|
| 554 |
+
# Load the assistant model with the same configuration as the main model
|
| 555 |
+
assistant_model = AutoModelForCausalLM.from_pretrained(
|
| 556 |
+
ASSISTANT_MODEL, dtype=torch_dtype, device_map="auto", token=hf_token
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# assistant_model.config._name_or_path = model.config._name_or_path
|
| 560 |
+
|
| 561 |
+
# Compile the assistant model if compilation is enabled
|
| 562 |
+
if COMPILE_TRANSFORMERS == "True":
|
| 563 |
+
try:
|
| 564 |
+
assistant_model = torch.compile(
|
| 565 |
+
assistant_model, mode=compile_mode, fullgraph=True
|
| 566 |
+
)
|
| 567 |
+
except Exception as e:
|
| 568 |
+
print(
|
| 569 |
+
f"Could not compile assistant model: {e}. Running in eager mode."
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
print("Successfully loaded assistant model for speculative decoding")
|
| 573 |
+
|
| 574 |
+
except Exception as e:
|
| 575 |
+
print(f"Error loading assistant model: {e}")
|
| 576 |
+
assistant_model = None
|
| 577 |
+
else:
|
| 578 |
+
assistant_model = None
|
| 579 |
+
|
| 580 |
+
return model, tokenizer, assistant_model
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def get_model():
|
| 584 |
+
"""Get the globally loaded model. Load it if not already loaded."""
|
| 585 |
+
global _model, _tokenizer, _assistant_model
|
| 586 |
+
if _model is None:
|
| 587 |
+
_model, _tokenizer, _assistant_model = load_model(
|
| 588 |
+
local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
|
| 589 |
+
gpu_layers=gpu_layers,
|
| 590 |
+
max_context_length=context_length,
|
| 591 |
+
gpu_config=gpu_config,
|
| 592 |
+
cpu_config=cpu_config,
|
| 593 |
+
torch_device=torch_device,
|
| 594 |
+
repo_id=LOCAL_REPO_ID,
|
| 595 |
+
model_filename=LOCAL_MODEL_FILE,
|
| 596 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
| 597 |
+
compile_mode=COMPILE_MODE,
|
| 598 |
+
model_dtype=MODEL_DTYPE,
|
| 599 |
+
hf_token=HF_TOKEN,
|
| 600 |
+
model=_model,
|
| 601 |
+
tokenizer=_tokenizer,
|
| 602 |
+
assistant_model=_assistant_model,
|
| 603 |
+
)
|
| 604 |
+
return _model
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def get_tokenizer():
|
| 608 |
+
"""Get the globally loaded tokenizer. Load it if not already loaded."""
|
| 609 |
+
global _model, _tokenizer, _assistant_model
|
| 610 |
+
if _tokenizer is None:
|
| 611 |
+
_model, _tokenizer, _assistant_model = load_model(
|
| 612 |
+
local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
|
| 613 |
+
gpu_layers=gpu_layers,
|
| 614 |
+
max_context_length=context_length,
|
| 615 |
+
gpu_config=gpu_config,
|
| 616 |
+
cpu_config=cpu_config,
|
| 617 |
+
torch_device=torch_device,
|
| 618 |
+
repo_id=LOCAL_REPO_ID,
|
| 619 |
+
model_filename=LOCAL_MODEL_FILE,
|
| 620 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
| 621 |
+
compile_mode=COMPILE_MODE,
|
| 622 |
+
model_dtype=MODEL_DTYPE,
|
| 623 |
+
hf_token=HF_TOKEN,
|
| 624 |
+
model=_model,
|
| 625 |
+
tokenizer=_tokenizer,
|
| 626 |
+
assistant_model=_assistant_model,
|
| 627 |
+
)
|
| 628 |
+
return _tokenizer
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def get_assistant_model():
|
| 632 |
+
"""Get the globally loaded assistant model. Load it if not already loaded."""
|
| 633 |
+
global _model, _tokenizer, _assistant_model
|
| 634 |
+
if _assistant_model is None:
|
| 635 |
+
_model, _tokenizer, _assistant_model = load_model(
|
| 636 |
+
local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
|
| 637 |
+
gpu_layers=gpu_layers,
|
| 638 |
+
max_context_length=context_length,
|
| 639 |
+
gpu_config=gpu_config,
|
| 640 |
+
cpu_config=cpu_config,
|
| 641 |
+
torch_device=torch_device,
|
| 642 |
+
repo_id=LOCAL_REPO_ID,
|
| 643 |
+
model_filename=LOCAL_MODEL_FILE,
|
| 644 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
| 645 |
+
compile_mode=COMPILE_MODE,
|
| 646 |
+
model_dtype=MODEL_DTYPE,
|
| 647 |
+
hf_token=HF_TOKEN,
|
| 648 |
+
model=_model,
|
| 649 |
+
tokenizer=_tokenizer,
|
| 650 |
+
assistant_model=_assistant_model,
|
| 651 |
+
)
|
| 652 |
+
return _assistant_model
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def set_model(model, tokenizer, assistant_model=None):
|
| 656 |
+
"""Set the global model, tokenizer, and assistant model."""
|
| 657 |
+
global _model, _tokenizer, _assistant_model
|
| 658 |
+
_model = model
|
| 659 |
+
_tokenizer = tokenizer
|
| 660 |
+
_assistant_model = assistant_model
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# Initialize model at startup if configured
|
| 664 |
+
if LOAD_LOCAL_MODEL_AT_START == "True" and RUN_LOCAL_MODEL == "1":
|
| 665 |
+
get_model() # This will trigger loading
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def call_llama_cpp_model(formatted_string: str, gen_config: str, model=None):
|
| 669 |
+
"""
|
| 670 |
+
Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
formatted_string (str): The formatted input text for the model.
|
| 674 |
+
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
| 675 |
+
model: Optional model instance. If None, will use the globally loaded model.
|
| 676 |
+
"""
|
| 677 |
+
if model is None:
|
| 678 |
+
model = get_model()
|
| 679 |
+
|
| 680 |
+
if model is None:
|
| 681 |
+
raise ValueError(
|
| 682 |
+
"No model available. Either pass a model parameter or ensure LOAD_LOCAL_MODEL_AT_START is True."
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Extracting parameters from the gen_config object
|
| 686 |
+
temperature = gen_config.temperature
|
| 687 |
+
top_k = gen_config.top_k
|
| 688 |
+
top_p = gen_config.top_p
|
| 689 |
+
repeat_penalty = gen_config.repeat_penalty
|
| 690 |
+
seed = gen_config.seed
|
| 691 |
+
max_tokens = gen_config.max_tokens
|
| 692 |
+
stream = gen_config.stream
|
| 693 |
+
|
| 694 |
+
# Now you can call your model directly, passing the parameters:
|
| 695 |
+
output = model(
|
| 696 |
+
formatted_string,
|
| 697 |
+
temperature=temperature,
|
| 698 |
+
top_k=top_k,
|
| 699 |
+
top_p=top_p,
|
| 700 |
+
repeat_penalty=repeat_penalty,
|
| 701 |
+
seed=seed,
|
| 702 |
+
max_tokens=max_tokens,
|
| 703 |
+
stream=stream, # ,
|
| 704 |
+
# stop=["<|eot_id|>", "\n\n"]
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
return output
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
def call_llama_cpp_chatmodel(
|
| 711 |
+
formatted_string: str,
|
| 712 |
+
system_prompt: str,
|
| 713 |
+
gen_config: LlamaCPPGenerationConfig,
|
| 714 |
+
model=None,
|
| 715 |
+
):
|
| 716 |
+
"""
|
| 717 |
+
Calls your Llama.cpp chat model with a formatted user message and system prompt,
|
| 718 |
+
using generation parameters from the LlamaCPPGenerationConfig object.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
formatted_string (str): The formatted input text for the user's message.
|
| 722 |
+
system_prompt (str): The system-level instructions for the model.
|
| 723 |
+
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
| 724 |
+
model: Optional model instance. If None, will use the globally loaded model.
|
| 725 |
+
"""
|
| 726 |
+
if model is None:
|
| 727 |
+
model = get_model()
|
| 728 |
+
|
| 729 |
+
if model is None:
|
| 730 |
+
raise ValueError(
|
| 731 |
+
"No model available. Either pass a model parameter or ensure LOAD_LOCAL_MODEL_AT_START is True."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Extracting parameters from the gen_config object
|
| 735 |
+
temperature = gen_config.temperature
|
| 736 |
+
top_k = gen_config.top_k
|
| 737 |
+
top_p = gen_config.top_p
|
| 738 |
+
repeat_penalty = gen_config.repeat_penalty
|
| 739 |
+
seed = gen_config.seed
|
| 740 |
+
max_tokens = gen_config.max_tokens
|
| 741 |
+
stream = gen_config.stream
|
| 742 |
+
reset = gen_config.reset
|
| 743 |
+
|
| 744 |
+
messages = [
|
| 745 |
+
{"role": "system", "content": system_prompt},
|
| 746 |
+
{"role": "user", "content": formatted_string},
|
| 747 |
+
]
|
| 748 |
+
|
| 749 |
+
input_tokens = len(
|
| 750 |
+
model.tokenize(
|
| 751 |
+
(system_prompt + "\n" + formatted_string).encode("utf-8"), special=True
|
| 752 |
+
)
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
if stream:
|
| 756 |
+
final_tokens = list()
|
| 757 |
+
output_tokens = 0
|
| 758 |
+
for chunk in model.create_chat_completion(
|
| 759 |
+
messages=messages,
|
| 760 |
+
temperature=temperature,
|
| 761 |
+
top_k=top_k,
|
| 762 |
+
top_p=top_p,
|
| 763 |
+
repeat_penalty=repeat_penalty,
|
| 764 |
+
seed=seed,
|
| 765 |
+
max_tokens=max_tokens,
|
| 766 |
+
stream=True,
|
| 767 |
+
stop=stop_strings,
|
| 768 |
+
):
|
| 769 |
+
delta = chunk["choices"][0].get("delta", {})
|
| 770 |
+
token = delta.get("content") or chunk["choices"][0].get("text") or ""
|
| 771 |
+
if token:
|
| 772 |
+
print(token, end="", flush=True)
|
| 773 |
+
final_tokens.append(token)
|
| 774 |
+
output_tokens += 1
|
| 775 |
+
print() # newline after stream finishes
|
| 776 |
+
|
| 777 |
+
text = "".join(final_tokens)
|
| 778 |
+
|
| 779 |
+
if reset:
|
| 780 |
+
model.reset()
|
| 781 |
+
|
| 782 |
+
return {
|
| 783 |
+
"choices": [
|
| 784 |
+
{
|
| 785 |
+
"index": 0,
|
| 786 |
+
"finish_reason": "stop",
|
| 787 |
+
"message": {"role": "assistant", "content": text},
|
| 788 |
+
}
|
| 789 |
+
],
|
| 790 |
+
# Provide a usage object so downstream code can read it
|
| 791 |
+
"usage": {
|
| 792 |
+
"prompt_tokens": input_tokens, # unknown during streaming
|
| 793 |
+
"completion_tokens": output_tokens, # unknown during streaming
|
| 794 |
+
"total_tokens": input_tokens
|
| 795 |
+
+ output_tokens, # unknown during streaming
|
| 796 |
+
},
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
else:
|
| 800 |
+
response = model.create_chat_completion(
|
| 801 |
+
messages=messages,
|
| 802 |
+
temperature=temperature,
|
| 803 |
+
top_k=top_k,
|
| 804 |
+
top_p=top_p,
|
| 805 |
+
repeat_penalty=repeat_penalty,
|
| 806 |
+
seed=seed,
|
| 807 |
+
max_tokens=max_tokens,
|
| 808 |
+
stream=False,
|
| 809 |
+
stop=stop_strings,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
if reset:
|
| 813 |
+
model.reset()
|
| 814 |
+
|
| 815 |
+
return response
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def call_inference_server_api(
|
| 819 |
+
formatted_string: str,
|
| 820 |
+
system_prompt: str,
|
| 821 |
+
gen_config: LlamaCPPGenerationConfig,
|
| 822 |
+
api_url: str = "http://localhost:8080",
|
| 823 |
+
model_name: str = None,
|
| 824 |
+
use_llama_swap: bool = USE_LLAMA_SWAP,
|
| 825 |
+
):
|
| 826 |
+
"""
|
| 827 |
+
Calls a inference-server API endpoint with a formatted user message and system prompt,
|
| 828 |
+
using generation parameters from the LlamaCPPGenerationConfig object.
|
| 829 |
+
|
| 830 |
+
This function provides the same interface as call_llama_cpp_chatmodel but calls
|
| 831 |
+
a remote inference-server instance instead of a local model.
|
| 832 |
+
|
| 833 |
+
Args:
|
| 834 |
+
formatted_string (str): The formatted input text for the user's message.
|
| 835 |
+
system_prompt (str): The system-level instructions for the model.
|
| 836 |
+
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
| 837 |
+
api_url (str): The base URL of the inference-server API (default: "http://localhost:8080").
|
| 838 |
+
model_name (str): Optional model name to use. If None, uses the default model.
|
| 839 |
+
use_llama_swap (bool): Whether to use llama-swap for the model.
|
| 840 |
+
Returns:
|
| 841 |
+
dict: Response in the same format as call_llama_cpp_chatmodel
|
| 842 |
+
|
| 843 |
+
Example:
|
| 844 |
+
# Create generation config
|
| 845 |
+
gen_config = LlamaCPPGenerationConfig(temperature=0.7, max_tokens=100)
|
| 846 |
+
|
| 847 |
+
# Call the API
|
| 848 |
+
response = call_inference_server_api(
|
| 849 |
+
formatted_string="Hello, how are you?",
|
| 850 |
+
system_prompt="You are a helpful assistant.",
|
| 851 |
+
gen_config=gen_config,
|
| 852 |
+
api_url="http://localhost:8080"
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# Extract the response text
|
| 856 |
+
response_text = response['choices'][0]['message']['content']
|
| 857 |
+
|
| 858 |
+
Integration Example:
|
| 859 |
+
# To use inference-server instead of local model:
|
| 860 |
+
# 1. Set model_source to "inference-server"
|
| 861 |
+
# 2. Provide api_url parameter
|
| 862 |
+
# 3. Call your existing functions as normal
|
| 863 |
+
|
| 864 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(
|
| 865 |
+
batch_prompts=["Your prompt here"],
|
| 866 |
+
system_prompt="Your system prompt",
|
| 867 |
+
conversation_history=[],
|
| 868 |
+
whole_conversation=[],
|
| 869 |
+
whole_conversation_metadata=[],
|
| 870 |
+
client=None, # Not used for inference-server
|
| 871 |
+
client_config=None, # Not used for inference-server
|
| 872 |
+
model_choice="your-model-name", # Model name on the server
|
| 873 |
+
temperature=0.7,
|
| 874 |
+
reported_batch_no=1,
|
| 875 |
+
local_model=None, # Not used for inference-server
|
| 876 |
+
tokenizer=None, # Not used for inference-server
|
| 877 |
+
bedrock_runtime=None, # Not used for inference-server
|
| 878 |
+
model_source="inference-server",
|
| 879 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS=3,
|
| 880 |
+
api_url="http://localhost:8080"
|
| 881 |
+
)
|
| 882 |
+
"""
|
| 883 |
+
# Extract parameters from the gen_config object
|
| 884 |
+
temperature = gen_config.temperature
|
| 885 |
+
top_k = gen_config.top_k
|
| 886 |
+
top_p = gen_config.top_p
|
| 887 |
+
repeat_penalty = gen_config.repeat_penalty
|
| 888 |
+
seed = gen_config.seed
|
| 889 |
+
max_tokens = gen_config.max_tokens
|
| 890 |
+
stream = gen_config.stream
|
| 891 |
+
|
| 892 |
+
# Prepare the request payload
|
| 893 |
+
messages = [
|
| 894 |
+
{"role": "system", "content": system_prompt},
|
| 895 |
+
{"role": "user", "content": formatted_string},
|
| 896 |
+
]
|
| 897 |
+
|
| 898 |
+
payload = {
|
| 899 |
+
"messages": messages,
|
| 900 |
+
"temperature": temperature,
|
| 901 |
+
"top_k": top_k,
|
| 902 |
+
"top_p": top_p,
|
| 903 |
+
"repeat_penalty": repeat_penalty,
|
| 904 |
+
"seed": seed,
|
| 905 |
+
"max_tokens": max_tokens,
|
| 906 |
+
"stream": stream,
|
| 907 |
+
"stop": LLM_STOP_STRINGS if LLM_STOP_STRINGS else [],
|
| 908 |
+
}
|
| 909 |
+
# Add model name if specified and use llama-swap
|
| 910 |
+
if model_name and use_llama_swap:
|
| 911 |
+
payload["model"] = model_name
|
| 912 |
+
|
| 913 |
+
# Determine the endpoint based on streaming preference
|
| 914 |
+
if stream:
|
| 915 |
+
endpoint = f"{api_url}/v1/chat/completions"
|
| 916 |
+
else:
|
| 917 |
+
endpoint = f"{api_url}/v1/chat/completions"
|
| 918 |
+
|
| 919 |
+
try:
|
| 920 |
+
if stream:
|
| 921 |
+
# Handle streaming response
|
| 922 |
+
response = requests.post(
|
| 923 |
+
endpoint,
|
| 924 |
+
json=payload,
|
| 925 |
+
headers={"Content-Type": "application/json"},
|
| 926 |
+
stream=True,
|
| 927 |
+
timeout=timeout_wait,
|
| 928 |
+
)
|
| 929 |
+
response.raise_for_status()
|
| 930 |
+
|
| 931 |
+
final_tokens = []
|
| 932 |
+
output_tokens = 0
|
| 933 |
+
|
| 934 |
+
for line in response.iter_lines():
|
| 935 |
+
if line:
|
| 936 |
+
line = line.decode("utf-8")
|
| 937 |
+
if line.startswith("data: "):
|
| 938 |
+
data = line[6:] # Remove 'data: ' prefix
|
| 939 |
+
if data.strip() == "[DONE]":
|
| 940 |
+
break
|
| 941 |
+
try:
|
| 942 |
+
chunk = json.loads(data)
|
| 943 |
+
if "choices" in chunk and len(chunk["choices"]) > 0:
|
| 944 |
+
delta = chunk["choices"][0].get("delta", {})
|
| 945 |
+
token = delta.get("content", "")
|
| 946 |
+
if token:
|
| 947 |
+
print(token, end="", flush=True)
|
| 948 |
+
final_tokens.append(token)
|
| 949 |
+
output_tokens += 1
|
| 950 |
+
except json.JSONDecodeError:
|
| 951 |
+
continue
|
| 952 |
+
|
| 953 |
+
print() # newline after stream finishes
|
| 954 |
+
|
| 955 |
+
text = "".join(final_tokens)
|
| 956 |
+
|
| 957 |
+
# Estimate input tokens (rough approximation)
|
| 958 |
+
input_tokens = len((system_prompt + "\n" + formatted_string).split())
|
| 959 |
+
|
| 960 |
+
return {
|
| 961 |
+
"choices": [
|
| 962 |
+
{
|
| 963 |
+
"index": 0,
|
| 964 |
+
"finish_reason": "stop",
|
| 965 |
+
"message": {"role": "assistant", "content": text},
|
| 966 |
+
}
|
| 967 |
+
],
|
| 968 |
+
"usage": {
|
| 969 |
+
"prompt_tokens": input_tokens,
|
| 970 |
+
"completion_tokens": output_tokens,
|
| 971 |
+
"total_tokens": input_tokens + output_tokens,
|
| 972 |
+
},
|
| 973 |
+
}
|
| 974 |
+
else:
|
| 975 |
+
# Handle non-streaming response
|
| 976 |
+
response = requests.post(
|
| 977 |
+
endpoint,
|
| 978 |
+
json=payload,
|
| 979 |
+
headers={"Content-Type": "application/json"},
|
| 980 |
+
timeout=timeout_wait,
|
| 981 |
+
)
|
| 982 |
+
response.raise_for_status()
|
| 983 |
+
|
| 984 |
+
result = response.json()
|
| 985 |
+
|
| 986 |
+
# Ensure the response has the expected format
|
| 987 |
+
if "choices" not in result:
|
| 988 |
+
raise ValueError("Invalid response format from inference-server")
|
| 989 |
+
|
| 990 |
+
return result
|
| 991 |
+
|
| 992 |
+
except requests.exceptions.RequestException as e:
|
| 993 |
+
raise ConnectionError(
|
| 994 |
+
f"Failed to connect to inference-server at {api_url}: {str(e)}"
|
| 995 |
+
)
|
| 996 |
+
except json.JSONDecodeError as e:
|
| 997 |
+
raise ValueError(f"Invalid JSON response from inference-server: {str(e)}")
|
| 998 |
+
except Exception as e:
|
| 999 |
+
raise RuntimeError(f"Error calling inference-server API: {str(e)}")
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
###
|
| 1003 |
+
# LLM FUNCTIONS
|
| 1004 |
+
###
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
def construct_gemini_generative_model(
|
| 1008 |
+
in_api_key: str,
|
| 1009 |
+
temperature: float,
|
| 1010 |
+
model_choice: str,
|
| 1011 |
+
system_prompt: str,
|
| 1012 |
+
max_tokens: int,
|
| 1013 |
+
random_seed=seed,
|
| 1014 |
+
) -> Tuple[object, dict]:
|
| 1015 |
+
"""
|
| 1016 |
+
Constructs a GenerativeModel for Gemini API calls.
|
| 1017 |
+
...
|
| 1018 |
+
"""
|
| 1019 |
+
# Construct a GenerativeModel
|
| 1020 |
+
try:
|
| 1021 |
+
if in_api_key:
|
| 1022 |
+
# print("Getting API key from textbox")
|
| 1023 |
+
api_key = in_api_key
|
| 1024 |
+
client = ai.Client(api_key=api_key)
|
| 1025 |
+
elif "GOOGLE_API_KEY" in os.environ:
|
| 1026 |
+
# print("Searching for API key in environmental variables")
|
| 1027 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
| 1028 |
+
client = ai.Client(api_key=api_key)
|
| 1029 |
+
else:
|
| 1030 |
+
print("No Gemini API key found")
|
| 1031 |
+
raise Warning("No Gemini API key found.")
|
| 1032 |
+
except Exception as e:
|
| 1033 |
+
print("Error constructing Gemini generative model:", e)
|
| 1034 |
+
raise Warning("Error constructing Gemini generative model:", e)
|
| 1035 |
+
|
| 1036 |
+
config = types.GenerateContentConfig(
|
| 1037 |
+
temperature=temperature, max_output_tokens=max_tokens, seed=random_seed
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
return client, config
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
|
| 1044 |
+
"""
|
| 1045 |
+
Constructs an OpenAI client for Azure/OpenAI AI Inference.
|
| 1046 |
+
"""
|
| 1047 |
+
try:
|
| 1048 |
+
key = None
|
| 1049 |
+
if in_api_key:
|
| 1050 |
+
key = in_api_key
|
| 1051 |
+
elif os.environ.get("AZURE_OPENAI_API_KEY"):
|
| 1052 |
+
key = os.environ["AZURE_OPENAI_API_KEY"]
|
| 1053 |
+
if not key:
|
| 1054 |
+
raise Warning("No Azure/OpenAI API key found.")
|
| 1055 |
+
|
| 1056 |
+
if not endpoint:
|
| 1057 |
+
endpoint = os.environ.get("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
|
| 1058 |
+
if not endpoint:
|
| 1059 |
+
# Assume using OpenAI API
|
| 1060 |
+
client = OpenAI(
|
| 1061 |
+
api_key=key,
|
| 1062 |
+
)
|
| 1063 |
+
else:
|
| 1064 |
+
# Use the provided endpoint
|
| 1065 |
+
client = OpenAI(
|
| 1066 |
+
api_key=key,
|
| 1067 |
+
base_url=f"{endpoint}",
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
return client, dict()
|
| 1071 |
+
except Exception as e:
|
| 1072 |
+
print("Error constructing Azure/OpenAI client:", e)
|
| 1073 |
+
raise
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
def call_aws_bedrock(
|
| 1077 |
+
prompt: str,
|
| 1078 |
+
system_prompt: str,
|
| 1079 |
+
temperature: float,
|
| 1080 |
+
max_tokens: int,
|
| 1081 |
+
model_choice: str,
|
| 1082 |
+
bedrock_runtime: boto3.Session.client,
|
| 1083 |
+
assistant_prefill: str = "",
|
| 1084 |
+
) -> ResponseObject:
|
| 1085 |
+
"""
|
| 1086 |
+
This function sends a request to AWS Claude with the following parameters:
|
| 1087 |
+
- prompt: The user's input prompt to be processed by the model.
|
| 1088 |
+
- system_prompt: A system-defined prompt that provides context or instructions for the model.
|
| 1089 |
+
- temperature: A value that controls the randomness of the model's output, with higher values resulting in more diverse responses.
|
| 1090 |
+
- max_tokens: The maximum number of tokens (words or characters) in the model's response.
|
| 1091 |
+
- model_choice: The specific model to use for processing the request.
|
| 1092 |
+
- bedrock_runtime: The client object for boto3 Bedrock runtime
|
| 1093 |
+
- assistant_prefill: A string indicating the text that the response should start with.
|
| 1094 |
+
|
| 1095 |
+
The function constructs the request configuration, invokes the model, extracts the response text, and returns a ResponseObject containing the text and metadata.
|
| 1096 |
+
"""
|
| 1097 |
+
|
| 1098 |
+
inference_config = {
|
| 1099 |
+
"maxTokens": max_tokens,
|
| 1100 |
+
"topP": 0.999,
|
| 1101 |
+
"temperature": temperature,
|
| 1102 |
+
}
|
| 1103 |
+
|
| 1104 |
+
# Using an assistant prefill only works for Anthropic models.
|
| 1105 |
+
if assistant_prefill and "anthropic" in model_choice:
|
| 1106 |
+
assistant_prefill_added = True
|
| 1107 |
+
messages = [
|
| 1108 |
+
{
|
| 1109 |
+
"role": "user",
|
| 1110 |
+
"content": [
|
| 1111 |
+
{"text": prompt},
|
| 1112 |
+
],
|
| 1113 |
+
},
|
| 1114 |
+
{
|
| 1115 |
+
"role": "assistant",
|
| 1116 |
+
# Pre-filling with '|'
|
| 1117 |
+
"content": [{"text": assistant_prefill}],
|
| 1118 |
+
},
|
| 1119 |
+
]
|
| 1120 |
+
else:
|
| 1121 |
+
assistant_prefill_added = False
|
| 1122 |
+
messages = [
|
| 1123 |
+
{
|
| 1124 |
+
"role": "user",
|
| 1125 |
+
"content": [
|
| 1126 |
+
{"text": prompt},
|
| 1127 |
+
],
|
| 1128 |
+
}
|
| 1129 |
+
]
|
| 1130 |
+
|
| 1131 |
+
system_prompt_list = [{"text": system_prompt}]
|
| 1132 |
+
|
| 1133 |
+
# The converse API call.
|
| 1134 |
+
api_response = bedrock_runtime.converse(
|
| 1135 |
+
modelId=model_choice,
|
| 1136 |
+
messages=messages,
|
| 1137 |
+
system=system_prompt_list,
|
| 1138 |
+
inferenceConfig=inference_config,
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
output_message = api_response["output"]["message"]
|
| 1142 |
+
|
| 1143 |
+
if "reasoningContent" in output_message["content"][0]:
|
| 1144 |
+
# Extract the reasoning text
|
| 1145 |
+
output_message["content"][0]["reasoningContent"]["reasoningText"]["text"]
|
| 1146 |
+
|
| 1147 |
+
# Extract the output text
|
| 1148 |
+
if assistant_prefill_added:
|
| 1149 |
+
text = assistant_prefill + output_message["content"][1]["text"]
|
| 1150 |
+
else:
|
| 1151 |
+
text = output_message["content"][1]["text"]
|
| 1152 |
+
else:
|
| 1153 |
+
if assistant_prefill_added:
|
| 1154 |
+
text = assistant_prefill + output_message["content"][0]["text"]
|
| 1155 |
+
else:
|
| 1156 |
+
text = output_message["content"][0]["text"]
|
| 1157 |
+
|
| 1158 |
+
# The usage statistics are neatly provided in the 'usage' key.
|
| 1159 |
+
usage = api_response["usage"]
|
| 1160 |
+
|
| 1161 |
+
# The full API response metadata is in 'ResponseMetadata' if you still need it.
|
| 1162 |
+
api_response["ResponseMetadata"]
|
| 1163 |
+
|
| 1164 |
+
# Create ResponseObject with the cleanly extracted data.
|
| 1165 |
+
response = ResponseObject(text=text, usage_metadata=usage)
|
| 1166 |
+
|
| 1167 |
+
return response
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
def call_transformers_model(
|
| 1171 |
+
prompt: str,
|
| 1172 |
+
system_prompt: str,
|
| 1173 |
+
gen_config: LlamaCPPGenerationConfig,
|
| 1174 |
+
model=None,
|
| 1175 |
+
tokenizer=None,
|
| 1176 |
+
assistant_model=None,
|
| 1177 |
+
speculative_decoding=speculative_decoding,
|
| 1178 |
+
):
|
| 1179 |
+
"""
|
| 1180 |
+
This function sends a request to a transformers model (through Unsloth) with the given prompt, system prompt, and generation configuration.
|
| 1181 |
+
"""
|
| 1182 |
+
from transformers import TextStreamer
|
| 1183 |
+
|
| 1184 |
+
if model is None:
|
| 1185 |
+
model = get_model()
|
| 1186 |
+
if tokenizer is None:
|
| 1187 |
+
tokenizer = get_tokenizer()
|
| 1188 |
+
if assistant_model is None and speculative_decoding:
|
| 1189 |
+
assistant_model = get_assistant_model()
|
| 1190 |
+
|
| 1191 |
+
if model is None or tokenizer is None:
|
| 1192 |
+
raise ValueError(
|
| 1193 |
+
"No model or tokenizer available. Either pass them as parameters or ensure LOAD_LOCAL_MODEL_AT_START is True."
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
# 1. Define the conversation as a list of dictionaries
|
| 1197 |
+
# Note: The multimodal format [{"type": "text", "text": text}] is only needed for actual multimodal models
|
| 1198 |
+
# with images/videos. For text-only content, even multimodal models expect plain strings.
|
| 1199 |
+
|
| 1200 |
+
# Always use string format for text-only content, regardless of MULTIMODAL_PROMPT_FORMAT setting
|
| 1201 |
+
# MULTIMODAL_PROMPT_FORMAT should only be used when you actually have multimodal inputs (images, etc.)
|
| 1202 |
+
if MULTIMODAL_PROMPT_FORMAT == "True":
|
| 1203 |
+
conversation = [
|
| 1204 |
+
{
|
| 1205 |
+
"role": "system",
|
| 1206 |
+
"content": [{"type": "text", "text": str(system_prompt)}],
|
| 1207 |
+
},
|
| 1208 |
+
{"role": "user", "content": [{"type": "text", "text": str(prompt)}]},
|
| 1209 |
+
]
|
| 1210 |
+
else:
|
| 1211 |
+
conversation = [
|
| 1212 |
+
{"role": "system", "content": str(system_prompt)},
|
| 1213 |
+
{"role": "user", "content": str(prompt)},
|
| 1214 |
+
]
|
| 1215 |
+
|
| 1216 |
+
# 2. Apply the chat template
|
| 1217 |
+
try:
|
| 1218 |
+
# Try applying chat template
|
| 1219 |
+
input_ids = tokenizer.apply_chat_template(
|
| 1220 |
+
conversation,
|
| 1221 |
+
add_generation_prompt=True,
|
| 1222 |
+
tokenize=True,
|
| 1223 |
+
return_tensors="pt",
|
| 1224 |
+
).to("cuda")
|
| 1225 |
+
except (TypeError, KeyError, IndexError) as e:
|
| 1226 |
+
# If chat template fails, try manual formatting
|
| 1227 |
+
print(f"Chat template failed ({e}), using manual tokenization")
|
| 1228 |
+
# Combine system and user prompts manually
|
| 1229 |
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
| 1230 |
+
# Tokenize manually with special tokens
|
| 1231 |
+
encoded = tokenizer(full_prompt, return_tensors="pt", add_special_tokens=True)
|
| 1232 |
+
if encoded is None:
|
| 1233 |
+
raise ValueError(
|
| 1234 |
+
"Tokenizer returned None - tokenizer may not be properly initialized"
|
| 1235 |
+
)
|
| 1236 |
+
if not hasattr(encoded, "input_ids") or encoded.input_ids is None:
|
| 1237 |
+
raise ValueError("Tokenizer output does not contain input_ids")
|
| 1238 |
+
input_ids = encoded.input_ids.to("cuda")
|
| 1239 |
+
except Exception as e:
|
| 1240 |
+
print("Error applying chat template:", e)
|
| 1241 |
+
import traceback
|
| 1242 |
+
|
| 1243 |
+
traceback.print_exc()
|
| 1244 |
+
raise
|
| 1245 |
+
|
| 1246 |
+
# Map LlamaCPP parameters to transformers parameters
|
| 1247 |
+
generation_kwargs = {
|
| 1248 |
+
"max_new_tokens": gen_config.max_tokens,
|
| 1249 |
+
"temperature": gen_config.temperature,
|
| 1250 |
+
"top_p": gen_config.top_p,
|
| 1251 |
+
"top_k": gen_config.top_k,
|
| 1252 |
+
"do_sample": True,
|
| 1253 |
+
#'pad_token_id': tokenizer.eos_token_id
|
| 1254 |
+
}
|
| 1255 |
+
|
| 1256 |
+
if gen_config.stream:
|
| 1257 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
| 1258 |
+
else:
|
| 1259 |
+
streamer = None
|
| 1260 |
+
|
| 1261 |
+
# Remove parameters that don't exist in transformers
|
| 1262 |
+
if hasattr(gen_config, "repeat_penalty"):
|
| 1263 |
+
generation_kwargs["repetition_penalty"] = gen_config.repeat_penalty
|
| 1264 |
+
|
| 1265 |
+
# --- Timed Inference Test ---
|
| 1266 |
+
print("\nStarting model inference...")
|
| 1267 |
+
start_time = time.time()
|
| 1268 |
+
|
| 1269 |
+
# Use speculative decoding if assistant model is available
|
| 1270 |
+
try:
|
| 1271 |
+
if speculative_decoding and assistant_model is not None:
|
| 1272 |
+
# print("Using speculative decoding with assistant model")
|
| 1273 |
+
outputs = model.generate(
|
| 1274 |
+
input_ids,
|
| 1275 |
+
assistant_model=assistant_model,
|
| 1276 |
+
**generation_kwargs,
|
| 1277 |
+
streamer=streamer,
|
| 1278 |
+
)
|
| 1279 |
+
else:
|
| 1280 |
+
# print("Generating without speculative decoding")
|
| 1281 |
+
outputs = model.generate(input_ids, **generation_kwargs, streamer=streamer)
|
| 1282 |
+
except Exception as e:
|
| 1283 |
+
error_msg = str(e)
|
| 1284 |
+
# Check if this is a CUDA compilation error
|
| 1285 |
+
if (
|
| 1286 |
+
"sm_120" in error_msg
|
| 1287 |
+
or "LLVM ERROR" in error_msg
|
| 1288 |
+
or "Cannot select" in error_msg
|
| 1289 |
+
):
|
| 1290 |
+
print("\n" + "=" * 80)
|
| 1291 |
+
print("CUDA COMPILATION ERROR DETECTED")
|
| 1292 |
+
print("=" * 80)
|
| 1293 |
+
print(
|
| 1294 |
+
"\nThe error is caused by torch.compile() trying to compile CUDA kernels"
|
| 1295 |
+
)
|
| 1296 |
+
print(
|
| 1297 |
+
"with incompatible settings. This is a known issue with certain CUDA/PyTorch"
|
| 1298 |
+
)
|
| 1299 |
+
print("combinations.\n")
|
| 1300 |
+
print(
|
| 1301 |
+
"SOLUTION: Disable model compilation by setting COMPILE_TRANSFORMERS=False"
|
| 1302 |
+
)
|
| 1303 |
+
print("in your config file (config/app_config.env).")
|
| 1304 |
+
print(
|
| 1305 |
+
"\nThe model will still work without compilation, just slightly slower."
|
| 1306 |
+
)
|
| 1307 |
+
print("=" * 80 + "\n")
|
| 1308 |
+
raise RuntimeError(
|
| 1309 |
+
"CUDA compilation error detected. Please set COMPILE_TRANSFORMERS=False "
|
| 1310 |
+
"in your config file to disable model compilation and avoid this error."
|
| 1311 |
+
) from e
|
| 1312 |
+
else:
|
| 1313 |
+
# Re-raise other errors as-is
|
| 1314 |
+
raise
|
| 1315 |
+
|
| 1316 |
+
end_time = time.time()
|
| 1317 |
+
|
| 1318 |
+
# --- Decode and Display Results ---
|
| 1319 |
+
new_tokens = outputs[0][input_ids.shape[-1] :]
|
| 1320 |
+
assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 1321 |
+
|
| 1322 |
+
num_input_tokens = input_ids.shape[
|
| 1323 |
+
-1
|
| 1324 |
+
] # This gets the sequence length (number of tokens)
|
| 1325 |
+
num_generated_tokens = len(new_tokens)
|
| 1326 |
+
duration = end_time - start_time
|
| 1327 |
+
tokens_per_second = num_generated_tokens / duration
|
| 1328 |
+
|
| 1329 |
+
print("\n--- Performance ---")
|
| 1330 |
+
print(f"Time taken: {duration:.2f} seconds")
|
| 1331 |
+
print(f"Generated tokens: {num_generated_tokens}")
|
| 1332 |
+
print(f"Tokens per second: {tokens_per_second:.2f}")
|
| 1333 |
+
|
| 1334 |
+
return assistant_reply, num_input_tokens, num_generated_tokens
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
# Function to send a request and update history
|
| 1338 |
+
def send_request(
|
| 1339 |
+
prompt: str,
|
| 1340 |
+
conversation_history: List[dict],
|
| 1341 |
+
client: ai.Client | OpenAI,
|
| 1342 |
+
config: types.GenerateContentConfig,
|
| 1343 |
+
model_choice: str,
|
| 1344 |
+
system_prompt: str,
|
| 1345 |
+
temperature: float,
|
| 1346 |
+
bedrock_runtime: boto3.Session.client,
|
| 1347 |
+
model_source: str,
|
| 1348 |
+
local_model=list(),
|
| 1349 |
+
tokenizer=None,
|
| 1350 |
+
assistant_model=None,
|
| 1351 |
+
assistant_prefill="",
|
| 1352 |
+
progress=Progress(track_tqdm=True),
|
| 1353 |
+
api_url: str = None,
|
| 1354 |
+
) -> Tuple[str, List[dict]]:
|
| 1355 |
+
"""Sends a request to a language model and manages the conversation history.
|
| 1356 |
+
|
| 1357 |
+
This function constructs the full prompt by appending the new user prompt to the conversation history,
|
| 1358 |
+
generates a response from the model, and updates the conversation history with the new prompt and response.
|
| 1359 |
+
It handles different model sources (Gemini, AWS, Local, inference-server) and includes retry logic for API calls.
|
| 1360 |
+
|
| 1361 |
+
Args:
|
| 1362 |
+
prompt (str): The user's input prompt to be sent to the model.
|
| 1363 |
+
conversation_history (List[dict]): A list of dictionaries representing the ongoing conversation.
|
| 1364 |
+
Each dictionary should have 'role' and 'parts' keys.
|
| 1365 |
+
client (ai.Client): The API client object for the chosen model (e.g., Gemini `ai.Client`, or Azure/OpenAI `OpenAI`).
|
| 1366 |
+
config (types.GenerateContentConfig): Configuration settings for content generation (e.g., Gemini `types.GenerateContentConfig`).
|
| 1367 |
+
model_choice (str): The specific model identifier to use (e.g., "gemini-pro", "claude-v2").
|
| 1368 |
+
system_prompt (str): An optional system-level instruction or context for the model.
|
| 1369 |
+
temperature (float): Controls the randomness of the model's output, with higher values leading to more diverse responses.
|
| 1370 |
+
bedrock_runtime (boto3.Session.client): The boto3 Bedrock runtime client object for AWS models.
|
| 1371 |
+
model_source (str): Indicates the source/provider of the model (e.g., "Gemini", "AWS", "Local", "inference-server").
|
| 1372 |
+
local_model (list, optional): A list containing the local model and its tokenizer (if `model_source` is "Local"). Defaults to [].
|
| 1373 |
+
tokenizer (object, optional): The tokenizer object for local models. Defaults to None.
|
| 1374 |
+
assistant_model (object, optional): An optional assistant model used for speculative decoding with local models. Defaults to None.
|
| 1375 |
+
assistant_prefill (str, optional): A string to pre-fill the assistant's response, useful for certain models like Claude. Defaults to "".
|
| 1376 |
+
progress (Progress, optional): A progress object for tracking the operation, typically from `tqdm`. Defaults to Progress(track_tqdm=True).
|
| 1377 |
+
api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
|
| 1378 |
+
|
| 1379 |
+
Returns:
|
| 1380 |
+
Tuple[str, List[dict]]: A tuple containing the model's response text and the updated conversation history.
|
| 1381 |
+
"""
|
| 1382 |
+
# Constructing the full prompt from the conversation history
|
| 1383 |
+
full_prompt = "Conversation history:\n"
|
| 1384 |
+
num_transformer_input_tokens = 0
|
| 1385 |
+
num_transformer_generated_tokens = 0
|
| 1386 |
+
response_text = ""
|
| 1387 |
+
|
| 1388 |
+
for entry in conversation_history:
|
| 1389 |
+
role = entry[
|
| 1390 |
+
"role"
|
| 1391 |
+
].capitalize() # Assuming the history is stored with 'role' and 'parts'
|
| 1392 |
+
message = " ".join(entry["parts"]) # Combining all parts of the message
|
| 1393 |
+
full_prompt += f"{role}: {message}\n"
|
| 1394 |
+
|
| 1395 |
+
# Adding the new user prompt
|
| 1396 |
+
full_prompt += f"\nUser: {prompt}"
|
| 1397 |
+
|
| 1398 |
+
# Clear any existing progress bars
|
| 1399 |
+
tqdm._instances.clear()
|
| 1400 |
+
|
| 1401 |
+
progress_bar = range(0, number_of_api_retry_attempts)
|
| 1402 |
+
|
| 1403 |
+
# Generate the model's response
|
| 1404 |
+
if "Gemini" in model_source:
|
| 1405 |
+
|
| 1406 |
+
for i in progress_bar:
|
| 1407 |
+
try:
|
| 1408 |
+
print("Calling Gemini model, attempt", i + 1)
|
| 1409 |
+
|
| 1410 |
+
response = client.models.generate_content(
|
| 1411 |
+
model=model_choice, contents=full_prompt, config=config
|
| 1412 |
+
)
|
| 1413 |
+
|
| 1414 |
+
# print("Successful call to Gemini model.")
|
| 1415 |
+
break
|
| 1416 |
+
except Exception as e:
|
| 1417 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
| 1418 |
+
print(
|
| 1419 |
+
"Call to Gemini model failed:",
|
| 1420 |
+
e,
|
| 1421 |
+
" Waiting for ",
|
| 1422 |
+
str(timeout_wait),
|
| 1423 |
+
"seconds and trying again.",
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
time.sleep(timeout_wait)
|
| 1427 |
+
|
| 1428 |
+
if i == number_of_api_retry_attempts:
|
| 1429 |
+
return (
|
| 1430 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1431 |
+
conversation_history,
|
| 1432 |
+
response_text,
|
| 1433 |
+
num_transformer_input_tokens,
|
| 1434 |
+
num_transformer_generated_tokens,
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
elif "AWS" in model_source:
|
| 1438 |
+
for i in progress_bar:
|
| 1439 |
+
try:
|
| 1440 |
+
print("Calling AWS Bedrock model, attempt", i + 1)
|
| 1441 |
+
response = call_aws_bedrock(
|
| 1442 |
+
prompt,
|
| 1443 |
+
system_prompt,
|
| 1444 |
+
temperature,
|
| 1445 |
+
max_tokens,
|
| 1446 |
+
model_choice,
|
| 1447 |
+
bedrock_runtime=bedrock_runtime,
|
| 1448 |
+
assistant_prefill=assistant_prefill,
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
# print("Successful call to Claude model.")
|
| 1452 |
+
break
|
| 1453 |
+
except Exception as e:
|
| 1454 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
| 1455 |
+
print(
|
| 1456 |
+
"Call to Bedrock model failed:",
|
| 1457 |
+
e,
|
| 1458 |
+
" Waiting for ",
|
| 1459 |
+
str(timeout_wait),
|
| 1460 |
+
"seconds and trying again.",
|
| 1461 |
+
)
|
| 1462 |
+
time.sleep(timeout_wait)
|
| 1463 |
+
|
| 1464 |
+
if i == number_of_api_retry_attempts:
|
| 1465 |
+
return (
|
| 1466 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1467 |
+
conversation_history,
|
| 1468 |
+
response_text,
|
| 1469 |
+
num_transformer_input_tokens,
|
| 1470 |
+
num_transformer_generated_tokens,
|
| 1471 |
+
)
|
| 1472 |
+
elif "Azure/OpenAI" in model_source:
|
| 1473 |
+
for i in progress_bar:
|
| 1474 |
+
try:
|
| 1475 |
+
print("Calling Azure/OpenAI inference model, attempt", i + 1)
|
| 1476 |
+
|
| 1477 |
+
messages = [
|
| 1478 |
+
{
|
| 1479 |
+
"role": "system",
|
| 1480 |
+
"content": system_prompt,
|
| 1481 |
+
},
|
| 1482 |
+
{
|
| 1483 |
+
"role": "user",
|
| 1484 |
+
"content": prompt,
|
| 1485 |
+
},
|
| 1486 |
+
]
|
| 1487 |
+
|
| 1488 |
+
response_raw = client.chat.completions.create(
|
| 1489 |
+
messages=messages,
|
| 1490 |
+
model=model_choice,
|
| 1491 |
+
temperature=temperature,
|
| 1492 |
+
max_completion_tokens=max_tokens,
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
response_text = response_raw.choices[0].message.content
|
| 1496 |
+
usage = getattr(response_raw, "usage", None)
|
| 1497 |
+
input_tokens = 0
|
| 1498 |
+
output_tokens = 0
|
| 1499 |
+
if usage is not None:
|
| 1500 |
+
input_tokens = getattr(
|
| 1501 |
+
usage, "input_tokens", getattr(usage, "prompt_tokens", 0)
|
| 1502 |
+
)
|
| 1503 |
+
output_tokens = getattr(
|
| 1504 |
+
usage, "output_tokens", getattr(usage, "completion_tokens", 0)
|
| 1505 |
+
)
|
| 1506 |
+
response = ResponseObject(
|
| 1507 |
+
text=response_text,
|
| 1508 |
+
usage_metadata={
|
| 1509 |
+
"inputTokens": input_tokens,
|
| 1510 |
+
"outputTokens": output_tokens,
|
| 1511 |
+
},
|
| 1512 |
+
)
|
| 1513 |
+
break
|
| 1514 |
+
except Exception as e:
|
| 1515 |
+
print(
|
| 1516 |
+
"Call to Azure/OpenAI model failed:",
|
| 1517 |
+
e,
|
| 1518 |
+
" Waiting for ",
|
| 1519 |
+
str(timeout_wait),
|
| 1520 |
+
"seconds and trying again.",
|
| 1521 |
+
)
|
| 1522 |
+
time.sleep(timeout_wait)
|
| 1523 |
+
if i == number_of_api_retry_attempts:
|
| 1524 |
+
return (
|
| 1525 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1526 |
+
conversation_history,
|
| 1527 |
+
response_text,
|
| 1528 |
+
num_transformer_input_tokens,
|
| 1529 |
+
num_transformer_generated_tokens,
|
| 1530 |
+
)
|
| 1531 |
+
elif "Local" in model_source:
|
| 1532 |
+
# This is the local model
|
| 1533 |
+
for i in progress_bar:
|
| 1534 |
+
try:
|
| 1535 |
+
print("Calling local model, attempt", i + 1)
|
| 1536 |
+
|
| 1537 |
+
gen_config = LlamaCPPGenerationConfig()
|
| 1538 |
+
gen_config.update_temp(temperature)
|
| 1539 |
+
|
| 1540 |
+
if USE_LLAMA_CPP == "True":
|
| 1541 |
+
response = call_llama_cpp_chatmodel(
|
| 1542 |
+
prompt, system_prompt, gen_config, model=local_model
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
else:
|
| 1546 |
+
(
|
| 1547 |
+
response,
|
| 1548 |
+
num_transformer_input_tokens,
|
| 1549 |
+
num_transformer_generated_tokens,
|
| 1550 |
+
) = call_transformers_model(
|
| 1551 |
+
prompt,
|
| 1552 |
+
system_prompt,
|
| 1553 |
+
gen_config,
|
| 1554 |
+
model=local_model,
|
| 1555 |
+
tokenizer=tokenizer,
|
| 1556 |
+
assistant_model=assistant_model,
|
| 1557 |
+
)
|
| 1558 |
+
response_text = response
|
| 1559 |
+
|
| 1560 |
+
break
|
| 1561 |
+
except Exception as e:
|
| 1562 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
| 1563 |
+
print(
|
| 1564 |
+
"Call to local model failed:",
|
| 1565 |
+
e,
|
| 1566 |
+
" Waiting for ",
|
| 1567 |
+
str(timeout_wait),
|
| 1568 |
+
"seconds and trying again.",
|
| 1569 |
+
)
|
| 1570 |
+
|
| 1571 |
+
time.sleep(timeout_wait)
|
| 1572 |
+
|
| 1573 |
+
if i == number_of_api_retry_attempts:
|
| 1574 |
+
return (
|
| 1575 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1576 |
+
conversation_history,
|
| 1577 |
+
response_text,
|
| 1578 |
+
num_transformer_input_tokens,
|
| 1579 |
+
num_transformer_generated_tokens,
|
| 1580 |
+
)
|
| 1581 |
+
elif "inference-server" in model_source:
|
| 1582 |
+
# This is the inference-server API
|
| 1583 |
+
for i in progress_bar:
|
| 1584 |
+
try:
|
| 1585 |
+
print("Calling inference-server API, attempt", i + 1)
|
| 1586 |
+
|
| 1587 |
+
if api_url is None:
|
| 1588 |
+
raise ValueError(
|
| 1589 |
+
"api_url is required when model_source is 'inference-server'"
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
gen_config = LlamaCPPGenerationConfig()
|
| 1593 |
+
gen_config.update_temp(temperature)
|
| 1594 |
+
|
| 1595 |
+
response = call_inference_server_api(
|
| 1596 |
+
prompt,
|
| 1597 |
+
system_prompt,
|
| 1598 |
+
gen_config,
|
| 1599 |
+
api_url=api_url,
|
| 1600 |
+
model_name=model_choice,
|
| 1601 |
+
)
|
| 1602 |
+
|
| 1603 |
+
break
|
| 1604 |
+
except Exception as e:
|
| 1605 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
| 1606 |
+
print(
|
| 1607 |
+
"Call to inference-server API failed:",
|
| 1608 |
+
e,
|
| 1609 |
+
" Waiting for ",
|
| 1610 |
+
str(timeout_wait),
|
| 1611 |
+
"seconds and trying again.",
|
| 1612 |
+
)
|
| 1613 |
+
|
| 1614 |
+
time.sleep(timeout_wait)
|
| 1615 |
+
|
| 1616 |
+
if i == number_of_api_retry_attempts:
|
| 1617 |
+
return (
|
| 1618 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1619 |
+
conversation_history,
|
| 1620 |
+
response_text,
|
| 1621 |
+
num_transformer_input_tokens,
|
| 1622 |
+
num_transformer_generated_tokens,
|
| 1623 |
+
)
|
| 1624 |
+
else:
|
| 1625 |
+
print("Model source not recognised")
|
| 1626 |
+
return (
|
| 1627 |
+
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
|
| 1628 |
+
conversation_history,
|
| 1629 |
+
response_text,
|
| 1630 |
+
num_transformer_input_tokens,
|
| 1631 |
+
num_transformer_generated_tokens,
|
| 1632 |
+
)
|
| 1633 |
+
|
| 1634 |
+
# Update the conversation history with the new prompt and response
|
| 1635 |
+
conversation_history.append({"role": "user", "parts": [prompt]})
|
| 1636 |
+
|
| 1637 |
+
# Check if is a LLama.cpp model response or inference-server response
|
| 1638 |
+
if isinstance(response, ResponseObject):
|
| 1639 |
+
response_text = response.text
|
| 1640 |
+
elif "choices" in response: # LLama.cpp model response or inference-server response
|
| 1641 |
+
if "gpt-oss" in model_choice:
|
| 1642 |
+
response_text = response["choices"][0]["message"]["content"].split(
|
| 1643 |
+
"<|start|>assistant<|channel|>final<|message|>"
|
| 1644 |
+
)[1]
|
| 1645 |
+
else:
|
| 1646 |
+
response_text = response["choices"][0]["message"]["content"]
|
| 1647 |
+
elif model_source == "Gemini":
|
| 1648 |
+
response_text = response.text
|
| 1649 |
+
else: # Assume transformers model response
|
| 1650 |
+
if "gpt-oss" in model_choice:
|
| 1651 |
+
response_text = response.split(
|
| 1652 |
+
"<|start|>assistant<|channel|>final<|message|>"
|
| 1653 |
+
)[1]
|
| 1654 |
+
else:
|
| 1655 |
+
response_text = response
|
| 1656 |
+
|
| 1657 |
+
# Replace multiple spaces with single space
|
| 1658 |
+
response_text = re.sub(r" {2,}", " ", response_text)
|
| 1659 |
+
response_text = response_text.strip()
|
| 1660 |
+
|
| 1661 |
+
conversation_history.append({"role": "assistant", "parts": [response_text]})
|
| 1662 |
+
|
| 1663 |
+
return (
|
| 1664 |
+
response,
|
| 1665 |
+
conversation_history,
|
| 1666 |
+
response_text,
|
| 1667 |
+
num_transformer_input_tokens,
|
| 1668 |
+
num_transformer_generated_tokens,
|
| 1669 |
+
)
|
| 1670 |
+
|
| 1671 |
+
|
| 1672 |
+
def process_requests(
|
| 1673 |
+
prompts: List[str],
|
| 1674 |
+
system_prompt: str,
|
| 1675 |
+
conversation_history: List[dict],
|
| 1676 |
+
whole_conversation: List[str],
|
| 1677 |
+
whole_conversation_metadata: List[str],
|
| 1678 |
+
client: ai.Client | OpenAI,
|
| 1679 |
+
config: types.GenerateContentConfig,
|
| 1680 |
+
model_choice: str,
|
| 1681 |
+
temperature: float,
|
| 1682 |
+
bedrock_runtime: boto3.Session.client,
|
| 1683 |
+
model_source: str,
|
| 1684 |
+
batch_no: int = 1,
|
| 1685 |
+
local_model=list(),
|
| 1686 |
+
tokenizer=None,
|
| 1687 |
+
assistant_model=None,
|
| 1688 |
+
master: bool = False,
|
| 1689 |
+
assistant_prefill="",
|
| 1690 |
+
api_url: str = None,
|
| 1691 |
+
) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
| 1692 |
+
"""
|
| 1693 |
+
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
| 1694 |
+
|
| 1695 |
+
Args:
|
| 1696 |
+
prompts (List[str]): A list of prompts to be processed.
|
| 1697 |
+
system_prompt (str): The system prompt.
|
| 1698 |
+
conversation_history (List[dict]): The history of the conversation.
|
| 1699 |
+
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1700 |
+
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1701 |
+
client (object): The client to use for processing the prompts, from either Gemini or OpenAI client.
|
| 1702 |
+
config (dict): Configuration for the model.
|
| 1703 |
+
model_choice (str): The choice of model to use.
|
| 1704 |
+
temperature (float): The temperature parameter for the model.
|
| 1705 |
+
model_source (str): Source of the model, whether local, AWS, Gemini, or inference-server
|
| 1706 |
+
batch_no (int): Batch number of the large language model request.
|
| 1707 |
+
local_model: Local gguf model (if loaded)
|
| 1708 |
+
master (bool): Is this request for the master table.
|
| 1709 |
+
assistant_prefill (str, optional): Is there a prefill for the assistant response. Currently only working for AWS model calls
|
| 1710 |
+
bedrock_runtime: The client object for boto3 Bedrock runtime
|
| 1711 |
+
api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
|
| 1712 |
+
|
| 1713 |
+
Returns:
|
| 1714 |
+
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
| 1715 |
+
"""
|
| 1716 |
+
responses = list()
|
| 1717 |
+
|
| 1718 |
+
# Clear any existing progress bars
|
| 1719 |
+
tqdm._instances.clear()
|
| 1720 |
+
|
| 1721 |
+
for prompt in prompts:
|
| 1722 |
+
|
| 1723 |
+
(
|
| 1724 |
+
response,
|
| 1725 |
+
conversation_history,
|
| 1726 |
+
response_text,
|
| 1727 |
+
num_transformer_input_tokens,
|
| 1728 |
+
num_transformer_generated_tokens,
|
| 1729 |
+
) = send_request(
|
| 1730 |
+
prompt,
|
| 1731 |
+
conversation_history,
|
| 1732 |
+
client=client,
|
| 1733 |
+
config=config,
|
| 1734 |
+
model_choice=model_choice,
|
| 1735 |
+
system_prompt=system_prompt,
|
| 1736 |
+
temperature=temperature,
|
| 1737 |
+
local_model=local_model,
|
| 1738 |
+
tokenizer=tokenizer,
|
| 1739 |
+
assistant_model=assistant_model,
|
| 1740 |
+
assistant_prefill=assistant_prefill,
|
| 1741 |
+
bedrock_runtime=bedrock_runtime,
|
| 1742 |
+
model_source=model_source,
|
| 1743 |
+
api_url=api_url,
|
| 1744 |
+
)
|
| 1745 |
+
|
| 1746 |
+
responses.append(response)
|
| 1747 |
+
whole_conversation.append(system_prompt)
|
| 1748 |
+
whole_conversation.append(prompt)
|
| 1749 |
+
whole_conversation.append(response_text)
|
| 1750 |
+
|
| 1751 |
+
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 1752 |
+
|
| 1753 |
+
try:
|
| 1754 |
+
if "AWS" in model_source:
|
| 1755 |
+
output_tokens = response.usage_metadata.get("outputTokens", 0)
|
| 1756 |
+
input_tokens = response.usage_metadata.get("inputTokens", 0)
|
| 1757 |
+
|
| 1758 |
+
elif "Gemini" in model_source:
|
| 1759 |
+
output_tokens = response.usage_metadata.candidates_token_count
|
| 1760 |
+
input_tokens = response.usage_metadata.prompt_token_count
|
| 1761 |
+
|
| 1762 |
+
elif "Azure/OpenAI" in model_source:
|
| 1763 |
+
input_tokens = response.usage_metadata.get("inputTokens", 0)
|
| 1764 |
+
output_tokens = response.usage_metadata.get("outputTokens", 0)
|
| 1765 |
+
|
| 1766 |
+
elif "Local" in model_source:
|
| 1767 |
+
if USE_LLAMA_CPP == "True":
|
| 1768 |
+
output_tokens = response["usage"].get("completion_tokens", 0)
|
| 1769 |
+
input_tokens = response["usage"].get("prompt_tokens", 0)
|
| 1770 |
+
|
| 1771 |
+
if USE_LLAMA_CPP == "False":
|
| 1772 |
+
input_tokens = num_transformer_input_tokens
|
| 1773 |
+
output_tokens = num_transformer_generated_tokens
|
| 1774 |
+
|
| 1775 |
+
elif "inference-server" in model_source:
|
| 1776 |
+
# inference-server returns the same format as llama-cpp
|
| 1777 |
+
output_tokens = response["usage"].get("completion_tokens", 0)
|
| 1778 |
+
input_tokens = response["usage"].get("prompt_tokens", 0)
|
| 1779 |
+
|
| 1780 |
+
else:
|
| 1781 |
+
input_tokens = 0
|
| 1782 |
+
output_tokens = 0
|
| 1783 |
+
|
| 1784 |
+
whole_conversation_metadata.append(
|
| 1785 |
+
"input_tokens: "
|
| 1786 |
+
+ str(input_tokens)
|
| 1787 |
+
+ " output_tokens: "
|
| 1788 |
+
+ str(output_tokens)
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
except KeyError as e:
|
| 1792 |
+
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
| 1793 |
+
|
| 1794 |
+
return (
|
| 1795 |
+
responses,
|
| 1796 |
+
conversation_history,
|
| 1797 |
+
whole_conversation,
|
| 1798 |
+
whole_conversation_metadata,
|
| 1799 |
+
response_text,
|
| 1800 |
+
)
|
| 1801 |
+
|
| 1802 |
+
|
| 1803 |
+
def call_llm_with_markdown_table_checks(
|
| 1804 |
+
batch_prompts: List[str],
|
| 1805 |
+
system_prompt: str,
|
| 1806 |
+
conversation_history: List[dict],
|
| 1807 |
+
whole_conversation: List[str],
|
| 1808 |
+
whole_conversation_metadata: List[str],
|
| 1809 |
+
client: ai.Client | OpenAI,
|
| 1810 |
+
client_config: types.GenerateContentConfig,
|
| 1811 |
+
model_choice: str,
|
| 1812 |
+
temperature: float,
|
| 1813 |
+
reported_batch_no: int,
|
| 1814 |
+
local_model: object,
|
| 1815 |
+
tokenizer: object,
|
| 1816 |
+
bedrock_runtime: boto3.Session.client,
|
| 1817 |
+
model_source: str,
|
| 1818 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
|
| 1819 |
+
assistant_prefill: str = "",
|
| 1820 |
+
master: bool = False,
|
| 1821 |
+
CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
|
| 1822 |
+
random_seed: int = seed,
|
| 1823 |
+
api_url: str = None,
|
| 1824 |
+
) -> Tuple[List[ResponseObject], List[dict], List[str], List[str], str]:
|
| 1825 |
+
"""
|
| 1826 |
+
Call the large language model with checks for a valid markdown table.
|
| 1827 |
+
|
| 1828 |
+
Parameters:
|
| 1829 |
+
- batch_prompts (List[str]): A list of prompts to be processed.
|
| 1830 |
+
- system_prompt (str): The system prompt.
|
| 1831 |
+
- conversation_history (List[dict]): The history of the conversation.
|
| 1832 |
+
- whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1833 |
+
- whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1834 |
+
- client (ai.Client | OpenAI): The client object for running Gemini or Azure/OpenAI API calls.
|
| 1835 |
+
- client_config (types.GenerateContentConfig): Configuration for the model.
|
| 1836 |
+
- model_choice (str): The choice of model to use.
|
| 1837 |
+
- temperature (float): The temperature parameter for the model.
|
| 1838 |
+
- reported_batch_no (int): The reported batch number.
|
| 1839 |
+
- local_model (object): The local model to use.
|
| 1840 |
+
- tokenizer (object): The tokenizer to use.
|
| 1841 |
+
- bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
|
| 1842 |
+
- model_source (str): The source of the model, whether in AWS, Gemini, local, or inference-server.
|
| 1843 |
+
- MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
|
| 1844 |
+
- assistant_prefill (str, optional): The text to prefill the LLM response. Currently only working with AWS Claude calls.
|
| 1845 |
+
- master (bool, optional): Boolean to determine whether this call is for the master output table.
|
| 1846 |
+
- CHOSEN_LOCAL_MODEL_TYPE (str, optional): String to determine model type loaded.
|
| 1847 |
+
- random_seed (int, optional): The random seed used for LLM generation.
|
| 1848 |
+
- api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
|
| 1849 |
+
|
| 1850 |
+
Returns:
|
| 1851 |
+
- Tuple[List[ResponseObject], List[dict], List[str], List[str], str]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, the updated whole conversation metadata, and the response text.
|
| 1852 |
+
"""
|
| 1853 |
+
|
| 1854 |
+
call_temperature = temperature # This is correct now with the fixed parameter name
|
| 1855 |
+
|
| 1856 |
+
# Update Gemini config with the new temperature settings
|
| 1857 |
+
client_config = types.GenerateContentConfig(
|
| 1858 |
+
temperature=call_temperature, max_output_tokens=max_tokens, seed=random_seed
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
|
| 1862 |
+
# Process requests to large language model
|
| 1863 |
+
(
|
| 1864 |
+
responses,
|
| 1865 |
+
conversation_history,
|
| 1866 |
+
whole_conversation,
|
| 1867 |
+
whole_conversation_metadata,
|
| 1868 |
+
response_text,
|
| 1869 |
+
) = process_requests(
|
| 1870 |
+
batch_prompts,
|
| 1871 |
+
system_prompt,
|
| 1872 |
+
conversation_history,
|
| 1873 |
+
whole_conversation,
|
| 1874 |
+
whole_conversation_metadata,
|
| 1875 |
+
client,
|
| 1876 |
+
client_config,
|
| 1877 |
+
model_choice,
|
| 1878 |
+
call_temperature,
|
| 1879 |
+
bedrock_runtime,
|
| 1880 |
+
model_source,
|
| 1881 |
+
reported_batch_no,
|
| 1882 |
+
local_model,
|
| 1883 |
+
tokenizer=tokenizer,
|
| 1884 |
+
master=master,
|
| 1885 |
+
assistant_prefill=assistant_prefill,
|
| 1886 |
+
api_url=api_url,
|
| 1887 |
+
)
|
| 1888 |
+
|
| 1889 |
+
stripped_response = response_text.strip()
|
| 1890 |
+
|
| 1891 |
+
# Check if response meets our criteria (length and contains table) OR is "No change"
|
| 1892 |
+
if (
|
| 1893 |
+
len(stripped_response) > 120 and "|" in stripped_response
|
| 1894 |
+
) or stripped_response.lower().startswith("no change"):
|
| 1895 |
+
if stripped_response.lower().startswith("no change"):
|
| 1896 |
+
print(f"Attempt {attempt + 1} produced 'No change' response.")
|
| 1897 |
+
else:
|
| 1898 |
+
print(f"Attempt {attempt + 1} produced response with markdown table.")
|
| 1899 |
+
break # Success - exit loop
|
| 1900 |
+
|
| 1901 |
+
# Increase temperature for next attempt
|
| 1902 |
+
call_temperature = temperature + (0.1 * (attempt + 1))
|
| 1903 |
+
print(
|
| 1904 |
+
f"Attempt {attempt + 1} resulted in invalid table: {stripped_response}. "
|
| 1905 |
+
f"Trying again with temperature: {call_temperature}"
|
| 1906 |
+
)
|
| 1907 |
+
|
| 1908 |
+
else: # This runs if no break occurred (all attempts failed)
|
| 1909 |
+
print(
|
| 1910 |
+
f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts"
|
| 1911 |
+
)
|
| 1912 |
+
|
| 1913 |
+
return (
|
| 1914 |
+
responses,
|
| 1915 |
+
conversation_history,
|
| 1916 |
+
whole_conversation,
|
| 1917 |
+
whole_conversation_metadata,
|
| 1918 |
+
stripped_response,
|
| 1919 |
+
)
|
| 1920 |
+
|
| 1921 |
+
|
| 1922 |
+
def create_missing_references_df(
|
| 1923 |
+
basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame
|
| 1924 |
+
) -> pd.DataFrame:
|
| 1925 |
+
"""
|
| 1926 |
+
Identifies references in basic_response_df that are not present in existing_reference_df.
|
| 1927 |
+
Returns a DataFrame with the missing references and the character count of their responses.
|
| 1928 |
+
|
| 1929 |
+
Args:
|
| 1930 |
+
basic_response_df (pd.DataFrame): DataFrame containing 'Reference' and 'Response' columns.
|
| 1931 |
+
existing_reference_df (pd.DataFrame): DataFrame containing 'Response References' column.
|
| 1932 |
+
|
| 1933 |
+
Returns:
|
| 1934 |
+
pd.DataFrame: A DataFrame with 'Missing Reference' and 'Response Character Count' columns.
|
| 1935 |
+
'Response Character Count' will be 0 for empty strings and NaN for actual missing data.
|
| 1936 |
+
"""
|
| 1937 |
+
# Ensure columns are treated as strings for robust comparison
|
| 1938 |
+
existing_references_unique = (
|
| 1939 |
+
existing_reference_df["Response References"].astype(str).unique()
|
| 1940 |
+
)
|
| 1941 |
+
|
| 1942 |
+
# Step 1: Identify all rows from basic_response_df that correspond to missing references
|
| 1943 |
+
# We want the entire row to access the 'Response' column later
|
| 1944 |
+
missing_data_rows = basic_response_df[
|
| 1945 |
+
~basic_response_df["Reference"].astype(str).isin(existing_references_unique)
|
| 1946 |
+
].copy() # .copy() to avoid SettingWithCopyWarning
|
| 1947 |
+
|
| 1948 |
+
# Step 2: Create the new DataFrame
|
| 1949 |
+
# Populate the 'Missing Reference' column directly
|
| 1950 |
+
missing_df = pd.DataFrame({"Missing Reference": missing_data_rows["Reference"]})
|
| 1951 |
+
|
| 1952 |
+
# Step 3: Calculate and add 'Response Character Count'
|
| 1953 |
+
# .str.len() works on Series of strings, handling empty strings (0) and NaN (NaN)
|
| 1954 |
+
missing_df["Response Character Count"] = missing_data_rows["Response"].str.len()
|
| 1955 |
+
|
| 1956 |
+
# Optional: Add the actual response text for easier debugging/inspection if needed
|
| 1957 |
+
# missing_df['Response Text'] = missing_data_rows['Response']
|
| 1958 |
+
|
| 1959 |
+
# Reset index to have a clean, sequential index for the new DataFrame
|
| 1960 |
+
missing_df = missing_df.reset_index(drop=True)
|
| 1961 |
+
|
| 1962 |
+
return missing_df
|
| 1963 |
+
|
| 1964 |
+
|
| 1965 |
+
def calculate_tokens_from_metadata(
|
| 1966 |
+
metadata_string: str, model_choice: str, model_name_map: dict
|
| 1967 |
+
):
|
| 1968 |
+
"""
|
| 1969 |
+
Calculate the number of input and output tokens for given queries based on metadata strings.
|
| 1970 |
+
|
| 1971 |
+
Args:
|
| 1972 |
+
metadata_string (str): A string containing all relevant metadata from the string.
|
| 1973 |
+
model_choice (str): A string describing the model name
|
| 1974 |
+
model_name_map (dict): A dictionary mapping model name to source
|
| 1975 |
+
"""
|
| 1976 |
+
|
| 1977 |
+
model_name_map[model_choice]["source"]
|
| 1978 |
+
|
| 1979 |
+
# Regex to find the numbers following the keys in the "Query summary metadata" section
|
| 1980 |
+
# This ensures we get the final, aggregated totals for the whole query.
|
| 1981 |
+
input_regex = r"input_tokens: (\d+)"
|
| 1982 |
+
output_regex = r"output_tokens: (\d+)"
|
| 1983 |
+
|
| 1984 |
+
# re.findall returns a list of all matching strings (the captured groups).
|
| 1985 |
+
input_token_strings = re.findall(input_regex, metadata_string)
|
| 1986 |
+
output_token_strings = re.findall(output_regex, metadata_string)
|
| 1987 |
+
|
| 1988 |
+
# Convert the lists of strings to lists of integers and sum them up
|
| 1989 |
+
total_input_tokens = sum([int(token) for token in input_token_strings])
|
| 1990 |
+
total_output_tokens = sum([int(token) for token in output_token_strings])
|
| 1991 |
+
|
| 1992 |
+
number_of_calls = len(input_token_strings)
|
| 1993 |
+
|
| 1994 |
+
print(f"Found {number_of_calls} LLM call entries in metadata.")
|
| 1995 |
+
print("-" * 20)
|
| 1996 |
+
print(f"Total Input Tokens: {total_input_tokens}")
|
| 1997 |
+
print(f"Total Output Tokens: {total_output_tokens}")
|
| 1998 |
+
|
| 1999 |
+
return total_input_tokens, total_output_tokens, number_of_calls
|
tools/prompts.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###
|
| 2 |
+
# System prompt
|
| 3 |
+
###
|
| 4 |
+
|
| 5 |
+
generic_system_prompt = """You are a researcher analysing responses from an open text dataset. You are analysing a single column from this dataset."""
|
| 6 |
+
|
| 7 |
+
system_prompt = """You are a researcher analysing responses from an open text dataset. You are analysing a single column from this dataset called '{column_name}'. {consultation_context}"""
|
| 8 |
+
|
| 9 |
+
markdown_additional_prompt = """ You will be given a request for a markdown table. You must respond with ONLY the markdown table. Do not include any introduction, explanation, or concluding text."""
|
| 10 |
+
|
| 11 |
+
###
|
| 12 |
+
# Initial topic table prompt
|
| 13 |
+
###
|
| 14 |
+
initial_table_system_prompt = system_prompt + markdown_additional_prompt
|
| 15 |
+
|
| 16 |
+
initial_table_assistant_prefill = "|"
|
| 17 |
+
|
| 18 |
+
default_response_reference_format = "In the next column named 'Response References', list each specific Response reference number that is relevant to the Subtopic, separated by commas. Do not write any other text in this column."
|
| 19 |
+
|
| 20 |
+
initial_table_prompt = """{validate_prompt_prefix}Your task is to create one new markdown table based on open text responses in the reponse table below.
|
| 21 |
+
In the first column named 'General topic', identify general topics relevant to responses. Create as many general topics as you can.
|
| 22 |
+
In the second column named 'Subtopic', list subtopics relevant to responses. Make the subtopics as specific as possible and make sure they cover every issue mentioned. The subtopic should never be empty.
|
| 23 |
+
{sentiment_choices}{response_reference_format}
|
| 24 |
+
In the final column named 'Summary', write a summary of the subtopic based on relevant responses - highlight specific issues that appear. {add_existing_topics_summary_format}
|
| 25 |
+
Do not add any other columns. Do not add any other text to your response. Only mention topics that are relevant to at least one response.
|
| 26 |
+
|
| 27 |
+
Response table:
|
| 28 |
+
{response_table}
|
| 29 |
+
|
| 30 |
+
New table:{previous_table_introduction}{previous_table}{validate_prompt_suffix}"""
|
| 31 |
+
|
| 32 |
+
###
|
| 33 |
+
# Adding existing topics to consultation responses
|
| 34 |
+
###
|
| 35 |
+
|
| 36 |
+
add_existing_topics_system_prompt = system_prompt + markdown_additional_prompt
|
| 37 |
+
|
| 38 |
+
add_existing_topics_assistant_prefill = "|"
|
| 39 |
+
|
| 40 |
+
force_existing_topics_prompt = """Create a new markdown table. In the first column named 'Placeholder', write 'Not assessed'. In the second column named 'Subtopics', assign Topics from the above table to Responses. Assign topics only if they are very relevant to the text of the Response. The assigned Subtopics should be chosen from the topics table above, exactly as written. Do not add any new topics, or modify existing topic names."""
|
| 41 |
+
|
| 42 |
+
allow_new_topics_prompt = """Create a new markdown table. In the first column named 'General topic', and the second column named 'Subtopic', assign General Topics and Subtopics to Responses. Assign topics from the Topics table above only if they are very relevant to the text of the Response. Fill in the General topic and Subtopic for the Topic if they do not already exist. If you find a new topic that does not exist in the Topics table, add a new row to the new table. Make the General topic and Subtopic as specific as possible. The subtopic should never be blank or empty."""
|
| 43 |
+
|
| 44 |
+
force_single_topic_prompt = """ Assign each response to one single topic only."""
|
| 45 |
+
|
| 46 |
+
add_existing_topics_prompt = """{validate_prompt_prefix}Your task is to create one new markdown table, assigning responses from the Response table below to topics.
|
| 47 |
+
{topic_assignment}{force_single_topic}
|
| 48 |
+
{sentiment_choices}{response_reference_format}
|
| 49 |
+
In the final column named 'Summary', write a summary of the Subtopic based on relevant responses - highlight specific issues that appear. {add_existing_topics_summary_format}
|
| 50 |
+
Do not add any other columns. Do not add any other text to your response. Only mention topics that are relevant to at least one response.
|
| 51 |
+
|
| 52 |
+
Choose from among the following topic names to assign to the responses, only if they are directly relevant to responses from the response table below:
|
| 53 |
+
{topics}
|
| 54 |
+
|
| 55 |
+
{response_table}
|
| 56 |
+
|
| 57 |
+
New table:{previous_table_introduction}{previous_table}{validate_prompt_suffix}"""
|
| 58 |
+
|
| 59 |
+
###
|
| 60 |
+
# VALIDATION PROMPTS
|
| 61 |
+
###
|
| 62 |
+
# These are prompts used to validate previous LLM outputs, and create corrected versions of the outputs if errors are found.
|
| 63 |
+
validation_system_prompt = system_prompt
|
| 64 |
+
|
| 65 |
+
validation_prompt_prefix_default = """The following instructions were previously provided to create an output table:\n'"""
|
| 66 |
+
|
| 67 |
+
previous_table_introduction_default = (
|
| 68 |
+
"""'\n\nThe following output table was created based on the above instructions:\n"""
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
validation_prompt_suffix_default = """\n\nBased on the above information, you need to create a corrected version of the output table. Examples of issues to correct include:
|
| 72 |
+
|
| 73 |
+
- Remove rows where responses are not relevant to the assigned topic, or where responses are not relevant to any topic.
|
| 74 |
+
- Remove rows where a topic is not assigned to any specific response.
|
| 75 |
+
- If the current topic assignment does not cover all information in a response, assign responses to relevant topics from the suggested topics table, or create a new topic if necessary.
|
| 76 |
+
- Correct any false information in the summary column, which is a summary of the relevant response text.
|
| 77 |
+
{additional_validation_issues}
|
| 78 |
+
- Any other obvious errors that you can identify.
|
| 79 |
+
|
| 80 |
+
With the above issues in mind, create a new, corrected version of the markdown table below. If there are no issues to correct, write simply "No change". Return only the corrected table without additional text, or 'no change' alone."""
|
| 81 |
+
|
| 82 |
+
validation_prompt_suffix_struct_summary_default = """\n\nBased on the above information, you need to create a corrected version of the output table. Examples of issues to correct include:
|
| 83 |
+
|
| 84 |
+
- Any misspellings in the Main heading or Subheading columns
|
| 85 |
+
- Correct any false information in the summary column, which is a summary of the relevant response text.
|
| 86 |
+
{additional_validation_issues}
|
| 87 |
+
- Any other obvious errors that you can identify.
|
| 88 |
+
|
| 89 |
+
With the above issues in mind, create a new, corrected version of the markdown table below. If there are no issues to correct, write simply "No change". Return only the corrected table without additional text, or 'no change' alone."""
|
| 90 |
+
|
| 91 |
+
###
|
| 92 |
+
# SENTIMENT CHOICES
|
| 93 |
+
###
|
| 94 |
+
|
| 95 |
+
negative_neutral_positive_sentiment_prompt = (
|
| 96 |
+
"write the sentiment of the Subtopic: Negative, Neutral, or Positive"
|
| 97 |
+
)
|
| 98 |
+
negative_or_positive_sentiment_prompt = (
|
| 99 |
+
"write the sentiment of the Subtopic: Negative or Positive"
|
| 100 |
+
)
|
| 101 |
+
do_not_assess_sentiment_prompt = "write the text 'Not assessed'" # Not used anymore. Instead, the column is filled in automatically with 'Not assessed'
|
| 102 |
+
default_sentiment_prompt = (
|
| 103 |
+
"write the sentiment of the Subtopic: Negative, Neutral, or Positive"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
###
|
| 107 |
+
# STRUCTURED SUMMARY PROMPT
|
| 108 |
+
###
|
| 109 |
+
|
| 110 |
+
structured_summary_prompt = """Your task is to write a structured summary for open text responses.
|
| 111 |
+
|
| 112 |
+
Create a new markdown table based on the response table below with the headings 'Main heading', 'Subheading' and 'Summary'.
|
| 113 |
+
|
| 114 |
+
For each of the responses in the Response table, you will create a row for each summary associated with each of the Main headings and Subheadings from the Headings table. If there is no Headings table, created your own headings. In the first and second columns, write a Main heading and Subheading from the Headings table. Then in Summary, write a detailed and comprehensive summary that covers all information relevant to the Main heading and Subheading on the same row.
|
| 115 |
+
{summary_format}
|
| 116 |
+
|
| 117 |
+
Do not add any other columns. Do not add any other text to your response.
|
| 118 |
+
|
| 119 |
+
{response_table}
|
| 120 |
+
|
| 121 |
+
Headings to structure the summary are in the following table:
|
| 122 |
+
{topics}
|
| 123 |
+
|
| 124 |
+
New table:"""
|
| 125 |
+
|
| 126 |
+
###
|
| 127 |
+
# SUMMARISE TOPICS PROMPT
|
| 128 |
+
###
|
| 129 |
+
|
| 130 |
+
summary_assistant_prefill = ""
|
| 131 |
+
|
| 132 |
+
summarise_topic_descriptions_system_prompt = system_prompt
|
| 133 |
+
|
| 134 |
+
summarise_topic_descriptions_prompt = """Your task is to make a consolidated summary of the text below. {summary_format}
|
| 135 |
+
|
| 136 |
+
Return only the summary and no other text:
|
| 137 |
+
|
| 138 |
+
{summaries}
|
| 139 |
+
|
| 140 |
+
Summary:"""
|
| 141 |
+
|
| 142 |
+
single_para_summary_format_prompt = "Return a concise summary up to one paragraph long that summarises only the most important themes from the original text"
|
| 143 |
+
|
| 144 |
+
two_para_summary_format_prompt = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text"
|
| 145 |
+
|
| 146 |
+
###
|
| 147 |
+
# OVERALL SUMMARY PROMPTS
|
| 148 |
+
###
|
| 149 |
+
|
| 150 |
+
summarise_everything_system_prompt = system_prompt
|
| 151 |
+
|
| 152 |
+
summarise_everything_prompt = """Below is a table that gives an overview of the main topics from a dataset of open text responses along with a description of each topic, and the number of responses that mentioned each topic:
|
| 153 |
+
|
| 154 |
+
'{topic_summary_table}'
|
| 155 |
+
|
| 156 |
+
Your task is to summarise the above table. {summary_format}. Return only the summary and no other text.
|
| 157 |
+
|
| 158 |
+
Summary:"""
|
| 159 |
+
|
| 160 |
+
comprehensive_summary_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format the output for Excel display using: **bold text** for main headings, β’ bullet points for sub-items, and line breaks between sections. Avoid markdown symbols like # or ##."
|
| 161 |
+
|
| 162 |
+
comprehensive_summary_format_prompt_by_group = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Compare and contrast differences between the topics and themes from each Group. Format the output for Excel display using: **bold text** for main headings, β’ bullet points for sub-items, and line breaks between sections. Avoid markdown symbols like # or ##."
|
| 163 |
+
|
| 164 |
+
# Alternative Excel formatting options
|
| 165 |
+
excel_rich_text_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format for Excel using: BOLD for main headings, bullet points (β’) for sub-items, and line breaks between sections. Use simple text formatting that Excel can interpret."
|
| 166 |
+
|
| 167 |
+
excel_plain_text_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format as plain text with clear structure: use ALL CAPS for main headings, bullet points (β’) for sub-items, and line breaks between sections. Avoid any special formatting symbols."
|
| 168 |
+
|
| 169 |
+
###
|
| 170 |
+
# LLM-BASED TOPIC DEDUPLICATION PROMPTS
|
| 171 |
+
###
|
| 172 |
+
|
| 173 |
+
llm_deduplication_system_prompt = """You are an expert at analysing and consolidating topic categories. Your task is to identify semantically similar topics that should be merged together, even if they use different wording or synonyms."""
|
| 174 |
+
|
| 175 |
+
llm_deduplication_prompt = """You are given a table of topics with their General topics, Subtopics, and Sentiment classifications. Your task is to identify topics that are semantically similar and should be merged together. Only merge topics that are almost identical in terms of meaning - if in doubt, do not merge.
|
| 176 |
+
|
| 177 |
+
Analyse the following topics table and identify groups of topics that describe essentially the same concept but may use different words or phrases. For example:
|
| 178 |
+
- "Transportation issues" and "Public transport problems"
|
| 179 |
+
- "Housing costs" and "Rent prices"
|
| 180 |
+
- "Environmental concerns" and "Green issues"
|
| 181 |
+
|
| 182 |
+
Create a markdown table with the following columns:
|
| 183 |
+
1. 'Original General topic' - The current general topic name
|
| 184 |
+
2. 'Original Subtopic' - The current subtopic name
|
| 185 |
+
3. 'Original Sentiment' - The current sentiment
|
| 186 |
+
4. 'Merged General topic' - The consolidated general topic name (use the most descriptive)
|
| 187 |
+
5. 'Merged Subtopic' - The consolidated subtopic name (use the most descriptive)
|
| 188 |
+
6. 'Merged Sentiment' - The consolidated sentiment (use 'Mixed' if sentiments differ)
|
| 189 |
+
7. 'Merge Reason' - Brief explanation of why these topics should be merged
|
| 190 |
+
|
| 191 |
+
Only include rows where topics should actually be merged. If a topic has no semantic duplicates, do not include it in the table. Produce only a markdown table in the format described above. Do not add any other text to your response.
|
| 192 |
+
|
| 193 |
+
Topics to analyse:
|
| 194 |
+
{topics_table}
|
| 195 |
+
|
| 196 |
+
Merged topics table:"""
|
| 197 |
+
|
| 198 |
+
llm_deduplication_prompt_with_candidates = """You are given a table of topics with their General topics, Subtopics, and Sentiment classifications. Your task is to identify topics that are semantically similar and should be merged together, even if they use different wording.
|
| 199 |
+
|
| 200 |
+
Additionally, you have been provided with a list of candidate topics that represent preferred topic categories. When merging topics, prioritise fitting similar topics into these existing candidate categories rather than creating new ones. Only merge topics that are almost identical in terms of meaning - if in doubt, do not merge.
|
| 201 |
+
|
| 202 |
+
Analyse the following topics table and identify groups of topics that describe essentially the same concept but may use different words or phrases. For example:
|
| 203 |
+
- "Transportation issues" and "Public transport problems"
|
| 204 |
+
- "Housing costs" and "Rent prices"
|
| 205 |
+
- "Environmental concerns" and "Green issues"
|
| 206 |
+
|
| 207 |
+
When merging topics, consider the candidate topics provided below and try to map similar topics to these preferred categories when possible.
|
| 208 |
+
|
| 209 |
+
Create a markdown table with the following columns:
|
| 210 |
+
1. 'Original General topic' - The current general topic name
|
| 211 |
+
2. 'Original Subtopic' - The current subtopic name
|
| 212 |
+
3. 'Original Sentiment' - The current sentiment
|
| 213 |
+
4. 'Merged General topic' - The consolidated general topic name (prefer candidate topics when similar)
|
| 214 |
+
5. 'Merged Subtopic' - The consolidated subtopic name (prefer candidate topics when similar)
|
| 215 |
+
6. 'Merged Sentiment' - The consolidated sentiment (use 'Mixed' if sentiments differ)
|
| 216 |
+
7. 'Merge Reason' - Brief explanation of why these topics should be merged
|
| 217 |
+
|
| 218 |
+
Only include rows where topics should actually be merged. If a topic has no semantic duplicates, do not include it in the table. Produce only a markdown table in the format described above. Do not add any other text to your response.
|
| 219 |
+
|
| 220 |
+
Topics to analyse:
|
| 221 |
+
{topics_table}
|
| 222 |
+
|
| 223 |
+
Candidate topics to consider for mapping:
|
| 224 |
+
{candidate_topics_table}
|
| 225 |
+
|
| 226 |
+
Merged topics table:"""
|
| 227 |
+
|
| 228 |
+
###
|
| 229 |
+
# VERIFY EXISTING DESCRIPTIONS/TITLES - Currently not used
|
| 230 |
+
###
|
| 231 |
+
|
| 232 |
+
verify_assistant_prefill = "|"
|
| 233 |
+
|
| 234 |
+
verify_titles_system_prompt = system_prompt
|
| 235 |
+
|
| 236 |
+
verify_titles_prompt = """Response numbers alongside the Response text and assigned descriptions are shown in the table below:
|
| 237 |
+
{response_table}
|
| 238 |
+
|
| 239 |
+
The criteria for a suitable description for these responses is that they should be readable, concise, and fully encapsulate the main subject of the response.
|
| 240 |
+
|
| 241 |
+
Create a markdown table with four columns.
|
| 242 |
+
The first column is 'Response References', and should contain just the response number under consideration.
|
| 243 |
+
The second column is 'Is this a suitable description', answer the question with 'Yes' or 'No', with no other text.
|
| 244 |
+
The third column is 'Explanation', give a short explanation for your response in the second column.
|
| 245 |
+
The fourth column is 'Alternative description', suggest an alternative description for the response that meet the criteria stated above.
|
| 246 |
+
Do not add any other text to your response.
|
| 247 |
+
|
| 248 |
+
Output markdown table:"""
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
## The following didn't work well in testing and so is not currently used
|
| 252 |
+
|
| 253 |
+
create_general_topics_system_prompt = system_prompt
|
| 254 |
+
|
| 255 |
+
create_general_topics_prompt = """Subtopics known to be relevant to this dataset are shown in the following Topics table:
|
| 256 |
+
{topics}
|
| 257 |
+
|
| 258 |
+
Your task is to create a General topic name for each Subtopic. The new Topics table should have the columns 'General topic' and 'Subtopic' only. Write a 'General topic' text label relevant to the Subtopic next to it in the new table. The text label should describe the general theme of the Subtopic. Do not add any other text, thoughts, or notes to your response.
|
| 259 |
+
|
| 260 |
+
New Topics table:"""
|
windows_install_llama-cpp-python.txt
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
|
| 3 |
+
#How to build llama-cpp-python on Windows: Step-by-Step Guide
|
| 4 |
+
|
| 5 |
+
First, you need to set up a proper C++ development environment.
|
| 6 |
+
|
| 7 |
+
# Step 1: Install the C++ Compiler
|
| 8 |
+
Scroll down the page past the main programs to "Tools for Visual Studio" and download the "Build Tools for Visual Studio". This is a standalone installer that gives you the C++ compiler and libraries without installing the full Visual Studio IDE.
|
| 9 |
+
|
| 10 |
+
Run the installer. In the "Workloads" tab, check the box for "Desktop development with C++".
|
| 11 |
+
|
| 12 |
+
MSVC v143
|
| 13 |
+
C++ ATL
|
| 14 |
+
C++ Profiling tools
|
| 15 |
+
C++ CMake tools for Windows
|
| 16 |
+
C++ MFC
|
| 17 |
+
C++ Modules
|
| 18 |
+
Windows 10 SDK (10.0.20348.0)
|
| 19 |
+
|
| 20 |
+
Proceed with the installation.
|
| 21 |
+
|
| 22 |
+
Need to use 'x64 Native Tools Command Prompt for VS 2022' to install the below. Run as administrator
|
| 23 |
+
|
| 24 |
+
# Step 2: Install CMake
|
| 25 |
+
Go to the CMake download page: https://cmake.org/download
|
| 26 |
+
|
| 27 |
+
Download the latest Windows installer (e.g., cmake-x.xx.x-windows-x86_64.msi).
|
| 28 |
+
|
| 29 |
+
Run the installer. Crucially, when prompted, select the option to "Add CMake to the system PATH for all users" or "for the current user." This allows you to run cmake from any command prompt.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Step 3: (FOR CPU INFERENCE ONLY) Download and Place OpenBLAS
|
| 33 |
+
This is often the trickiest part.
|
| 34 |
+
|
| 35 |
+
Go to the OpenBLAS releases on GitHub.
|
| 36 |
+
|
| 37 |
+
Find a recent release and download the pre-compiled version for Windows. It will typically be a file named something like OpenBLAS-0.3.21-x64.zip (the version number will change). Make sure you get the 64-bit (x64) version if you are using 64-bit Python.
|
| 38 |
+
|
| 39 |
+
Create a folder somewhere easily accessible, for example, C:\libs\.
|
| 40 |
+
|
| 41 |
+
Extract the contents of the OpenBLAS zip file into that folder. Your final directory structure should look something like this:
|
| 42 |
+
|
| 43 |
+
C:\libs\OpenBLAS\
|
| 44 |
+
βββ bin\
|
| 45 |
+
βββ include\
|
| 46 |
+
βββ lib\
|
| 47 |
+
|
| 48 |
+
## 3.b. Install Chocolatey
|
| 49 |
+
https://chocolatey.org/install
|
| 50 |
+
|
| 51 |
+
Step 1: Install Chocolatey (if you don't already have it)
|
| 52 |
+
Open PowerShell as an Administrator. (Right-click the Start Menu -> "Windows PowerShell (Admin)" or "Terminal (Admin)").
|
| 53 |
+
|
| 54 |
+
Run the following command to install Chocolatey. It's a single, long line:
|
| 55 |
+
|
| 56 |
+
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
| 57 |
+
|
| 58 |
+
Once it's done, close the Administrator PowerShell window.
|
| 59 |
+
|
| 60 |
+
Step 2: Install pkg-config-lite using Chocolatey
|
| 61 |
+
IMPORTANT: Open a NEW command prompt or PowerShell window (as a regular user is fine). This is necessary so it recognises the new choco command.
|
| 62 |
+
|
| 63 |
+
Run the following command in console to install a lightweight version of pkg-config:
|
| 64 |
+
|
| 65 |
+
choco install pkgconfiglite
|
| 66 |
+
|
| 67 |
+
Approve the installation by typing Y or A if prompted.
|
| 68 |
+
|
| 69 |
+
# Step 4: Run the Installation Command
|
| 70 |
+
Now you have all the pieces. The final step is to run the command in a terminal that is aware of your new build environment.
|
| 71 |
+
|
| 72 |
+
Open the "Developer Command Prompt for VS" from your Start Menu. This is important! This special command prompt automatically configures all the necessary paths for the C++ compiler.
|
| 73 |
+
|
| 74 |
+
## For CPU
|
| 75 |
+
|
| 76 |
+
set PKG_CONFIG_PATH=C:\<path-to-openblas>\OpenBLAS\lib\pkgconfig # Set this in environment variables
|
| 77 |
+
|
| 78 |
+
pip install llama-cpp-python==0.3.16 --force-reinstall --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
|
| 79 |
+
|
| 80 |
+
pip install llama-cpp-python==0.3.16 --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/Users/s_cas/libs/OpenBLAS/include;-DBLAS_LIBRARIES=C:/Users/s_cas/OpenBLAS/lib/libopenblas.lib";-DPKG_CONFIG_PATH=C:/users/s_cas/openblas/lib/pkgconfig"
|
| 81 |
+
|
| 82 |
+
or to make a wheel:
|
| 83 |
+
|
| 84 |
+
pip install llama-cpp-python==0.3.16 --wheel-dir dist --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
|
| 85 |
+
|
| 86 |
+
pip wheel llama-cpp-python==0.3.16 --wheel-dir dist --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/Users/<user>/libs/OpenBLAS/include;-DBLAS_LIBRARIES=C:/Users/<user>/libs/OpenBLAS/lib/libopenblas.lib"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
## With Cuda (NVIDIA GPUs only)
|
| 91 |
+
|
| 92 |
+
Make sure that the have the CUDA 12.4 toolkit for windows installed: https://developer.nvidia.com/cuda-12-4-0-download-archive
|
| 93 |
+
|
| 94 |
+
### Make sure you are using the x64 version of Developer command tools for the below, e.g. 'x64 Native Tools Command Prompt for VS 2022' ###
|
| 95 |
+
|
| 96 |
+
Use NVIDIA GPU (cuBLAS): If you have an NVIDIA GPU, using cuBLAS is often easier because the CUDA Toolkit installer handles most of the setup.
|
| 97 |
+
|
| 98 |
+
Install the NVIDIA CUDA Toolkit.
|
| 99 |
+
|
| 100 |
+
Run the install command specifying cuBLAS (for faster inference):
|
| 101 |
+
|
| 102 |
+
pip install llama-cpp-python==0.3.16 --force-reinstall --verbose -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
|
| 103 |
+
|
| 104 |
+
If you want to create a new wheel to help with future installs, you can run:
|
| 105 |
+
|
| 106 |
+
cd first to a folder that you have edit access for
|
| 107 |
+
|
| 108 |
+
pip wheel llama-cpp-python==0.3.16 --wheel-dir dist --verbose -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|