seanpedrickcase commited on
Commit
b6265c3
Β·
0 Parent(s):

Sync: Added functionality to save to S3 and save logs to DynamoDB when using cli_topics

Browse files
Files changed (49) hide show
  1. .dockerignore +27 -0
  2. .gitattributes +1 -0
  3. .github/workflows/ci.yml +196 -0
  4. .github/workflows/simple-test.yml +46 -0
  5. .github/workflows/sync_to_hf.yml +53 -0
  6. .gitignore +22 -0
  7. Dockerfile +166 -0
  8. README.md +176 -0
  9. app.py +0 -0
  10. cli_topics.py +1943 -0
  11. entrypoint.sh +18 -0
  12. example_data/case_note_headers_specific.csv +7 -0
  13. example_data/combined_case_notes.csv +19 -0
  14. example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_structured_summaries.xlsx +3 -0
  15. example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis.xlsx +3 -0
  16. example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis_grouped.xlsx +3 -0
  17. example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis.xlsx +3 -0
  18. example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis_zero_shot.xlsx +3 -0
  19. example_data/dummy_consultation_response.csv +31 -0
  20. example_data/dummy_consultation_response_themes.csv +26 -0
  21. intros/intro.txt +7 -0
  22. lambda_entrypoint.py +466 -0
  23. load_dynamo_logs.py +102 -0
  24. load_s3_logs.py +93 -0
  25. pyproject.toml +147 -0
  26. requirements.txt +29 -0
  27. requirements_cpu.txt +24 -0
  28. requirements_gpu.txt +28 -0
  29. requirements_lightweight.txt +18 -0
  30. test/README.md +87 -0
  31. test/__init__.py +5 -0
  32. test/mock_inference_server.py +225 -0
  33. test/mock_llm_calls.py +185 -0
  34. test/run_tests.py +34 -0
  35. test/test.py +1067 -0
  36. test/test_gui_only.py +189 -0
  37. tools/__init__.py +0 -0
  38. tools/auth.py +85 -0
  39. tools/aws_functions.py +387 -0
  40. tools/combine_sheets_into_xlsx.py +615 -0
  41. tools/config.py +950 -0
  42. tools/custom_csvlogger.py +333 -0
  43. tools/dedup_summaries.py +0 -0
  44. tools/example_table_outputs.py +94 -0
  45. tools/helper_functions.py +1245 -0
  46. tools/llm_api_call.py +0 -0
  47. tools/llm_funcs.py +1999 -0
  48. tools/prompts.py +260 -0
  49. windows_install_llama-cpp-python.txt +111 -0
.dockerignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pdf
2
+ *.url
3
+ *.jpg
4
+ *.png
5
+ *.ipynb
6
+ *.xls
7
+ *.xlsx
8
+ examples/*
9
+ output/*
10
+ tools/__pycache__/*
11
+ build/*
12
+ dist/*
13
+ logs/*
14
+ usage/*
15
+ feedback/*
16
+ test_code/*
17
+ test/tmp/*
18
+ unsloth_compiled_cache/*
19
+ .vscode/*
20
+ llm_topic_modelling.egg-info/*
21
+ input/
22
+ output/
23
+ logs/
24
+ usage/
25
+ feedback/
26
+ config/
27
+ tmp/
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.xlsx filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci.yml ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI/CD Pipeline
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+ #schedule:
9
+ # Run tests daily at 2 AM UTC
10
+ # - cron: '0 2 * * *'
11
+
12
+ permissions:
13
+ contents: read
14
+ actions: read
15
+ pull-requests: write
16
+ issues: write
17
+
18
+ env:
19
+ PYTHON_VERSION: "3.11"
20
+
21
+ jobs:
22
+ lint:
23
+ runs-on: ubuntu-latest
24
+ steps:
25
+ - uses: actions/checkout@v4
26
+
27
+ - name: Set up Python
28
+ uses: actions/setup-python@v4
29
+ with:
30
+ python-version: ${{ env.PYTHON_VERSION }}
31
+
32
+ - name: Install dependencies
33
+ run: |
34
+ python -m pip install --upgrade pip
35
+ pip install ruff black
36
+
37
+ - name: Run Ruff linter
38
+ run: ruff check .
39
+
40
+ - name: Run Black formatter check
41
+ run: black --check .
42
+
43
+ test-unit:
44
+ runs-on: ubuntu-latest
45
+ strategy:
46
+ matrix:
47
+ python-version: [3.11, 3.12, 3.13]
48
+
49
+ steps:
50
+ - uses: actions/checkout@v4
51
+
52
+ - name: Set up Python ${{ matrix.python-version }}
53
+ uses: actions/setup-python@v4
54
+ with:
55
+ python-version: ${{ matrix.python-version }}
56
+
57
+ - name: Cache pip dependencies
58
+ uses: actions/cache@v4
59
+ with:
60
+ path: ~/.cache/pip
61
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }}
62
+ restore-keys: |
63
+ ${{ runner.os }}-pip-
64
+
65
+ - name: Install Python dependencies
66
+ run: |
67
+ python -m pip install --upgrade pip
68
+ pip install -r requirements_lightweight.txt
69
+ pip install pytest pytest-cov pytest-html pytest-xdist
70
+
71
+ - name: Verify example data files
72
+ run: |
73
+ echo "Checking if example data directory exists:"
74
+ ls -la example_data/ || echo "example_data directory not found"
75
+ echo "Checking for specific CSV files:"
76
+ ls -la example_data/*.csv || echo "No CSV files found"
77
+
78
+ - name: Run CLI and GUI tests
79
+ run: |
80
+ cd test
81
+ python run_tests.py
82
+
83
+ - name: Run tests with pytest
84
+ run: |
85
+ pytest test/test.py test/test_gui_only.py -v --tb=short --junitxml=test-results.xml
86
+
87
+ - name: Run tests with coverage
88
+ run: |
89
+ pytest test/test.py test/test_gui_only.py --cov=. --cov-report=xml --cov-report=html --cov-report=term
90
+
91
+ - name: Upload test results
92
+ uses: actions/upload-artifact@v4
93
+ if: always()
94
+ with:
95
+ name: test-results-python-${{ matrix.python-version }}
96
+ path: |
97
+ test-results.xml
98
+ htmlcov/
99
+ coverage.xml
100
+
101
+ test-integration:
102
+ runs-on: ubuntu-latest
103
+ needs: [lint, test-unit]
104
+
105
+ steps:
106
+ - uses: actions/checkout@v4
107
+
108
+ - name: Set up Python
109
+ uses: actions/setup-python@v4
110
+ with:
111
+ python-version: ${{ env.PYTHON_VERSION }}
112
+
113
+ - name: Install dependencies
114
+ run: |
115
+ python -m pip install --upgrade pip
116
+ pip install -r requirements_lightweight.txt
117
+ pip install pytest pytest-cov
118
+
119
+ - name: Verify example data files
120
+ run: |
121
+ echo "Checking if example data directory exists:"
122
+ ls -la example_data/
123
+ echo "Checking for specific CSV files:"
124
+ ls -la example_data/*.csv || echo "No CSV files found"
125
+
126
+ - name: Run integration tests (CLI and GUI)
127
+ run: |
128
+ cd test
129
+ python run_tests.py
130
+
131
+ - name: Test CLI help
132
+ run: |
133
+ python cli_topics.py --help
134
+
135
+ - name: Test CLI version
136
+ run: |
137
+ python -c "import sys; print(f'Python {sys.version}')"
138
+
139
+ security:
140
+ runs-on: ubuntu-latest
141
+ steps:
142
+ - uses: actions/checkout@v4
143
+
144
+ - name: Set up Python
145
+ uses: actions/setup-python@v4
146
+ with:
147
+ python-version: ${{ env.PYTHON_VERSION }}
148
+
149
+ - name: Install dependencies
150
+ run: |
151
+ python -m pip install --upgrade pip
152
+ pip install bandit
153
+
154
+ - name: Run bandit security check
155
+ run: |
156
+ bandit -r . -f json -o bandit-report.json || true
157
+
158
+ - name: Upload security report
159
+ uses: actions/upload-artifact@v4
160
+ if: always()
161
+ with:
162
+ name: security-report
163
+ path: bandit-report.json
164
+
165
+ build:
166
+ runs-on: ubuntu-latest
167
+ needs: [lint, test-unit]
168
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
169
+
170
+ steps:
171
+ - uses: actions/checkout@v4
172
+
173
+ - name: Set up Python
174
+ uses: actions/setup-python@v4
175
+ with:
176
+ python-version: ${{ env.PYTHON_VERSION }}
177
+
178
+ - name: Install build dependencies
179
+ run: |
180
+ python -m pip install --upgrade pip
181
+ pip install build twine
182
+
183
+ - name: Build package
184
+ run: |
185
+ python -m build
186
+
187
+ - name: Check package
188
+ run: |
189
+ twine check dist/*
190
+
191
+ - name: Upload build artifacts
192
+ uses: actions/upload-artifact@v4
193
+ with:
194
+ name: dist
195
+ path: dist/
196
+
.github/workflows/simple-test.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Simple Test Run
2
+
3
+ on:
4
+ push:
5
+ branches: [ dev ]
6
+ pull_request:
7
+ branches: [ dev ]
8
+
9
+ permissions:
10
+ contents: read
11
+ actions: read
12
+
13
+ jobs:
14
+ test:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+
20
+ - name: Set up Python 3.11
21
+ uses: actions/setup-python@v4
22
+ with:
23
+ python-version: "3.11"
24
+
25
+ - name: Install Python dependencies
26
+ run: |
27
+ python -m pip install --upgrade pip
28
+ pip install -r requirements_lightweight.txt
29
+ pip install pytest pytest-cov
30
+
31
+ - name: Verify example data files
32
+ run: |
33
+ echo "Checking if example data directory exists:"
34
+ ls -la example_data/ || echo "example_data directory not found"
35
+ echo "Checking for specific CSV files:"
36
+ ls -la example_data/*.csv || echo "No CSV files found"
37
+
38
+ - name: Run CLI and GUI tests
39
+ run: |
40
+ cd test
41
+ python run_tests.py
42
+
43
+ - name: Run tests with pytest
44
+ run: |
45
+ pytest test/test.py test/test_gui_only.py -v --tb=short
46
+
.github/workflows/sync_to_hf.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [dev]
5
+
6
+ permissions:
7
+ contents: read
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ with:
15
+ fetch-depth: 1 # Only get the latest state
16
+ lfs: true # Download actual LFS files so they can be pushed
17
+
18
+ - name: Install Git LFS
19
+ run: git lfs install
20
+
21
+ - name: Recreate repo history (single-commit force push)
22
+ run: |
23
+ # 1. Capture the message BEFORE we delete the .git folder
24
+ COMMIT_MSG=$(git log -1 --pretty=%B)
25
+ echo "Syncing commit message: $COMMIT_MSG"
26
+
27
+ # 2. DELETE the .git folder.
28
+ # This turns the repo into a standard folder of files.
29
+ rm -rf .git
30
+
31
+ # 3. Re-initialize a brand new git repo
32
+ git init -b main
33
+ git config --global user.name "$HF_USERNAME"
34
+ git config --global user.email "$HF_EMAIL"
35
+
36
+ # 4. Re-install LFS (needs to be done after git init)
37
+ git lfs install
38
+
39
+ # 5. Add the remote
40
+ git remote add hf https://$HF_USERNAME:$HF_TOKEN@huggingface.co/spaces/$HF_USERNAME/$HF_REPO_ID
41
+
42
+ # 6. Add all files
43
+ # Since this is a fresh init, Git sees EVERY file as "New"
44
+ git add .
45
+
46
+ # 7. Commit and Force Push
47
+ git commit -m "Sync: $COMMIT_MSG"
48
+ git push --force hf main
49
+ env:
50
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
51
+ HF_USERNAME: ${{ secrets.HF_USERNAME }}
52
+ HF_EMAIL: ${{ secrets.HF_EMAIL }}
53
+ HF_REPO_ID: ${{ secrets.HF_REPO_ID }}
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pdf
2
+ *.url
3
+ *.jpg
4
+ *.png
5
+ *.ipynb
6
+ *.xls
7
+ *.pyc
8
+ examples/*
9
+ output/*
10
+ tools/__pycache__/*
11
+ build/*
12
+ dist/*
13
+ logs/*
14
+ usage/*
15
+ feedback/*
16
+ test_code/*
17
+ config/*
18
+ tmp/*
19
+ test/tmp/*
20
+ unsloth_compiled_cache/*
21
+ .vscode/*
22
+ llm_topic_modelling.egg-info/*
Dockerfile ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This Dockerfile is optimised for AWS ECS using Python 3.11, and assumes CPU inference with OpenBLAS for local models.
2
+ # Stage 1: Build dependencies and download models
3
+ FROM public.ecr.aws/docker/library/python:3.11.13-slim-trixie AS builder
4
+
5
+ # Install system dependencies.
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ gcc \
9
+ g++ \
10
+ cmake \
11
+ #libopenblas-dev \
12
+ pkg-config \
13
+ python3-dev \
14
+ libffi-dev \
15
+ && apt-get clean \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ WORKDIR /src
19
+
20
+ COPY requirements_lightweight.txt .
21
+
22
+ # Set environment variables for OpenBLAS - not necessary if not building from source
23
+ # ENV OPENBLAS_VERBOSE=1
24
+ # ENV CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS"
25
+
26
+ ARG INSTALL_TORCH=False
27
+ ENV INSTALL_TORCH=${INSTALL_TORCH}
28
+
29
+ RUN if [ "$INSTALL_TORCH" = "True" ]; then \
30
+ pip install --no-cache-dir --target=/install torch==2.9.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu; \
31
+ fi
32
+
33
+ ARG INSTALL_LLAMA_CPP_PYTHON=False
34
+ ENV INSTALL_LLAMA_CPP_PYTHON=${INSTALL_LLAMA_CPP_PYTHON}
35
+
36
+ RUN if [ "$INSTALL_LLAMA_CPP_PYTHON" = "True" ]; then \
37
+ pip install --no-cache-dir --target=/install https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl; \
38
+ fi
39
+
40
+ RUN pip install --no-cache-dir --target=/install -r requirements_lightweight.txt
41
+
42
+ RUN rm requirements_lightweight.txt
43
+
44
+ # ===================================================================
45
+ # Stage 2: A common 'base' for both Lambda and Gradio
46
+ # ===================================================================
47
+ FROM public.ecr.aws/docker/library/python:3.11.13-slim-trixie AS base
48
+
49
+ # Set build-time and runtime environment variable for whether to run in Gradio mode or Lambda mode
50
+ ARG APP_MODE=gradio
51
+ ENV APP_MODE=${APP_MODE}
52
+
53
+ # Install runtime system dependencies
54
+ RUN apt-get update && apt-get install -y --no-install-recommends \
55
+ libopenblas0 \
56
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
57
+
58
+ ENV APP_HOME=/home/user
59
+
60
+ # Set env variables for Gradio & other apps
61
+ ENV GRADIO_TEMP_DIR=/tmp/gradio_tmp/ \
62
+ MPLCONFIGDIR=/tmp/matplotlib_cache/ \
63
+ GRADIO_OUTPUT_FOLDER=$APP_HOME/app/output/ \
64
+ GRADIO_INPUT_FOLDER=$APP_HOME/app/input/ \
65
+ FEEDBACK_LOGS_FOLDER=$APP_HOME/app/feedback/ \
66
+ ACCESS_LOGS_FOLDER=$APP_HOME/app/logs/ \
67
+ USAGE_LOGS_FOLDER=$APP_HOME/app/usage/ \
68
+ CONFIG_FOLDER=$APP_HOME/app/config/ \
69
+ GRADIO_SERVER_NAME=0.0.0.0 \
70
+ GRADIO_SERVER_PORT=7860 \
71
+ PATH=$APP_HOME/.local/bin:$PATH \
72
+ PYTHONPATH=$APP_HOME/app \
73
+ PYTHONUNBUFFERED=1 \
74
+ PYTHONDONTWRITEBYTECODE=1 \
75
+ GRADIO_ALLOW_FLAGGING=never \
76
+ GRADIO_NUM_PORTS=1 \
77
+ GRADIO_THEME=huggingface \
78
+ SYSTEM=spaces
79
+
80
+ # Copy Python packages from the builder stage
81
+ COPY --from=builder /install /usr/local/lib/python3.11/site-packages/
82
+ COPY --from=builder /install/bin /usr/local/bin/
83
+
84
+ # Copy your application code and entrypoint
85
+ COPY . ${APP_HOME}/app
86
+ COPY entrypoint.sh ${APP_HOME}/app/entrypoint.sh
87
+ # Fix line endings and set execute permissions
88
+ RUN sed -i 's/\r$//' ${APP_HOME}/app/entrypoint.sh \
89
+ && chmod +x ${APP_HOME}/app/entrypoint.sh
90
+
91
+ WORKDIR ${APP_HOME}/app
92
+
93
+ # ===================================================================
94
+ # FINAL Stage 3: The Lambda Image (runs as root for simplicity)
95
+ # ===================================================================
96
+ FROM base AS lambda
97
+ # Set runtime ENV for Lambda mode
98
+ ENV APP_MODE=lambda
99
+ ENTRYPOINT ["/home/user/app/entrypoint.sh"]
100
+ CMD ["lambda_entrypoint.lambda_handler"]
101
+
102
+ # ===================================================================
103
+ # FINAL Stage 4: The Gradio Image (runs as a secure, non-root user)
104
+ # ===================================================================
105
+ FROM base AS gradio
106
+ # Set runtime ENV for Gradio mode
107
+ ENV APP_MODE=gradio
108
+
109
+ # Create non-root user
110
+ RUN useradd -m -u 1000 user
111
+
112
+ # Create the base application directory and set its ownership
113
+ RUN mkdir -p ${APP_HOME}/app && chown user:user ${APP_HOME}/app
114
+
115
+ # Create required sub-folders within the app directory and set their permissions
116
+ RUN mkdir -p \
117
+ ${APP_HOME}/app/output \
118
+ ${APP_HOME}/app/input \
119
+ ${APP_HOME}/app/logs \
120
+ ${APP_HOME}/app/usage \
121
+ ${APP_HOME}/app/feedback \
122
+ ${APP_HOME}/app/config \
123
+ && chown user:user \
124
+ ${APP_HOME}/app/output \
125
+ ${APP_HOME}/app/input \
126
+ ${APP_HOME}/app/logs \
127
+ ${APP_HOME}/app/usage \
128
+ ${APP_HOME}/app/feedback \
129
+ ${APP_HOME}/app/config \
130
+ && chmod 755 \
131
+ ${APP_HOME}/app/output \
132
+ ${APP_HOME}/app/input \
133
+ ${APP_HOME}/app/logs \
134
+ ${APP_HOME}/app/usage \
135
+ ${APP_HOME}/app/feedback \
136
+ ${APP_HOME}/app/config
137
+
138
+ # Now handle the /tmp directories
139
+ RUN mkdir -p /tmp/gradio_tmp /tmp/matplotlib_cache /tmp /var/tmp \
140
+ && chown user:user /tmp /var/tmp /tmp/gradio_tmp /tmp/matplotlib_cache \
141
+ && chmod 1777 /tmp /var/tmp /tmp/gradio_tmp /tmp/matplotlib_cache
142
+
143
+ # Fix apply user ownership to all files in the home directory
144
+ RUN chown -R user:user /home/user
145
+
146
+ # Set permissions for Python executable
147
+ RUN chmod 755 /usr/local/bin/python
148
+
149
+ # Declare volumes
150
+ VOLUME ["/tmp/matplotlib_cache"]
151
+ VOLUME ["/tmp/gradio_tmp"]
152
+ VOLUME ["/home/user/app/output"]
153
+ VOLUME ["/home/user/app/input"]
154
+ VOLUME ["/home/user/app/logs"]
155
+ VOLUME ["/home/user/app/usage"]
156
+ VOLUME ["/home/user/app/feedback"]
157
+ VOLUME ["/home/user/app/config"]
158
+ VOLUME ["/tmp"]
159
+ VOLUME ["/var/tmp"]
160
+
161
+ USER user
162
+
163
+ EXPOSE $GRADIO_SERVER_PORT
164
+
165
+ ENTRYPOINT ["/home/user/app/entrypoint.sh"]
166
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Large language model topic modelling
3
+ emoji: πŸ“š
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.0.2
8
+ app_file: app.py
9
+ pinned: true
10
+ license: agpl-3.0
11
+ short_description: Create thematic summaries for open text data with LLMs
12
+ ---
13
+
14
+ # Large language model topic modelling
15
+
16
+ Version: 0.6.0
17
+
18
+ Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets on the main app page, which will show you example outputs from a local model run. API keys for AWS, Azure, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
19
+
20
+ NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.
21
+
22
+ Basic use:
23
+ 1. On the front page, choose your model for inference. Gemma 3/GPT-OSS will use 'on-device' inference. Calls to Gemini or AWS will require an API key that can be input on the 'LLM and topic extraction' page.
24
+ 1. Upload a csv/xlsx/parquet file containing at least one open text column.
25
+ 2. Select the relevant open text column from the dropdown.
26
+ 3. If you have your own suggested (zero shot) topics, upload this (see examples folder for an example file)
27
+ 4. Write a one sentence description of the consultation/context of the open text.
28
+ 5. Click 'Extract topics, deduplicate, and summarise'. This will run through the whole analysis process from topic extraction, to topic deduplication, to topic-level and overall summaries.
29
+ 6. A summary xlsx file workbook will be created on the front page in the box 'Overall summary xlsx file'. This will combine all the results from the different processes into one workbook.
30
+
31
+ # Installation guide
32
+
33
+ Here is a step-by-step guide to clone the repository, create a virtual environment, and install dependencies from the relevant `requirements` file. This guide assumes you have **Git** and **Python 3.11** installed.
34
+
35
+ -----
36
+
37
+ ### Step 1: Clone the Git Repository
38
+
39
+ First, you need to copy the project files to your local machine. Navigate to the directory where you want to store the project using the `cd` (change directory) command. Then, use `git clone` with the repository's URL.
40
+
41
+ 1. **Clone the repo:**
42
+
43
+ ```bash
44
+ git clone https://github.com/seanpedrick-case/llm_topic_modelling.git
45
+ ```
46
+
47
+ 2. **Navigate into the new project folder:**
48
+
49
+ ```bash
50
+ cd llm_topic_modelling
51
+ ```
52
+ -----
53
+
54
+ ### Step 2: Create and Activate a Virtual Environment
55
+
56
+ A virtual environment is a self-contained directory that holds a specific Python interpreter and its own set of installed packages. This is crucial for isolating your project's dependencies.
57
+
58
+ NOTE: Alternatively you could also create and activate a Conda environment instead of using venv below.
59
+
60
+ 1. **Create the virtual environment:** We'll use Python's built-in `venv` module. It's common practice to name the environment folder `.venv`.
61
+
62
+ ```bash
63
+ python -m venv .venv
64
+ ```
65
+
66
+ *This command tells Python to create a new virtual environment in a folder named `.venv`.*
67
+
68
+ 2. **Activate the environment:** You must "activate" the environment to start using it. The command differs based on your operating system and shell.
69
+
70
+ * **On macOS / Linux (bash/zsh):**
71
+
72
+ ```bash
73
+ source .venv/bin/activate
74
+ ```
75
+
76
+ * **On Windows (Command Prompt):**
77
+
78
+ ```bash
79
+ .\.venv\Scripts\activate
80
+ ```
81
+
82
+ * **On Windows (PowerShell):**
83
+
84
+ ```powershell
85
+ .\.venv\Scripts\Activate.ps1
86
+ ```
87
+
88
+ You'll know it's active because your command prompt will be prefixed with `(.venv)`.
89
+
90
+ -----
91
+
92
+ ### Step 3: Install Dependencies
93
+
94
+ Now that your virtual environment is active, you can install all the required packages. Here you have two options, install from the pyproject.toml file (recommended), or install from requirements files.
95
+
96
+ 1. **Install from pyproject.toml (recommended)**
97
+
98
+ You can install the 'lightweight' version of the app to access all available cloud provider or local inference (e.g. llama server, vLLM server) APIs. This version will not allow you to run local models such as Gemma 12b or GPT-OSS-20b 'in-app', i.e. accessible from the GUI interface directly. However, you will have access to AWS, Gemma, or Azure/OpenAI models with appropriate API keys. Use the following command in your environment to install the relevant packages:
99
+
100
+ ```bash
101
+ pip install .
102
+ ```
103
+
104
+ #### Install torch (optional)
105
+
106
+ If you want to run inference with transformers with full/quantised models, and the associated Unsloth package, you can run the following command for CPU inference. For GPU inference, please refer to the requirements_gpu.txt guide, and the 'Install from a requirements file' section below:
107
+
108
+ ```bash
109
+ pip install .[torch]
110
+ ```
111
+
112
+ #### Install llama-cpp-python (optional)
113
+
114
+ You can run quantised GGUF models in-app using llama-cpp-python. However, installation of this package is not always straightforward, particularly considering that wheels are not available for the latest version apart from for linux. This package is not being updated regularly, and so support may be removed for this package in future. Long term I would advise instead looking into running GGUF models using llama-server and calling the API from this app using the lightweight version (details here: https://github.com/ggml-org/llama.cpp).
115
+
116
+ If you do want to install llama-cpp-python in app, first try the following command:
117
+
118
+ ```bash
119
+ pip install .[llamacpp]
120
+ ```
121
+
122
+ This will install the CPU version of llama-cpp-python. If you want GPU support, first I would try using pip install with specific wheels for your system, e.g. for Linux: See files in https://github.com/abetlen/llama-cpp-python/releases/tag/v0.3.16-cu124 . If you are still struggling, see here for more details on installation here: https://llama-cpp-python.readthedocs.io/en/latest
123
+
124
+ **NOTE:** A sister repository contains [llama-cpp-python 3.16 wheels for Python version 3.11/10](https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/tag/v0.1.0) so that users can avoid having to build the package from source. I also have a guide to building the package on a Windows system [here](https://github.com/seanpedrick-case/llm_topic_modelling/blob/main/windows_install_llama-cpp-python.txt).
125
+
126
+ #### Install mcp version of gradio
127
+
128
+ You can install an mcp-compatible version of gradio for this app with the following command:
129
+
130
+ ```bash
131
+ pip install .[mcp]
132
+ ```
133
+
134
+ 2. **Install from a requirements file (not recommended)**
135
+
136
+ The repo provides several requirements files that are relevant for different situations. To start, I advise installing using the **requirements_lightweight.txt** file, which installs the app with access to all cloud provider or local inference (e.g. llama server, vLLM server) APIs. This approach is much simpler as a first step, and avoids issues with potentially complicated llama-cpp-python installation and GPU management described below.
137
+
138
+ If you want to run models locally 'in app', then you have two further requirements files to choose from:
139
+
140
+ - **requirements_cpu.txt**: Used for Python 3.11 CPU-only environments. Uncomment the requirements under 'Windows' for Windows compatibility. Make sure you have [Openblas](https://github.com/OpenMathLib/OpenBLAS) installed!
141
+ - **requirements_gpu.txt**: Used for Python 3.11 GPU-enabled environments. Uncomment the requirements under 'Windows' for Windows compatibility (CUDA 12.4).
142
+
143
+ Example The below instructions will guide you in how to install the GPU-enabled version of the app for local inference.
144
+
145
+ **Install packages for local model 'in-app' inference from the requirements file:**
146
+ ```bash
147
+ pip install -r requirements_gpu.txt
148
+ ```
149
+ *This command reads every package name listed in the file and installs it into your `.venv` environment.*
150
+
151
+ NOTE: If default llama-cpp-python installation does not work when installing from the above, go into the requirements_gpu.txt file and uncomment the lines to install a wheel for llama-cpp-python 0.3.16 relevant to your system.
152
+
153
+ ### Step 4: Verify CUDA compatibility (if using a GPU environment)
154
+
155
+ Install the relevant toolkit for CUDA 12.4 from here: https://developer.nvidia.com/cuda-12-4-0-download-archive
156
+
157
+ Restart your computer
158
+
159
+ Ensure you have the latest drivers for your NVIDIA GPU. Check your current version and memory availability by running nvidia-smi
160
+
161
+ In command line, CUDA compatibility can be checked by running nvcc --version
162
+
163
+
164
+ ### Step 5: Ensure you have compatible NVIDIA drivers
165
+
166
+ Make sure you have the latest NVIDIA drivers installed on your system for your GPU (be careful in particular if using WSL that you have drivers compatible with this). Official drivers can be found here: https://www.nvidia.com/en-us/drivers
167
+
168
+ Current drivers can be found by running nvidia-smi in command line
169
+
170
+ ### Step 6: Run the app
171
+
172
+ Go to the app project directory. Run python app.py
173
+
174
+ ### Step 7: (optional) change default configuration
175
+
176
+ A number of configuration options can be seen the tools/config.py file. You can either pass in these variables as environment variables, or you can create a file in config/app_config.env to read this into the app on initialisation.
app.py ADDED
The diff for this file is too large to render. See raw diff
 
cli_topics.py ADDED
@@ -0,0 +1,1943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+ import time
5
+ import uuid
6
+ from datetime import datetime
7
+
8
+ import boto3
9
+ import botocore
10
+ import pandas as pd
11
+
12
+ from tools.aws_functions import download_file_from_s3, export_outputs_to_s3
13
+ from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
14
+ from tools.config import (
15
+ API_URL,
16
+ AWS_ACCESS_KEY,
17
+ AWS_REGION,
18
+ AWS_SECRET_KEY,
19
+ AZURE_OPENAI_API_KEY,
20
+ AZURE_OPENAI_INFERENCE_ENDPOINT,
21
+ BATCH_SIZE_DEFAULT,
22
+ CHOSEN_INFERENCE_SERVER_MODEL,
23
+ CSV_USAGE_LOG_HEADERS,
24
+ DEDUPLICATION_THRESHOLD,
25
+ DEFAULT_COST_CODE,
26
+ DEFAULT_SAMPLED_SUMMARIES,
27
+ DYNAMODB_USAGE_LOG_HEADERS,
28
+ GEMINI_API_KEY,
29
+ GRADIO_TEMP_DIR,
30
+ HF_TOKEN,
31
+ INPUT_FOLDER,
32
+ LLM_MAX_NEW_TOKENS,
33
+ LLM_SEED,
34
+ LLM_TEMPERATURE,
35
+ MAX_TIME_FOR_LOOP,
36
+ OUTPUT_DEBUG_FILES,
37
+ OUTPUT_FOLDER,
38
+ RUN_AWS_FUNCTIONS,
39
+ S3_OUTPUTS_BUCKET,
40
+ S3_OUTPUTS_FOLDER,
41
+ SAVE_LOGS_TO_CSV,
42
+ SAVE_LOGS_TO_DYNAMODB,
43
+ SAVE_OUTPUTS_TO_S3,
44
+ SESSION_OUTPUT_FOLDER,
45
+ USAGE_LOG_DYNAMODB_TABLE_NAME,
46
+ USAGE_LOG_FILE_NAME,
47
+ USAGE_LOGS_FOLDER,
48
+ convert_string_to_boolean,
49
+ default_model_choice,
50
+ default_model_source,
51
+ model_name_map,
52
+ )
53
+ from tools.dedup_summaries import (
54
+ deduplicate_topics,
55
+ deduplicate_topics_llm,
56
+ overall_summary,
57
+ wrapper_summarise_output_topics_per_group,
58
+ )
59
+ from tools.helper_functions import (
60
+ load_in_data_file,
61
+ load_in_previous_data_files,
62
+ )
63
+ from tools.llm_api_call import (
64
+ all_in_one_pipeline,
65
+ validate_topics_wrapper,
66
+ wrapper_extract_topics_per_column_value,
67
+ )
68
+ from tools.prompts import (
69
+ add_existing_topics_prompt,
70
+ add_existing_topics_system_prompt,
71
+ initial_table_prompt,
72
+ initial_table_system_prompt,
73
+ single_para_summary_format_prompt,
74
+ two_para_summary_format_prompt,
75
+ )
76
+
77
+
78
+ def _generate_session_hash() -> str:
79
+ """Generate a unique session hash for logging purposes."""
80
+ return str(uuid.uuid4())[:8]
81
+
82
+
83
+ def _download_s3_file_if_needed(
84
+ file_path: str,
85
+ default_filename: str = "downloaded_file",
86
+ aws_access_key: str = "",
87
+ aws_secret_key: str = "",
88
+ aws_region: str = "",
89
+ ) -> str:
90
+ """
91
+ Download a file from S3 if the path starts with 's3://' or 'S3://', otherwise return the path as-is.
92
+
93
+ Args:
94
+ file_path: File path (either local or S3 URL)
95
+ default_filename: Default filename to use if S3 key doesn't have a filename
96
+ aws_access_key: AWS access key ID (optional, uses environment/config if not provided)
97
+ aws_secret_key: AWS secret access key (optional, uses environment/config if not provided)
98
+ aws_region: AWS region (optional, uses environment/config if not provided)
99
+
100
+ Returns:
101
+ Local file path (downloaded from S3 or original path)
102
+ """
103
+ if not file_path:
104
+ return file_path
105
+
106
+ # Check for S3 URL (case-insensitive)
107
+ file_path_stripped = file_path.strip()
108
+ file_path_upper = file_path_stripped.upper()
109
+ if not file_path_upper.startswith("S3://"):
110
+ return file_path
111
+
112
+ # Ensure temp directory exists
113
+ os.makedirs(GRADIO_TEMP_DIR, exist_ok=True)
114
+
115
+ # Parse S3 URL: s3://bucket/key (preserve original case for bucket/key)
116
+ # Remove 's3://' prefix (case-insensitive)
117
+ s3_path = (
118
+ file_path_stripped.split("://", 1)[1]
119
+ if "://" in file_path_stripped
120
+ else file_path_stripped
121
+ )
122
+ # Split bucket and key (first '/' separates bucket from key)
123
+ if "/" in s3_path:
124
+ bucket_name_s3, s3_key = s3_path.split("/", 1)
125
+ else:
126
+ # If no key provided, use bucket name as key (unlikely but handle it)
127
+ bucket_name_s3 = s3_path
128
+ s3_key = ""
129
+
130
+ # Get the filename from the S3 key
131
+ filename = os.path.basename(s3_key) if s3_key else bucket_name_s3
132
+ if not filename:
133
+ filename = default_filename
134
+
135
+ # Create local file path in temp directory
136
+ local_file_path = os.path.join(GRADIO_TEMP_DIR, filename)
137
+
138
+ # Download file from S3
139
+ try:
140
+ download_file_from_s3(
141
+ bucket_name=bucket_name_s3,
142
+ key=s3_key,
143
+ local_file_path=local_file_path,
144
+ aws_access_key_textbox=aws_access_key,
145
+ aws_secret_key_textbox=aws_secret_key,
146
+ aws_region_textbox=aws_region,
147
+ )
148
+ print(f"S3 file downloaded successfully: {file_path} -> {local_file_path}")
149
+ return local_file_path
150
+ except Exception as e:
151
+ print(f"Error downloading file from S3 ({file_path}): {e}")
152
+ raise Exception(f"Failed to download file from S3: {e}")
153
+
154
+
155
+ def get_username_and_folders(
156
+ username: str = "",
157
+ output_folder_textbox: str = OUTPUT_FOLDER,
158
+ input_folder_textbox: str = INPUT_FOLDER,
159
+ session_output_folder: bool = SESSION_OUTPUT_FOLDER,
160
+ ):
161
+ """Generate session hash and set up output/input folders."""
162
+ # Generate session hash for logging. Either from input user name or generated
163
+ if username:
164
+ out_session_hash = username
165
+ else:
166
+ out_session_hash = _generate_session_hash()
167
+
168
+ if session_output_folder:
169
+ output_folder = output_folder_textbox + out_session_hash + "/"
170
+ input_folder = input_folder_textbox + out_session_hash + "/"
171
+ else:
172
+ output_folder = output_folder_textbox
173
+ input_folder = input_folder_textbox
174
+
175
+ if not os.path.exists(output_folder):
176
+ os.makedirs(output_folder, exist_ok=True)
177
+ if not os.path.exists(input_folder):
178
+ os.makedirs(input_folder, exist_ok=True)
179
+
180
+ return (
181
+ out_session_hash,
182
+ output_folder,
183
+ out_session_hash,
184
+ input_folder,
185
+ )
186
+
187
+
188
+ def upload_outputs_to_s3_if_enabled(
189
+ output_files: list,
190
+ base_file_name: str = None,
191
+ session_hash: str = "",
192
+ s3_output_folder: str = S3_OUTPUTS_FOLDER,
193
+ s3_bucket: str = S3_OUTPUTS_BUCKET,
194
+ save_outputs_to_s3: bool = None,
195
+ ):
196
+ """
197
+ Upload output files to S3 if SAVE_OUTPUTS_TO_S3 is enabled.
198
+
199
+ Args:
200
+ output_files: List of output file paths to upload
201
+ base_file_name: Base file name (input file) for organizing S3 folder structure
202
+ session_hash: Session hash to include in S3 path
203
+ s3_output_folder: S3 output folder path
204
+ s3_bucket: S3 bucket name
205
+ save_outputs_to_s3: Override for SAVE_OUTPUTS_TO_S3 config (if None, uses config value)
206
+ """
207
+ # Use provided value or fall back to config
208
+ if save_outputs_to_s3 is None:
209
+ save_outputs_to_s3 = convert_string_to_boolean(SAVE_OUTPUTS_TO_S3)
210
+
211
+ if not save_outputs_to_s3:
212
+ return
213
+
214
+ if not s3_bucket:
215
+ print("Warning: S3_OUTPUTS_BUCKET not configured. Skipping S3 upload.")
216
+ return
217
+
218
+ if not output_files:
219
+ print("No output files to upload to S3.")
220
+ return
221
+
222
+ # Filter out empty/None values and ensure files exist
223
+ valid_files = []
224
+ for file_path in output_files:
225
+ if file_path and os.path.exists(file_path):
226
+ valid_files.append(file_path)
227
+ elif file_path:
228
+ print(f"Warning: Output file does not exist, skipping: {file_path}")
229
+
230
+ if not valid_files:
231
+ print("No valid output files to upload to S3.")
232
+ return
233
+
234
+ # Construct S3 output folder path
235
+ # Include session hash if provided and SESSION_OUTPUT_FOLDER is enabled
236
+ s3_folder_path = s3_output_folder or ""
237
+ if session_hash and convert_string_to_boolean(SESSION_OUTPUT_FOLDER):
238
+ if s3_folder_path and not s3_folder_path.endswith("/"):
239
+ s3_folder_path += "/"
240
+ s3_folder_path += session_hash + "/"
241
+
242
+ print(f"\nUploading {len(valid_files)} output file(s) to S3...")
243
+ try:
244
+ export_outputs_to_s3(
245
+ file_list_state=valid_files,
246
+ s3_output_folder_state_value=s3_folder_path,
247
+ save_outputs_to_s3_flag=True,
248
+ base_file_state=base_file_name,
249
+ s3_bucket=s3_bucket,
250
+ )
251
+ except Exception as e:
252
+ print(f"Warning: Failed to upload outputs to S3: {e}")
253
+
254
+
255
+ def write_usage_log(
256
+ session_hash: str,
257
+ file_name: str,
258
+ text_column: str,
259
+ model_choice: str,
260
+ conversation_metadata: str,
261
+ input_tokens: int,
262
+ output_tokens: int,
263
+ number_of_calls: int,
264
+ estimated_time_taken: float,
265
+ cost_code: str = DEFAULT_COST_CODE,
266
+ save_to_csv: bool = SAVE_LOGS_TO_CSV,
267
+ save_to_dynamodb: bool = SAVE_LOGS_TO_DYNAMODB,
268
+ include_conversation_metadata: bool = False,
269
+ ):
270
+ """
271
+ Write usage log entry to CSV file and/or DynamoDB.
272
+
273
+ Args:
274
+ session_hash: Session identifier
275
+ file_name: Name of the input file
276
+ text_column: Column name used for analysis (as list for CSV)
277
+ model_choice: LLM model used
278
+ conversation_metadata: Metadata string
279
+ input_tokens: Number of input tokens
280
+ output_tokens: Number of output tokens
281
+ number_of_calls: Number of LLM calls
282
+ estimated_time_taken: Time taken in seconds
283
+ cost_code: Cost code for tracking
284
+ save_to_csv: Whether to save to CSV
285
+ save_to_dynamodb: Whether to save to DynamoDB
286
+ include_conversation_metadata: Whether to include conversation metadata in the log
287
+ """
288
+ # Convert boolean parameters if they're strings
289
+ if isinstance(save_to_csv, str):
290
+ save_to_csv = convert_string_to_boolean(save_to_csv)
291
+ if isinstance(save_to_dynamodb, str):
292
+ save_to_dynamodb = convert_string_to_boolean(save_to_dynamodb)
293
+
294
+ # Return early if neither logging method is enabled
295
+ if not save_to_csv and not save_to_dynamodb:
296
+ return
297
+
298
+ if not conversation_metadata:
299
+ conversation_metadata = ""
300
+
301
+ # Ensure usage logs folder exists
302
+ os.makedirs(USAGE_LOGS_FOLDER, exist_ok=True)
303
+
304
+ # Construct full file path
305
+ usage_log_file_path = os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
306
+
307
+ # Prepare data row - order matches app.py component order
308
+ # session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice,
309
+ # conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num,
310
+ # number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop
311
+ data = [
312
+ session_hash,
313
+ file_name,
314
+ (
315
+ text_column
316
+ if isinstance(text_column, str)
317
+ else (text_column[0] if text_column else "")
318
+ ),
319
+ model_choice,
320
+ conversation_metadata if conversation_metadata else "",
321
+ input_tokens,
322
+ output_tokens,
323
+ number_of_calls,
324
+ estimated_time_taken,
325
+ cost_code,
326
+ ]
327
+
328
+ # Add id and timestamp
329
+ generated_id = str(uuid.uuid4())
330
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
331
+ data.extend([generated_id, timestamp])
332
+
333
+ # Use custom headers if available, otherwise use default
334
+ # Note: CSVLogger_custom uses component labels, but we need to match what collect_output_csvs_and_create_excel_output expects
335
+ if CSV_USAGE_LOG_HEADERS and len(CSV_USAGE_LOG_HEADERS) == len(data):
336
+ headers = CSV_USAGE_LOG_HEADERS
337
+ else:
338
+ # Default headers - these should match what CSVLogger_custom creates from Gradio component labels
339
+ # The components are: session_hash_textbox, original_data_file_name_textbox, in_colnames,
340
+ # model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num,
341
+ # number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop
342
+ # Since these are hidden components without labels, CSVLogger_custom uses component variable names
343
+ # or default labels. We need to match what collect_output_csvs_and_create_excel_output expects:
344
+ # "Total LLM calls", "Total input tokens", "Total output tokens"
345
+ # But the actual CSV from Gradio likely has: "Number of calls", "Input tokens", "Output tokens"
346
+ # Let's use the names that match what the Excel function expects
347
+ headers = [
348
+ "Session hash",
349
+ "Reference data file name",
350
+ "Select the open text column of interest. In an Excel file, this shows columns across all sheets.",
351
+ "Large language model for topic extraction and summarisation",
352
+ "Conversation metadata",
353
+ "Total input tokens", # Changed from "Input tokens" to match Excel function
354
+ "Total output tokens", # Changed from "Output tokens" to match Excel function
355
+ "Total LLM calls", # Changed from "Number of calls" to match Excel function
356
+ "Estimated time taken (seconds)",
357
+ "Cost code",
358
+ "id",
359
+ "timestamp",
360
+ ]
361
+
362
+ # Write to CSV if enabled
363
+ if save_to_csv:
364
+ # Ensure usage logs folder exists
365
+ os.makedirs(USAGE_LOGS_FOLDER, exist_ok=True)
366
+
367
+ # Construct full file path
368
+ usage_log_file_path = os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
369
+
370
+ # Write to CSV
371
+ file_exists = os.path.exists(usage_log_file_path)
372
+ with open(
373
+ usage_log_file_path, "a", newline="", encoding="utf-8-sig"
374
+ ) as csvfile:
375
+ writer = csv.writer(csvfile)
376
+ if not file_exists:
377
+ # Write headers if file doesn't exist
378
+ writer.writerow(headers)
379
+ writer.writerow(data)
380
+
381
+ # Write to DynamoDB if enabled
382
+ if save_to_dynamodb:
383
+ # DynamoDB logging implementation
384
+ print("Saving to DynamoDB")
385
+
386
+ try:
387
+ # Connect to DynamoDB
388
+ if RUN_AWS_FUNCTIONS == "1":
389
+ try:
390
+ print("Connecting to DynamoDB via existing SSO connection")
391
+ dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
392
+ dynamodb.meta.client.list_tables()
393
+ except Exception as e:
394
+ print("No SSO credentials found:", e)
395
+ if AWS_ACCESS_KEY and AWS_SECRET_KEY:
396
+ print("Trying DynamoDB credentials from environment variables")
397
+ dynamodb = boto3.resource(
398
+ "dynamodb",
399
+ aws_access_key_id=AWS_ACCESS_KEY,
400
+ aws_secret_access_key=AWS_SECRET_KEY,
401
+ region_name=AWS_REGION,
402
+ )
403
+ else:
404
+ raise Exception(
405
+ "AWS credentials for DynamoDB logging not found"
406
+ )
407
+ else:
408
+ raise Exception("AWS credentials for DynamoDB logging not found")
409
+
410
+ # Get table name from config
411
+ dynamodb_table_name = USAGE_LOG_DYNAMODB_TABLE_NAME
412
+ if not dynamodb_table_name:
413
+ raise ValueError(
414
+ "USAGE_LOG_DYNAMODB_TABLE_NAME not configured. Cannot save to DynamoDB."
415
+ )
416
+
417
+ # Determine headers for DynamoDB
418
+ # Use DYNAMODB_USAGE_LOG_HEADERS if available and matches data length,
419
+ # otherwise use CSV_USAGE_LOG_HEADERS if it matches, otherwise use default headers
420
+ # Note: headers and data are guaranteed to have the same length and include id/timestamp
421
+ if DYNAMODB_USAGE_LOG_HEADERS and len(DYNAMODB_USAGE_LOG_HEADERS) == len(
422
+ data
423
+ ):
424
+ dynamodb_headers = list(DYNAMODB_USAGE_LOG_HEADERS) # Make a copy
425
+ elif CSV_USAGE_LOG_HEADERS and len(CSV_USAGE_LOG_HEADERS) == len(data):
426
+ dynamodb_headers = list(CSV_USAGE_LOG_HEADERS) # Make a copy
427
+ else:
428
+ # Use the headers we created which are guaranteed to match data
429
+ dynamodb_headers = headers
430
+
431
+ # Check if table exists, create if it doesn't
432
+ try:
433
+ table = dynamodb.Table(dynamodb_table_name)
434
+ table.load()
435
+ except botocore.exceptions.ClientError as e:
436
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
437
+ print(
438
+ f"Table '{dynamodb_table_name}' does not exist. Creating it..."
439
+ )
440
+ attribute_definitions = [
441
+ {
442
+ "AttributeName": "id",
443
+ "AttributeType": "S",
444
+ }
445
+ ]
446
+
447
+ table = dynamodb.create_table(
448
+ TableName=dynamodb_table_name,
449
+ KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}],
450
+ AttributeDefinitions=attribute_definitions,
451
+ BillingMode="PAY_PER_REQUEST",
452
+ )
453
+ # Wait until the table exists
454
+ table.meta.client.get_waiter("table_exists").wait(
455
+ TableName=dynamodb_table_name
456
+ )
457
+ time.sleep(5)
458
+ print(f"Table '{dynamodb_table_name}' created successfully.")
459
+ else:
460
+ raise
461
+
462
+ # Prepare the DynamoDB item to upload
463
+ # Map the headers to values (headers and data should match in length)
464
+ if len(dynamodb_headers) == len(data):
465
+ item = {
466
+ header: str(value) for header, value in zip(dynamodb_headers, data)
467
+ }
468
+ else:
469
+ # Fallback: use the default headers which are guaranteed to match data
470
+ print(
471
+ f"Warning: DynamoDB headers length ({len(dynamodb_headers)}) doesn't match data length ({len(data)}). Using default headers."
472
+ )
473
+ item = {header: str(value) for header, value in zip(headers, data)}
474
+
475
+ # Upload to DynamoDB
476
+ table.put_item(Item=item)
477
+ print("Successfully uploaded log to DynamoDB")
478
+
479
+ except Exception as e:
480
+ print(f"Could not upload log to DynamoDB due to: {e}")
481
+ import traceback
482
+
483
+ traceback.print_exc()
484
+
485
+
486
+ # --- Main CLI Function ---
487
+ def main(direct_mode_args={}):
488
+ """
489
+ A unified command-line interface for topic extraction, validation, deduplication, and summarisation.
490
+
491
+ Args:
492
+ direct_mode_args (dict, optional): Dictionary of arguments for direct mode execution.
493
+ If provided, uses these instead of parsing command line arguments.
494
+ """
495
+ parser = argparse.ArgumentParser(
496
+ description="A versatile CLI for topic extraction, validation, deduplication, and summarisation using LLMs.",
497
+ formatter_class=argparse.RawTextHelpFormatter,
498
+ epilog="""
499
+ Examples:
500
+
501
+ To run these, you need to do the following:
502
+
503
+ - Open a terminal window
504
+
505
+ - CD to the app folder that contains this file (cli_topics.py)
506
+
507
+ - Load the virtual environment using either conda or venv depending on your setup
508
+
509
+ - Run one of the example commands below
510
+
511
+ - The examples below use the free Gemini 2.5 Flash Lite model, that is free with an API key that you can get from here: https://aistudio.google.com/api-keys. You can either set this or API keys for other services as an environment variable (e.g. in config/app_config.py. See the file tools/config.py for more details about variables relevant to each service) or you can set them manually at the time of the function call via command line arguments, such as the following:
512
+
513
+ Google/Gemini: --google_api_key
514
+ AWS Bedrock: --aws_access_key, --aws_secret_key, --aws_region
515
+ Hugging Face (for model download): --hf_token
516
+ Azure/OpenAI: --azure_api_key, --azure_endpoint
517
+ Inference Server endpoint for local models(e.g. llama server, vllm): --api_url
518
+
519
+ - Use --create_xlsx_output to create an Excel file combining all CSV outputs after task completion
520
+
521
+ - Look in the output/ folder to see output files:
522
+
523
+ # Topic Extraction
524
+
525
+ ## Extract topics from a CSV file with default settings:
526
+ python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note"
527
+
528
+ ## Extract topics with custom model and context:
529
+ python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note" --model_choice "gemini-2.5-flash-lite" --context "Social Care case notes for young people"
530
+
531
+ ## Extract topics with grouping:
532
+ python cli_topics.py --task extract --input_file example_data/combined_case_notes.csv --text_column "Case Note" --group_by "Client"
533
+
534
+ ## Extract topics with candidate topics (zero-shot):
535
+ python cli_topics.py --task extract --input_file example_data/dummy_consultation_response.csv --text_column "Response text" --candidate_topics example_data/dummy_consultation_response_themes.csv
536
+
537
+ # Topic Validation
538
+
539
+ ## Validate previously extracted topics:
540
+ python cli_topics.py --task validate --input_file example_data/combined_case_notes.csv --text_column "Case Note" --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv
541
+
542
+ # Deduplication
543
+
544
+ Note: you will need to change the reference to previous output files to match the exact file names created from the previous task. This includes the relative path to the app folder. Also, the function will create an xlsx output file by default. the --input_file and --text_column arguments are needed for this, unless you pass in --no_xlsx_output as seen below.
545
+
546
+ ## Deduplicate topics using fuzzy matching:
547
+ python cli_topics.py --task deduplicate --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --similarity_threshold 90 --no_xlsx_output
548
+
549
+ ## Deduplicate topics using LLM:
550
+ python cli_topics.py --task deduplicate --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --method llm --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
551
+
552
+ # Summarisation
553
+
554
+ Note: you will need to change the reference to previous output files to match the exact file names created from the previous task. This includes the relative path to the app folder. Also, the function will create an xlsx output file by default. the --input_file and --text_column arguments are needed for this, unless you pass in --no_xlsx_output as seen below.
555
+
556
+ ## Summarise topics:
557
+ python cli_topics.py --task summarise --previous_output_files output/combined_case_notes_col_Case_Note_reference_table.csv output/combined_case_notes_col_Case_Note_unique_topics.csv --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
558
+
559
+ ## Create overall summary:
560
+ python cli_topics.py --task overall_summary --previous_output_files output/combined_case_notes_col_Case_Note_unique_topics.csv --model_choice "gemini-2.5-flash-lite" --no_xlsx_output
561
+
562
+ # All-in-one pipeline
563
+
564
+ ## Run complete pipeline (extract, deduplicate, summarise):
565
+ python cli_topics.py --task all_in_one --input_file example_data/combined_case_notes.csv --text_column "Case Note" --model_choice "gemini-2.5-flash-lite"
566
+
567
+ """,
568
+ )
569
+
570
+ # --- Task Selection ---
571
+ task_group = parser.add_argument_group("Task Selection")
572
+ task_group.add_argument(
573
+ "--task",
574
+ choices=[
575
+ "extract",
576
+ "validate",
577
+ "deduplicate",
578
+ "summarise",
579
+ "overall_summary",
580
+ "all_in_one",
581
+ ],
582
+ default="extract",
583
+ help="Task to perform: extract (topic extraction), validate (validate topics), deduplicate (deduplicate topics), summarise (summarise topics), overall_summary (create overall summary), or all_in_one (complete pipeline).",
584
+ )
585
+
586
+ # --- General Arguments ---
587
+ general_group = parser.add_argument_group("General Options")
588
+ general_group.add_argument(
589
+ "--input_file",
590
+ nargs="+",
591
+ help="Path to the input file(s) to process. Separate multiple files with a space, and use quotes if there are spaces in the file name.",
592
+ )
593
+ general_group.add_argument(
594
+ "--output_dir", default=OUTPUT_FOLDER, help="Directory for all output files."
595
+ )
596
+ general_group.add_argument(
597
+ "--input_dir", default=INPUT_FOLDER, help="Directory for all input files."
598
+ )
599
+ general_group.add_argument(
600
+ "--text_column",
601
+ help="Name of the text column to process (required for extract, validate, and all_in_one tasks).",
602
+ )
603
+ general_group.add_argument(
604
+ "--previous_output_files",
605
+ nargs="+",
606
+ help="Path(s) to previous output files (reference_table and/or unique_topics files) for validate, deduplicate, summarise, and overall_summary tasks.",
607
+ )
608
+ general_group.add_argument(
609
+ "--username", default="", help="Username for the session."
610
+ )
611
+ general_group.add_argument(
612
+ "--save_to_user_folders",
613
+ default=SESSION_OUTPUT_FOLDER,
614
+ help="Whether to save to user folders or not.",
615
+ )
616
+ general_group.add_argument(
617
+ "--excel_sheets",
618
+ nargs="+",
619
+ default=list(),
620
+ help="Specific Excel sheet names to process.",
621
+ )
622
+ general_group.add_argument(
623
+ "--group_by",
624
+ help="Column name to group results by.",
625
+ )
626
+
627
+ # --- Model Configuration ---
628
+ model_group = parser.add_argument_group("Model Configuration")
629
+ model_group.add_argument(
630
+ "--model_choice",
631
+ default=default_model_choice,
632
+ help=f"LLM model to use. Default: {default_model_choice}",
633
+ )
634
+ model_group.add_argument(
635
+ "--model_source",
636
+ default=default_model_source,
637
+ help=f"Model source (e.g., 'Google', 'AWS', 'Local'). Default: {default_model_source}",
638
+ )
639
+ model_group.add_argument(
640
+ "--temperature",
641
+ type=float,
642
+ default=LLM_TEMPERATURE,
643
+ help=f"Temperature for LLM generation. Default: {LLM_TEMPERATURE}",
644
+ )
645
+ model_group.add_argument(
646
+ "--batch_size",
647
+ type=int,
648
+ default=BATCH_SIZE_DEFAULT,
649
+ help=f"Number of responses to submit in a single LLM query. Default: {BATCH_SIZE_DEFAULT}",
650
+ )
651
+ model_group.add_argument(
652
+ "--max_tokens",
653
+ type=int,
654
+ default=LLM_MAX_NEW_TOKENS,
655
+ help=f"Maximum tokens for LLM generation. Default: {LLM_MAX_NEW_TOKENS}",
656
+ )
657
+ model_group.add_argument(
658
+ "--google_api_key",
659
+ default=GEMINI_API_KEY,
660
+ help="Google API key for Gemini models.",
661
+ )
662
+ model_group.add_argument(
663
+ "--aws_access_key",
664
+ default=AWS_ACCESS_KEY,
665
+ help="AWS Access Key ID for Bedrock models.",
666
+ )
667
+ model_group.add_argument(
668
+ "--aws_secret_key",
669
+ default=AWS_SECRET_KEY,
670
+ help="AWS Secret Access Key for Bedrock models.",
671
+ )
672
+ model_group.add_argument(
673
+ "--aws_region",
674
+ default=AWS_REGION,
675
+ help="AWS region for Bedrock models.",
676
+ )
677
+ model_group.add_argument(
678
+ "--hf_token",
679
+ default=HF_TOKEN,
680
+ help="Hugging Face token for downloading gated models.",
681
+ )
682
+ model_group.add_argument(
683
+ "--azure_api_key",
684
+ default=AZURE_OPENAI_API_KEY,
685
+ help="Azure/OpenAI API key for Azure/OpenAI models.",
686
+ )
687
+ model_group.add_argument(
688
+ "--azure_endpoint",
689
+ default=AZURE_OPENAI_INFERENCE_ENDPOINT,
690
+ help="Azure Inference endpoint URL.",
691
+ )
692
+ model_group.add_argument(
693
+ "--api_url",
694
+ default=API_URL,
695
+ help=f"Inference server API URL (for local models). Default: {API_URL}",
696
+ )
697
+ model_group.add_argument(
698
+ "--inference_server_model",
699
+ default=CHOSEN_INFERENCE_SERVER_MODEL,
700
+ help=f"Inference server model name to use. Default: {CHOSEN_INFERENCE_SERVER_MODEL}",
701
+ )
702
+
703
+ # --- Topic Extraction Arguments ---
704
+ extract_group = parser.add_argument_group("Topic Extraction Options")
705
+ extract_group.add_argument(
706
+ "--context",
707
+ default="",
708
+ help="Context sentence to provide to the LLM for topic extraction.",
709
+ )
710
+ extract_group.add_argument(
711
+ "--candidate_topics",
712
+ help="Path to CSV file with candidate topics for zero-shot extraction.",
713
+ )
714
+ extract_group.add_argument(
715
+ "--force_zero_shot",
716
+ choices=["Yes", "No"],
717
+ default="No",
718
+ help="Force responses into suggested topics. Default: No",
719
+ )
720
+ extract_group.add_argument(
721
+ "--force_single_topic",
722
+ choices=["Yes", "No"],
723
+ default="No",
724
+ help="Ask the model to assign responses to only a single topic. Default: No",
725
+ )
726
+ extract_group.add_argument(
727
+ "--produce_structured_summary",
728
+ choices=["Yes", "No"],
729
+ default="No",
730
+ help="Produce structured summaries using suggested topics as headers. Default: No",
731
+ )
732
+ extract_group.add_argument(
733
+ "--sentiment",
734
+ choices=[
735
+ "Negative or Positive",
736
+ "Negative, Neutral, or Positive",
737
+ "Do not assess sentiment",
738
+ ],
739
+ default="Negative or Positive",
740
+ help="Response sentiment analysis option. Default: Negative or Positive",
741
+ )
742
+ extract_group.add_argument(
743
+ "--additional_summary_instructions",
744
+ default="",
745
+ help="Additional instructions for summary format.",
746
+ )
747
+
748
+ # --- Validation Arguments ---
749
+ validate_group = parser.add_argument_group("Topic Validation Options")
750
+ validate_group.add_argument(
751
+ "--additional_validation_issues",
752
+ default="",
753
+ help="Additional validation issues for the model to consider (bullet-point list).",
754
+ )
755
+ validate_group.add_argument(
756
+ "--show_previous_table",
757
+ choices=["Yes", "No"],
758
+ default="Yes",
759
+ help="Provide response data to validation process. Default: Yes",
760
+ )
761
+ validate_group.add_argument(
762
+ "--output_debug_files",
763
+ choices=["True", "False"],
764
+ default=OUTPUT_DEBUG_FILES,
765
+ help=f"Output debug files. Default: {OUTPUT_DEBUG_FILES}",
766
+ )
767
+ validate_group.add_argument(
768
+ "--max_time_for_loop",
769
+ type=int,
770
+ default=MAX_TIME_FOR_LOOP,
771
+ help=f"Maximum time for validation loop in seconds. Default: {MAX_TIME_FOR_LOOP}",
772
+ )
773
+
774
+ # --- Deduplication Arguments ---
775
+ dedup_group = parser.add_argument_group("Deduplication Options")
776
+ dedup_group.add_argument(
777
+ "--method",
778
+ choices=["fuzzy", "llm"],
779
+ default="fuzzy",
780
+ help="Deduplication method: fuzzy (fuzzy matching) or llm (LLM semantic matching). Default: fuzzy",
781
+ )
782
+ dedup_group.add_argument(
783
+ "--similarity_threshold",
784
+ type=int,
785
+ default=DEDUPLICATION_THRESHOLD,
786
+ help=f"Similarity threshold (0-100) for fuzzy matching. Default: {DEDUPLICATION_THRESHOLD}",
787
+ )
788
+ dedup_group.add_argument(
789
+ "--merge_sentiment",
790
+ choices=["Yes", "No"],
791
+ default="No",
792
+ help="Merge sentiment values together for duplicate subtopics. Default: No",
793
+ )
794
+ dedup_group.add_argument(
795
+ "--merge_general_topics",
796
+ choices=["Yes", "No"],
797
+ default="Yes",
798
+ help="Merge general topic values together for duplicate subtopics. Default: Yes",
799
+ )
800
+
801
+ # --- Summarisation Arguments ---
802
+ summarise_group = parser.add_argument_group("Summarisation Options")
803
+ summarise_group.add_argument(
804
+ "--summary_format",
805
+ choices=["two_paragraph", "single_paragraph"],
806
+ default="two_paragraph",
807
+ help="Summary format type. Default: two_paragraph",
808
+ )
809
+ summarise_group.add_argument(
810
+ "--sample_reference_table",
811
+ choices=["True", "False"],
812
+ default="True",
813
+ help="Sample reference table (recommended for large datasets). Default: True",
814
+ )
815
+ summarise_group.add_argument(
816
+ "--no_of_sampled_summaries",
817
+ type=int,
818
+ default=DEFAULT_SAMPLED_SUMMARIES,
819
+ help=f"Number of summaries per group. Default: {DEFAULT_SAMPLED_SUMMARIES}",
820
+ )
821
+ summarise_group.add_argument(
822
+ "--random_seed",
823
+ type=int,
824
+ default=LLM_SEED,
825
+ help=f"Random seed for sampling. Default: {LLM_SEED}",
826
+ )
827
+
828
+ # --- Output Format Arguments ---
829
+ output_group = parser.add_argument_group("Output Format Options")
830
+ output_group.add_argument(
831
+ "--no_xlsx_output",
832
+ dest="create_xlsx_output",
833
+ action="store_false",
834
+ default=True,
835
+ help="Disable creation of Excel (.xlsx) output file. By default, Excel output is created.",
836
+ )
837
+
838
+ # --- Logging Arguments ---
839
+ logging_group = parser.add_argument_group("Logging Options")
840
+ logging_group.add_argument(
841
+ "--save_logs_to_csv",
842
+ default=SAVE_LOGS_TO_CSV,
843
+ help="Save processing logs to CSV files.",
844
+ )
845
+ logging_group.add_argument(
846
+ "--save_logs_to_dynamodb",
847
+ default=SAVE_LOGS_TO_DYNAMODB,
848
+ help="Save processing logs to DynamoDB.",
849
+ )
850
+ logging_group.add_argument(
851
+ "--usage_logs_folder",
852
+ default=USAGE_LOGS_FOLDER,
853
+ help="Directory for usage log files.",
854
+ )
855
+ logging_group.add_argument(
856
+ "--cost_code",
857
+ default=DEFAULT_COST_CODE,
858
+ help="Cost code for tracking usage.",
859
+ )
860
+
861
+ # Parse arguments - either from command line or direct mode
862
+ if direct_mode_args:
863
+ # Use direct mode arguments
864
+ args = argparse.Namespace(**direct_mode_args)
865
+ else:
866
+ # Parse command line arguments
867
+ args = parser.parse_args()
868
+
869
+ # --- Handle S3 file downloads ---
870
+ # Get AWS credentials from args or fall back to config values
871
+ aws_access_key = getattr(args, "aws_access_key", None) or AWS_ACCESS_KEY or ""
872
+ aws_secret_key = getattr(args, "aws_secret_key", None) or AWS_SECRET_KEY or ""
873
+ aws_region = getattr(args, "aws_region", None) or AWS_REGION or ""
874
+
875
+ # Download input files from S3 if needed
876
+ # Note: args.input_file is typically a list (from CLI nargs="+" or from direct mode)
877
+ # but we also handle pipe-separated strings for compatibility
878
+ if args.input_file:
879
+ if isinstance(args.input_file, list):
880
+ # Handle list of files (may include S3 paths)
881
+ downloaded_files = []
882
+ for file_path in args.input_file:
883
+ downloaded_path = _download_s3_file_if_needed(
884
+ file_path,
885
+ aws_access_key=aws_access_key,
886
+ aws_secret_key=aws_secret_key,
887
+ aws_region=aws_region,
888
+ )
889
+ downloaded_files.append(downloaded_path)
890
+ args.input_file = downloaded_files
891
+ elif isinstance(args.input_file, str):
892
+ # Handle pipe-separated string (for direct mode compatibility)
893
+ if "|" in args.input_file:
894
+ file_list = [f.strip() for f in args.input_file.split("|") if f.strip()]
895
+ downloaded_files = []
896
+ for file_path in file_list:
897
+ downloaded_path = _download_s3_file_if_needed(
898
+ file_path,
899
+ aws_access_key=aws_access_key,
900
+ aws_secret_key=aws_secret_key,
901
+ aws_region=aws_region,
902
+ )
903
+ downloaded_files.append(downloaded_path)
904
+ args.input_file = downloaded_files
905
+ else:
906
+ # Single file path
907
+ args.input_file = [
908
+ _download_s3_file_if_needed(
909
+ args.input_file,
910
+ aws_access_key=aws_access_key,
911
+ aws_secret_key=aws_secret_key,
912
+ aws_region=aws_region,
913
+ )
914
+ ]
915
+
916
+ # Download candidate topics file from S3 if needed
917
+ if args.candidate_topics:
918
+ args.candidate_topics = _download_s3_file_if_needed(
919
+ args.candidate_topics,
920
+ default_filename="downloaded_candidate_topics",
921
+ aws_access_key=aws_access_key,
922
+ aws_secret_key=aws_secret_key,
923
+ aws_region=aws_region,
924
+ )
925
+
926
+ # --- Override model_choice with inference_server_model if provided ---
927
+ # If inference_server_model is explicitly provided, use it to override model_choice
928
+ # This allows users to specify which inference-server model to use
929
+ if args.inference_server_model:
930
+ # Check if the current model_choice is an inference-server model
931
+ model_source = model_name_map.get(args.model_choice, {}).get(
932
+ "source", default_model_source
933
+ )
934
+ # If model_source is "inference-server" OR if inference_server_model is explicitly provided
935
+ # (different from default), use it
936
+ if (
937
+ model_source == "inference-server"
938
+ or args.inference_server_model != CHOSEN_INFERENCE_SERVER_MODEL
939
+ ):
940
+ args.model_choice = args.inference_server_model
941
+ # Ensure the model is registered in model_name_map with inference-server source
942
+ if args.model_choice not in model_name_map:
943
+ model_name_map[args.model_choice] = {
944
+ "short_name": args.model_choice,
945
+ "source": "inference-server",
946
+ }
947
+ # Also update the model_source to ensure it's set correctly
948
+ model_name_map[args.model_choice]["source"] = "inference-server"
949
+
950
+ # --- Initial Setup ---
951
+ # Convert string boolean variables to boolean
952
+ args.save_to_user_folders = convert_string_to_boolean(args.save_to_user_folders)
953
+ args.save_logs_to_csv = convert_string_to_boolean(str(args.save_logs_to_csv))
954
+ args.save_logs_to_dynamodb = convert_string_to_boolean(
955
+ str(args.save_logs_to_dynamodb)
956
+ )
957
+ args.sample_reference_table = args.sample_reference_table == "True"
958
+ args.output_debug_files = args.output_debug_files == "True"
959
+
960
+ # Get username and folders
961
+ (
962
+ session_hash,
963
+ args.output_dir,
964
+ _,
965
+ args.input_dir,
966
+ ) = get_username_and_folders(
967
+ username=args.username,
968
+ output_folder_textbox=args.output_dir,
969
+ input_folder_textbox=args.input_dir,
970
+ session_output_folder=args.save_to_user_folders,
971
+ )
972
+
973
+ print(
974
+ f"Conducting analyses with user {args.username or session_hash}. Outputs will be saved to {args.output_dir}."
975
+ )
976
+
977
+ # --- Route to the Correct Workflow Based on Task ---
978
+
979
+ # Validate input_file requirement for tasks that need it
980
+ if args.task in ["extract", "validate", "all_in_one"] and not args.input_file:
981
+ print(f"Error: --input_file is required for '{args.task}' task.")
982
+ return
983
+
984
+ if (
985
+ args.task in ["validate", "deduplicate", "summarise", "overall_summary"]
986
+ and not args.previous_output_files
987
+ ):
988
+ print(f"Error: --previous_output_files is required for '{args.task}' task.")
989
+ return
990
+
991
+ if args.task in ["extract", "validate", "all_in_one"] and not args.text_column:
992
+ print(f"Error: --text_column is required for '{args.task}' task.")
993
+ return
994
+
995
+ start_time = time.time()
996
+
997
+ try:
998
+ # Task 1: Extract Topics
999
+ if args.task == "extract":
1000
+ print("--- Starting Topic Extraction Workflow... ---")
1001
+
1002
+ # Load data file
1003
+ if isinstance(args.input_file, str):
1004
+ args.input_file = [args.input_file]
1005
+
1006
+ file_data, file_name, total_number_of_batches = load_in_data_file(
1007
+ file_paths=args.input_file,
1008
+ in_colnames=[args.text_column],
1009
+ batch_size=args.batch_size,
1010
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1011
+ )
1012
+
1013
+ # Prepare candidate topics if provided
1014
+ candidate_topics = None
1015
+ if args.candidate_topics:
1016
+ candidate_topics = args.candidate_topics
1017
+
1018
+ # Determine summary format prompt
1019
+ summary_format_prompt = (
1020
+ two_para_summary_format_prompt
1021
+ if args.summary_format == "two_paragraph"
1022
+ else single_para_summary_format_prompt
1023
+ )
1024
+
1025
+ # Run extraction
1026
+ (
1027
+ display_markdown,
1028
+ master_topic_df_state,
1029
+ master_unique_topics_df_state,
1030
+ master_reference_df_state,
1031
+ topic_extraction_output_files,
1032
+ text_output_file_list_state,
1033
+ latest_batch_completed,
1034
+ log_files_output,
1035
+ log_files_output_list_state,
1036
+ conversation_metadata_textbox,
1037
+ estimated_time_taken_number,
1038
+ deduplication_input_files,
1039
+ summarisation_input_files,
1040
+ modifiable_unique_topics_df_state,
1041
+ modification_input_files,
1042
+ in_join_files,
1043
+ missing_df_state,
1044
+ input_tokens_num,
1045
+ output_tokens_num,
1046
+ number_of_calls_num,
1047
+ output_messages_textbox,
1048
+ logged_content_df,
1049
+ ) = wrapper_extract_topics_per_column_value(
1050
+ grouping_col=args.group_by,
1051
+ in_data_file=args.input_file,
1052
+ file_data=file_data,
1053
+ initial_existing_topics_table=pd.DataFrame(),
1054
+ initial_existing_reference_df=pd.DataFrame(),
1055
+ initial_existing_topic_summary_df=pd.DataFrame(),
1056
+ initial_unique_table_df_display_table_markdown="",
1057
+ original_file_name=file_name,
1058
+ total_number_of_batches=total_number_of_batches,
1059
+ in_api_key=args.google_api_key,
1060
+ temperature=args.temperature,
1061
+ chosen_cols=[args.text_column],
1062
+ model_choice=args.model_choice,
1063
+ candidate_topics=candidate_topics,
1064
+ initial_first_loop_state=True,
1065
+ initial_all_metadata_content_str="",
1066
+ initial_latest_batch_completed=0,
1067
+ initial_time_taken=0,
1068
+ batch_size=args.batch_size,
1069
+ context_textbox=args.context,
1070
+ sentiment_checkbox=args.sentiment,
1071
+ force_zero_shot_radio=args.force_zero_shot,
1072
+ in_excel_sheets=args.excel_sheets,
1073
+ force_single_topic_radio=args.force_single_topic,
1074
+ produce_structured_summary_radio=args.produce_structured_summary,
1075
+ aws_access_key_textbox=args.aws_access_key,
1076
+ aws_secret_key_textbox=args.aws_secret_key,
1077
+ aws_region_textbox=args.aws_region,
1078
+ hf_api_key_textbox=args.hf_token,
1079
+ azure_api_key_textbox=args.azure_api_key,
1080
+ azure_endpoint_textbox=args.azure_endpoint,
1081
+ output_folder=args.output_dir,
1082
+ existing_logged_content=list(),
1083
+ additional_instructions_summary_format=args.additional_summary_instructions,
1084
+ additional_validation_issues_provided="",
1085
+ show_previous_table="Yes",
1086
+ api_url=args.api_url if args.api_url else API_URL,
1087
+ max_tokens=args.max_tokens,
1088
+ model_name_map=model_name_map,
1089
+ max_time_for_loop=99999,
1090
+ reasoning_suffix="",
1091
+ CHOSEN_LOCAL_MODEL_TYPE="",
1092
+ output_debug_files=str(args.output_debug_files),
1093
+ model=None,
1094
+ tokenizer=None,
1095
+ assistant_model=None,
1096
+ max_rows=999999,
1097
+ )
1098
+
1099
+ end_time = time.time()
1100
+ processing_time = end_time - start_time
1101
+
1102
+ print("\n--- Topic Extraction Complete ---")
1103
+ print(f"Processing time: {processing_time:.2f} seconds")
1104
+ print(f"\nOutput files saved to: {args.output_dir}")
1105
+ if topic_extraction_output_files:
1106
+ print("Generated Files:", sorted(topic_extraction_output_files))
1107
+
1108
+ # Write usage log (before Excel creation so it can be included in Excel)
1109
+ write_usage_log(
1110
+ session_hash=session_hash,
1111
+ file_name=file_name,
1112
+ text_column=args.text_column,
1113
+ model_choice=args.model_choice,
1114
+ conversation_metadata=conversation_metadata_textbox or "",
1115
+ input_tokens=input_tokens_num or 0,
1116
+ output_tokens=output_tokens_num or 0,
1117
+ number_of_calls=number_of_calls_num or 0,
1118
+ estimated_time_taken=estimated_time_taken_number or processing_time,
1119
+ cost_code=args.cost_code,
1120
+ save_to_csv=args.save_logs_to_csv,
1121
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1122
+ )
1123
+
1124
+ # Create Excel output if requested
1125
+ xlsx_files = []
1126
+ if args.create_xlsx_output:
1127
+ print("\nCreating Excel output file...")
1128
+ try:
1129
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1130
+ in_data_files=args.input_file,
1131
+ chosen_cols=[args.text_column],
1132
+ reference_data_file_name_textbox=file_name,
1133
+ in_group_col=args.group_by,
1134
+ model_choice=args.model_choice,
1135
+ master_reference_df_state=master_reference_df_state,
1136
+ master_unique_topics_df_state=master_unique_topics_df_state,
1137
+ summarised_output_df=pd.DataFrame(), # No summaries yet
1138
+ missing_df_state=missing_df_state,
1139
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1140
+ usage_logs_location=(
1141
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1142
+ if args.save_logs_to_csv
1143
+ else ""
1144
+ ),
1145
+ model_name_map=model_name_map,
1146
+ output_folder=args.output_dir,
1147
+ structured_summaries=args.produce_structured_summary,
1148
+ )
1149
+ if xlsx_files:
1150
+ print(f"Excel output created: {sorted(xlsx_files)}")
1151
+ except Exception as e:
1152
+ print(f"Warning: Could not create Excel output: {e}")
1153
+
1154
+ # Upload outputs to S3 if enabled
1155
+ all_output_files = (
1156
+ list(topic_extraction_output_files)
1157
+ if topic_extraction_output_files
1158
+ else []
1159
+ )
1160
+ if xlsx_files:
1161
+ all_output_files.extend(xlsx_files)
1162
+ upload_outputs_to_s3_if_enabled(
1163
+ output_files=all_output_files,
1164
+ base_file_name=file_name,
1165
+ session_hash=session_hash,
1166
+ )
1167
+
1168
+ # Task 2: Validate Topics
1169
+ elif args.task == "validate":
1170
+ print("--- Starting Topic Validation Workflow... ---")
1171
+
1172
+ # Load data file
1173
+ if isinstance(args.input_file, str):
1174
+ args.input_file = [args.input_file]
1175
+
1176
+ file_data, file_name, total_number_of_batches = load_in_data_file(
1177
+ file_paths=args.input_file,
1178
+ in_colnames=[args.text_column],
1179
+ batch_size=args.batch_size,
1180
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1181
+ )
1182
+
1183
+ # Load previous output files
1184
+ (
1185
+ reference_df,
1186
+ topic_summary_df,
1187
+ latest_batch_completed_no_loop,
1188
+ deduplication_input_files_status,
1189
+ working_data_file_name_textbox,
1190
+ unique_topics_table_file_name_textbox,
1191
+ ) = load_in_previous_data_files(args.previous_output_files)
1192
+
1193
+ # Run validation
1194
+ (
1195
+ display_markdown,
1196
+ master_topic_df_state,
1197
+ master_unique_topics_df_state,
1198
+ master_reference_df_state,
1199
+ validation_output_files,
1200
+ text_output_file_list_state,
1201
+ latest_batch_completed,
1202
+ log_files_output,
1203
+ log_files_output_list_state,
1204
+ conversation_metadata_textbox,
1205
+ estimated_time_taken_number,
1206
+ deduplication_input_files,
1207
+ summarisation_input_files,
1208
+ modifiable_unique_topics_df_state,
1209
+ modification_input_files,
1210
+ in_join_files,
1211
+ missing_df_state,
1212
+ input_tokens_num,
1213
+ output_tokens_num,
1214
+ number_of_calls_num,
1215
+ output_messages_textbox,
1216
+ logged_content_df,
1217
+ ) = validate_topics_wrapper(
1218
+ file_data=file_data,
1219
+ reference_df=reference_df,
1220
+ topic_summary_df=topic_summary_df,
1221
+ file_name=working_data_file_name_textbox,
1222
+ chosen_cols=[args.text_column],
1223
+ batch_size=args.batch_size,
1224
+ model_choice=args.model_choice,
1225
+ in_api_key=args.google_api_key,
1226
+ temperature=args.temperature,
1227
+ max_tokens=args.max_tokens,
1228
+ azure_api_key_textbox=args.azure_api_key,
1229
+ azure_endpoint_textbox=args.azure_endpoint,
1230
+ reasoning_suffix="",
1231
+ group_name=args.group_by or "All",
1232
+ produce_structured_summary_radio=args.produce_structured_summary,
1233
+ force_zero_shot_radio=args.force_zero_shot,
1234
+ force_single_topic_radio=args.force_single_topic,
1235
+ context_textbox=args.context,
1236
+ additional_instructions_summary_format=args.additional_summary_instructions,
1237
+ output_folder=args.output_dir,
1238
+ output_debug_files=str(args.output_debug_files),
1239
+ original_full_file_name=file_name,
1240
+ additional_validation_issues_provided=args.additional_validation_issues,
1241
+ max_time_for_loop=args.max_time_for_loop,
1242
+ in_data_files=args.input_file,
1243
+ sentiment_checkbox=args.sentiment,
1244
+ logged_content=None,
1245
+ show_previous_table=args.show_previous_table,
1246
+ aws_access_key_textbox=args.aws_access_key,
1247
+ aws_secret_key_textbox=args.aws_secret_key,
1248
+ aws_region_textbox=args.aws_region,
1249
+ api_url=args.api_url if args.api_url else API_URL,
1250
+ )
1251
+
1252
+ end_time = time.time()
1253
+ processing_time = end_time - start_time
1254
+
1255
+ print("\n--- Topic Validation Complete ---")
1256
+ print(f"Processing time: {processing_time:.2f} seconds")
1257
+ print(f"\nOutput files saved to: {args.output_dir}")
1258
+ if validation_output_files:
1259
+ print("Generated Files:", sorted(validation_output_files))
1260
+
1261
+ # Write usage log
1262
+ write_usage_log(
1263
+ session_hash=session_hash,
1264
+ file_name=file_name,
1265
+ text_column=args.text_column,
1266
+ model_choice=args.model_choice,
1267
+ conversation_metadata=conversation_metadata_textbox or "",
1268
+ input_tokens=input_tokens_num or 0,
1269
+ output_tokens=output_tokens_num or 0,
1270
+ number_of_calls=number_of_calls_num or 0,
1271
+ estimated_time_taken=estimated_time_taken_number or processing_time,
1272
+ cost_code=args.cost_code,
1273
+ save_to_csv=args.save_logs_to_csv,
1274
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1275
+ )
1276
+
1277
+ # Create Excel output if requested
1278
+ if args.create_xlsx_output:
1279
+ print("\nCreating Excel output file...")
1280
+ try:
1281
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1282
+ in_data_files=args.input_file,
1283
+ chosen_cols=[args.text_column],
1284
+ reference_data_file_name_textbox=file_name,
1285
+ in_group_col=args.group_by,
1286
+ model_choice=args.model_choice,
1287
+ master_reference_df_state=master_reference_df_state,
1288
+ master_unique_topics_df_state=master_unique_topics_df_state,
1289
+ summarised_output_df=pd.DataFrame(), # No summaries yet
1290
+ missing_df_state=missing_df_state,
1291
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1292
+ usage_logs_location=(
1293
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1294
+ if args.save_logs_to_csv
1295
+ else ""
1296
+ ),
1297
+ model_name_map=model_name_map,
1298
+ output_folder=args.output_dir,
1299
+ structured_summaries=args.produce_structured_summary,
1300
+ )
1301
+ if xlsx_files:
1302
+ print(f"Excel output created: {sorted(xlsx_files)}")
1303
+ except Exception as e:
1304
+ print(f"Warning: Could not create Excel output: {e}")
1305
+
1306
+ # Task 3: Deduplicate Topics
1307
+ elif args.task == "deduplicate":
1308
+ print("--- Starting Topic Deduplication Workflow... ---")
1309
+
1310
+ # Load previous output files
1311
+ (
1312
+ reference_df,
1313
+ topic_summary_df,
1314
+ latest_batch_completed_no_loop,
1315
+ deduplication_input_files_status,
1316
+ working_data_file_name_textbox,
1317
+ unique_topics_table_file_name_textbox,
1318
+ ) = load_in_previous_data_files(args.previous_output_files)
1319
+
1320
+ if args.method == "fuzzy":
1321
+ # Fuzzy matching deduplication
1322
+ (
1323
+ ref_df_after_dedup,
1324
+ unique_df_after_dedup,
1325
+ summarisation_input_files,
1326
+ log_files_output,
1327
+ summarised_output_markdown,
1328
+ ) = deduplicate_topics(
1329
+ reference_df=reference_df,
1330
+ topic_summary_df=topic_summary_df,
1331
+ reference_table_file_name=working_data_file_name_textbox,
1332
+ unique_topics_table_file_name=unique_topics_table_file_name_textbox,
1333
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1334
+ merge_sentiment=args.merge_sentiment,
1335
+ merge_general_topics=args.merge_general_topics,
1336
+ score_threshold=args.similarity_threshold,
1337
+ in_data_files=args.input_file if args.input_file else list(),
1338
+ chosen_cols=[args.text_column] if args.text_column else list(),
1339
+ output_folder=args.output_dir,
1340
+ )
1341
+ else:
1342
+ # LLM deduplication
1343
+ model_source = model_name_map.get(args.model_choice, {}).get(
1344
+ "source", default_model_source
1345
+ )
1346
+ (
1347
+ ref_df_after_dedup,
1348
+ unique_df_after_dedup,
1349
+ summarisation_input_files,
1350
+ log_files_output,
1351
+ summarised_output_markdown,
1352
+ input_tokens_num,
1353
+ output_tokens_num,
1354
+ number_of_calls_num,
1355
+ estimated_time_taken_number,
1356
+ ) = deduplicate_topics_llm(
1357
+ reference_df=reference_df,
1358
+ topic_summary_df=topic_summary_df,
1359
+ reference_table_file_name=working_data_file_name_textbox,
1360
+ unique_topics_table_file_name=unique_topics_table_file_name_textbox,
1361
+ model_choice=args.model_choice,
1362
+ in_api_key=args.google_api_key,
1363
+ temperature=args.temperature,
1364
+ model_source=model_source,
1365
+ bedrock_runtime=None,
1366
+ local_model=None,
1367
+ tokenizer=None,
1368
+ assistant_model=None,
1369
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1370
+ merge_sentiment=args.merge_sentiment,
1371
+ merge_general_topics=args.merge_general_topics,
1372
+ in_data_files=args.input_file if args.input_file else list(),
1373
+ chosen_cols=[args.text_column] if args.text_column else list(),
1374
+ output_folder=args.output_dir,
1375
+ candidate_topics=(
1376
+ args.candidate_topics if args.candidate_topics else None
1377
+ ),
1378
+ azure_endpoint=args.azure_endpoint,
1379
+ output_debug_files=str(args.output_debug_files),
1380
+ api_url=args.api_url if args.api_url else API_URL,
1381
+ )
1382
+
1383
+ end_time = time.time()
1384
+ processing_time = end_time - start_time
1385
+
1386
+ print("\n--- Topic Deduplication Complete ---")
1387
+ print(f"Processing time: {processing_time:.2f} seconds")
1388
+ print(f"\nOutput files saved to: {args.output_dir}")
1389
+ if summarisation_input_files:
1390
+ print("Generated Files:", sorted(summarisation_input_files))
1391
+
1392
+ # Write usage log (only for LLM deduplication which has token counts)
1393
+ if args.method == "llm":
1394
+ # Extract token counts from LLM deduplication result
1395
+ llm_input_tokens = (
1396
+ input_tokens_num if "input_tokens_num" in locals() else 0
1397
+ )
1398
+ llm_output_tokens = (
1399
+ output_tokens_num if "output_tokens_num" in locals() else 0
1400
+ )
1401
+ llm_calls = (
1402
+ number_of_calls_num if "number_of_calls_num" in locals() else 0
1403
+ )
1404
+ llm_time = (
1405
+ estimated_time_taken_number
1406
+ if "estimated_time_taken_number" in locals()
1407
+ else processing_time
1408
+ )
1409
+
1410
+ write_usage_log(
1411
+ session_hash=session_hash,
1412
+ file_name=working_data_file_name_textbox,
1413
+ text_column=args.text_column if args.text_column else "",
1414
+ model_choice=args.model_choice,
1415
+ conversation_metadata="",
1416
+ input_tokens=llm_input_tokens,
1417
+ output_tokens=llm_output_tokens,
1418
+ number_of_calls=llm_calls,
1419
+ estimated_time_taken=llm_time,
1420
+ cost_code=args.cost_code,
1421
+ save_to_csv=args.save_logs_to_csv,
1422
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1423
+ )
1424
+
1425
+ # Create Excel output if requested
1426
+ xlsx_files = []
1427
+ if args.create_xlsx_output:
1428
+ print("\nCreating Excel output file...")
1429
+ try:
1430
+ # Use the deduplicated dataframes
1431
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1432
+ in_data_files=args.input_file if args.input_file else [],
1433
+ chosen_cols=[args.text_column] if args.text_column else [],
1434
+ reference_data_file_name_textbox=working_data_file_name_textbox,
1435
+ in_group_col=args.group_by,
1436
+ model_choice=args.model_choice,
1437
+ master_reference_df_state=ref_df_after_dedup,
1438
+ master_unique_topics_df_state=unique_df_after_dedup,
1439
+ summarised_output_df=pd.DataFrame(), # No summaries yet
1440
+ missing_df_state=pd.DataFrame(),
1441
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1442
+ usage_logs_location=(
1443
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1444
+ if args.save_logs_to_csv
1445
+ else ""
1446
+ ),
1447
+ model_name_map=model_name_map,
1448
+ output_folder=args.output_dir,
1449
+ structured_summaries=args.produce_structured_summary,
1450
+ )
1451
+ if xlsx_files:
1452
+ print(f"Excel output created: {sorted(xlsx_files)}")
1453
+ except Exception as e:
1454
+ print(f"Warning: Could not create Excel output: {e}")
1455
+
1456
+ # Upload outputs to S3 if enabled
1457
+ all_output_files = (
1458
+ list(summarisation_input_files) if summarisation_input_files else []
1459
+ )
1460
+ if xlsx_files:
1461
+ all_output_files.extend(xlsx_files)
1462
+ upload_outputs_to_s3_if_enabled(
1463
+ output_files=all_output_files,
1464
+ base_file_name=working_data_file_name_textbox,
1465
+ session_hash=session_hash,
1466
+ )
1467
+
1468
+ # Task 4: Summarise Topics
1469
+ elif args.task == "summarise":
1470
+ print("--- Starting Topic Summarisation Workflow... ---")
1471
+
1472
+ # Load previous output files
1473
+ (
1474
+ reference_df,
1475
+ topic_summary_df,
1476
+ latest_batch_completed_no_loop,
1477
+ deduplication_input_files_status,
1478
+ working_data_file_name_textbox,
1479
+ unique_topics_table_file_name_textbox,
1480
+ ) = load_in_previous_data_files(args.previous_output_files)
1481
+
1482
+ # Determine summary format prompt
1483
+ summary_format_prompt = (
1484
+ two_para_summary_format_prompt
1485
+ if args.summary_format == "two_paragraph"
1486
+ else single_para_summary_format_prompt
1487
+ )
1488
+
1489
+ # Run summarisation
1490
+ (
1491
+ summary_reference_table_sample_state,
1492
+ master_unique_topics_df_revised_summaries_state,
1493
+ master_reference_df_revised_summaries_state,
1494
+ summary_output_files,
1495
+ summarised_outputs_list,
1496
+ latest_summary_completed_num,
1497
+ conversation_metadata_textbox,
1498
+ summarised_output_markdown,
1499
+ log_files_output,
1500
+ overall_summarisation_input_files,
1501
+ input_tokens_num,
1502
+ output_tokens_num,
1503
+ number_of_calls_num,
1504
+ estimated_time_taken_number,
1505
+ output_messages_textbox,
1506
+ logged_content_df,
1507
+ ) = wrapper_summarise_output_topics_per_group(
1508
+ grouping_col=args.group_by,
1509
+ sampled_reference_table_df=reference_df.copy(), # Will be sampled if sample_reference_table=True
1510
+ topic_summary_df=topic_summary_df,
1511
+ reference_table_df=reference_df,
1512
+ model_choice=args.model_choice,
1513
+ in_api_key=args.google_api_key,
1514
+ temperature=args.temperature,
1515
+ reference_data_file_name=working_data_file_name_textbox,
1516
+ summarised_outputs=list(),
1517
+ latest_summary_completed=0,
1518
+ out_metadata_str="",
1519
+ in_data_files=args.input_file if args.input_file else list(),
1520
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1521
+ chosen_cols=[args.text_column] if args.text_column else list(),
1522
+ log_output_files=list(),
1523
+ summarise_format_radio=summary_format_prompt,
1524
+ output_folder=args.output_dir,
1525
+ context_textbox=args.context,
1526
+ aws_access_key_textbox=args.aws_access_key,
1527
+ aws_secret_key_textbox=args.aws_secret_key,
1528
+ aws_region_textbox=args.aws_region,
1529
+ model_name_map=model_name_map,
1530
+ hf_api_key_textbox=args.hf_token,
1531
+ azure_endpoint_textbox=args.azure_endpoint,
1532
+ existing_logged_content=list(),
1533
+ sample_reference_table=args.sample_reference_table,
1534
+ no_of_sampled_summaries=args.no_of_sampled_summaries,
1535
+ random_seed=args.random_seed,
1536
+ api_url=args.api_url if args.api_url else API_URL,
1537
+ additional_summary_instructions_provided=args.additional_summary_instructions,
1538
+ output_debug_files=str(args.output_debug_files),
1539
+ reasoning_suffix="",
1540
+ local_model=None,
1541
+ tokenizer=None,
1542
+ assistant_model=None,
1543
+ do_summaries="Yes",
1544
+ )
1545
+
1546
+ end_time = time.time()
1547
+ processing_time = end_time - start_time
1548
+
1549
+ print("\n--- Topic Summarisation Complete ---")
1550
+ print(f"Processing time: {processing_time:.2f} seconds")
1551
+ print(f"\nOutput files saved to: {args.output_dir}")
1552
+ if summary_output_files:
1553
+ print("Generated Files:", sorted(summary_output_files))
1554
+
1555
+ # Write usage log
1556
+ write_usage_log(
1557
+ session_hash=session_hash,
1558
+ file_name=working_data_file_name_textbox,
1559
+ text_column=args.text_column if args.text_column else "",
1560
+ model_choice=args.model_choice,
1561
+ conversation_metadata=conversation_metadata_textbox or "",
1562
+ input_tokens=input_tokens_num or 0,
1563
+ output_tokens=output_tokens_num or 0,
1564
+ number_of_calls=number_of_calls_num or 0,
1565
+ estimated_time_taken=estimated_time_taken_number or processing_time,
1566
+ cost_code=args.cost_code,
1567
+ save_to_csv=args.save_logs_to_csv,
1568
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1569
+ )
1570
+
1571
+ # Create Excel output if requested
1572
+ xlsx_files = []
1573
+ if args.create_xlsx_output:
1574
+ print("\nCreating Excel output file...")
1575
+ try:
1576
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1577
+ in_data_files=args.input_file if args.input_file else [],
1578
+ chosen_cols=[args.text_column] if args.text_column else [],
1579
+ reference_data_file_name_textbox=working_data_file_name_textbox,
1580
+ in_group_col=args.group_by,
1581
+ model_choice=args.model_choice,
1582
+ master_reference_df_state=master_reference_df_revised_summaries_state,
1583
+ master_unique_topics_df_state=master_unique_topics_df_revised_summaries_state,
1584
+ summarised_output_df=pd.DataFrame(), # Summaries are in the revised dataframes
1585
+ missing_df_state=pd.DataFrame(),
1586
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1587
+ usage_logs_location=(
1588
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1589
+ if args.save_logs_to_csv
1590
+ else ""
1591
+ ),
1592
+ model_name_map=model_name_map,
1593
+ output_folder=args.output_dir,
1594
+ structured_summaries=args.produce_structured_summary,
1595
+ )
1596
+ if xlsx_files:
1597
+ print(f"Excel output created: {sorted(xlsx_files)}")
1598
+ except Exception as e:
1599
+ print(f"Warning: Could not create Excel output: {e}")
1600
+
1601
+ # Upload outputs to S3 if enabled
1602
+ all_output_files = (
1603
+ list(summary_output_files) if summary_output_files else []
1604
+ )
1605
+ if xlsx_files:
1606
+ all_output_files.extend(xlsx_files)
1607
+ upload_outputs_to_s3_if_enabled(
1608
+ output_files=all_output_files,
1609
+ base_file_name=working_data_file_name_textbox,
1610
+ session_hash=session_hash,
1611
+ )
1612
+
1613
+ # Task 5: Overall Summary
1614
+ elif args.task == "overall_summary":
1615
+ print("--- Starting Overall Summary Workflow... ---")
1616
+
1617
+ # Load previous output files
1618
+ (
1619
+ reference_df,
1620
+ topic_summary_df,
1621
+ latest_batch_completed_no_loop,
1622
+ deduplication_input_files_status,
1623
+ working_data_file_name_textbox,
1624
+ unique_topics_table_file_name_textbox,
1625
+ ) = load_in_previous_data_files(args.previous_output_files)
1626
+
1627
+ # Run overall summary
1628
+ (
1629
+ overall_summary_output_files,
1630
+ overall_summarised_output_markdown,
1631
+ summarised_output_df,
1632
+ conversation_metadata_textbox,
1633
+ input_tokens_num,
1634
+ output_tokens_num,
1635
+ number_of_calls_num,
1636
+ estimated_time_taken_number,
1637
+ output_messages_textbox,
1638
+ logged_content_df,
1639
+ ) = overall_summary(
1640
+ topic_summary_df=topic_summary_df,
1641
+ model_choice=args.model_choice,
1642
+ in_api_key=args.google_api_key,
1643
+ temperature=args.temperature,
1644
+ reference_data_file_name=working_data_file_name_textbox,
1645
+ output_folder=args.output_dir,
1646
+ chosen_cols=[args.text_column] if args.text_column else list(),
1647
+ context_textbox=args.context,
1648
+ aws_access_key_textbox=args.aws_access_key,
1649
+ aws_secret_key_textbox=args.aws_secret_key,
1650
+ aws_region_textbox=args.aws_region,
1651
+ model_name_map=model_name_map,
1652
+ hf_api_key_textbox=args.hf_token,
1653
+ azure_endpoint_textbox=args.azure_endpoint,
1654
+ existing_logged_content=list(),
1655
+ api_url=args.api_url if args.api_url else API_URL,
1656
+ output_debug_files=str(args.output_debug_files),
1657
+ log_output_files=list(),
1658
+ reasoning_suffix="",
1659
+ local_model=None,
1660
+ tokenizer=None,
1661
+ assistant_model=None,
1662
+ do_summaries="Yes",
1663
+ )
1664
+
1665
+ end_time = time.time()
1666
+ processing_time = end_time - start_time
1667
+
1668
+ print("\n--- Overall Summary Complete ---")
1669
+ print(f"Processing time: {processing_time:.2f} seconds")
1670
+ print(f"\nOutput files saved to: {args.output_dir}")
1671
+ if overall_summary_output_files:
1672
+ print("Generated Files:", sorted(overall_summary_output_files))
1673
+
1674
+ # Write usage log
1675
+ write_usage_log(
1676
+ session_hash=session_hash,
1677
+ file_name=working_data_file_name_textbox,
1678
+ text_column=args.text_column if args.text_column else "",
1679
+ model_choice=args.model_choice,
1680
+ conversation_metadata=conversation_metadata_textbox or "",
1681
+ input_tokens=input_tokens_num or 0,
1682
+ output_tokens=output_tokens_num or 0,
1683
+ number_of_calls=number_of_calls_num or 0,
1684
+ estimated_time_taken=estimated_time_taken_number or processing_time,
1685
+ cost_code=args.cost_code,
1686
+ save_to_csv=args.save_logs_to_csv,
1687
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1688
+ )
1689
+
1690
+ # Create Excel output if requested
1691
+ xlsx_files = []
1692
+ if args.create_xlsx_output:
1693
+ print("\nCreating Excel output file...")
1694
+ try:
1695
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1696
+ in_data_files=args.input_file if args.input_file else [],
1697
+ chosen_cols=[args.text_column] if args.text_column else [],
1698
+ reference_data_file_name_textbox=working_data_file_name_textbox,
1699
+ in_group_col=args.group_by,
1700
+ model_choice=args.model_choice,
1701
+ master_reference_df_state=reference_df, # Use original reference_df
1702
+ master_unique_topics_df_state=topic_summary_df, # Use original topic_summary_df
1703
+ summarised_output_df=summarised_output_df,
1704
+ missing_df_state=pd.DataFrame(),
1705
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1706
+ usage_logs_location=(
1707
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1708
+ if args.save_logs_to_csv
1709
+ else ""
1710
+ ),
1711
+ model_name_map=model_name_map,
1712
+ output_folder=args.output_dir,
1713
+ structured_summaries=args.produce_structured_summary,
1714
+ )
1715
+ if xlsx_files:
1716
+ print(f"Excel output created: {sorted(xlsx_files)}")
1717
+ except Exception as e:
1718
+ print(f"Warning: Could not create Excel output: {e}")
1719
+
1720
+ # Upload outputs to S3 if enabled
1721
+ all_output_files = (
1722
+ list(overall_summary_output_files)
1723
+ if overall_summary_output_files
1724
+ else []
1725
+ )
1726
+ if xlsx_files:
1727
+ all_output_files.extend(xlsx_files)
1728
+ upload_outputs_to_s3_if_enabled(
1729
+ output_files=all_output_files,
1730
+ base_file_name=working_data_file_name_textbox,
1731
+ session_hash=session_hash,
1732
+ )
1733
+
1734
+ # Task 6: All-in-One Pipeline
1735
+ elif args.task == "all_in_one":
1736
+ print("--- Starting All-in-One Pipeline Workflow... ---")
1737
+
1738
+ # Load data file
1739
+ if isinstance(args.input_file, str):
1740
+ args.input_file = [args.input_file]
1741
+
1742
+ file_data, file_name, total_number_of_batches = load_in_data_file(
1743
+ file_paths=args.input_file,
1744
+ in_colnames=[args.text_column],
1745
+ batch_size=args.batch_size,
1746
+ in_excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1747
+ )
1748
+
1749
+ # Prepare candidate topics if provided
1750
+ candidate_topics = None
1751
+ if args.candidate_topics:
1752
+ candidate_topics = args.candidate_topics
1753
+
1754
+ # Determine summary format prompt
1755
+ summary_format_prompt = (
1756
+ two_para_summary_format_prompt
1757
+ if args.summary_format == "two_paragraph"
1758
+ else single_para_summary_format_prompt
1759
+ )
1760
+
1761
+ # Run all-in-one pipeline
1762
+ (
1763
+ display_markdown,
1764
+ master_topic_df_state,
1765
+ master_unique_topics_df_state,
1766
+ master_reference_df_state,
1767
+ topic_extraction_output_files,
1768
+ text_output_file_list_state,
1769
+ latest_batch_completed,
1770
+ log_files_output,
1771
+ log_files_output_list_state,
1772
+ conversation_metadata_textbox,
1773
+ estimated_time_taken_number,
1774
+ deduplication_input_files,
1775
+ summarisation_input_files,
1776
+ modifiable_unique_topics_df_state,
1777
+ modification_input_files,
1778
+ in_join_files,
1779
+ missing_df_state,
1780
+ input_tokens_num,
1781
+ output_tokens_num,
1782
+ number_of_calls_num,
1783
+ output_messages_textbox,
1784
+ summary_reference_table_sample_state,
1785
+ summarised_references_markdown,
1786
+ master_unique_topics_df_revised_summaries_state,
1787
+ master_reference_df_revised_summaries_state,
1788
+ summary_output_files,
1789
+ summarised_outputs_list,
1790
+ latest_summary_completed_num,
1791
+ overall_summarisation_input_files,
1792
+ overall_summary_output_files,
1793
+ overall_summarised_output_markdown,
1794
+ summarised_output_df,
1795
+ logged_content_df,
1796
+ ) = all_in_one_pipeline(
1797
+ grouping_col=args.group_by,
1798
+ in_data_files=args.input_file,
1799
+ file_data=file_data,
1800
+ existing_topics_table=pd.DataFrame(),
1801
+ existing_reference_df=pd.DataFrame(),
1802
+ existing_topic_summary_df=pd.DataFrame(),
1803
+ unique_table_df_display_table_markdown="",
1804
+ original_file_name=file_name,
1805
+ total_number_of_batches=total_number_of_batches,
1806
+ in_api_key=args.google_api_key,
1807
+ temperature=args.temperature,
1808
+ chosen_cols=[args.text_column],
1809
+ model_choice=args.model_choice,
1810
+ candidate_topics=candidate_topics,
1811
+ first_loop_state=True,
1812
+ conversation_metadata_text="",
1813
+ latest_batch_completed=0,
1814
+ time_taken_so_far=0,
1815
+ initial_table_prompt_text=initial_table_prompt,
1816
+ initial_table_system_prompt_text=initial_table_system_prompt,
1817
+ add_existing_topics_system_prompt_text=add_existing_topics_system_prompt,
1818
+ add_existing_topics_prompt_text=add_existing_topics_prompt,
1819
+ number_of_prompts_used=1,
1820
+ batch_size=args.batch_size,
1821
+ context_text=args.context,
1822
+ sentiment_choice=args.sentiment,
1823
+ force_zero_shot_choice=args.force_zero_shot,
1824
+ in_excel_sheets=args.excel_sheets,
1825
+ force_single_topic_choice=args.force_single_topic,
1826
+ produce_structures_summary_choice=args.produce_structured_summary,
1827
+ aws_access_key_text=args.aws_access_key,
1828
+ aws_secret_key_text=args.aws_secret_key,
1829
+ aws_region_text=args.aws_region,
1830
+ hf_api_key_text=args.hf_token,
1831
+ azure_api_key_text=args.azure_api_key,
1832
+ azure_endpoint_text=args.azure_endpoint,
1833
+ output_folder=args.output_dir,
1834
+ merge_sentiment=args.merge_sentiment,
1835
+ merge_general_topics=args.merge_general_topics,
1836
+ score_threshold=args.similarity_threshold,
1837
+ summarise_format=summary_format_prompt,
1838
+ random_seed=args.random_seed,
1839
+ log_files_output_list_state=list(),
1840
+ model_name_map_state=model_name_map,
1841
+ usage_logs_location=(
1842
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1843
+ if args.save_logs_to_csv
1844
+ else ""
1845
+ ),
1846
+ existing_logged_content=list(),
1847
+ additional_instructions_summary_format=args.additional_summary_instructions,
1848
+ additional_validation_issues_provided="",
1849
+ show_previous_table="Yes",
1850
+ sample_reference_table_checkbox=args.sample_reference_table,
1851
+ api_url=args.api_url if args.api_url else API_URL,
1852
+ output_debug_files=str(args.output_debug_files),
1853
+ model=None,
1854
+ tokenizer=None,
1855
+ assistant_model=None,
1856
+ max_rows=999999,
1857
+ )
1858
+
1859
+ end_time = time.time()
1860
+ processing_time = end_time - start_time
1861
+
1862
+ print("\n--- All-in-One Pipeline Complete ---")
1863
+ print(f"Processing time: {processing_time:.2f} seconds")
1864
+ print(f"\nOutput files saved to: {args.output_dir}")
1865
+ if overall_summary_output_files:
1866
+ print("Generated Files:", sorted(overall_summary_output_files))
1867
+
1868
+ # Write usage log
1869
+ write_usage_log(
1870
+ session_hash=session_hash,
1871
+ file_name=file_name,
1872
+ text_column=args.text_column,
1873
+ model_choice=args.model_choice,
1874
+ conversation_metadata=conversation_metadata_textbox or "",
1875
+ input_tokens=input_tokens_num or 0,
1876
+ output_tokens=output_tokens_num or 0,
1877
+ number_of_calls=number_of_calls_num or 0,
1878
+ estimated_time_taken=estimated_time_taken_number or processing_time,
1879
+ cost_code=args.cost_code,
1880
+ save_to_csv=args.save_logs_to_csv,
1881
+ save_to_dynamodb=args.save_logs_to_dynamodb,
1882
+ )
1883
+
1884
+ # Create Excel output if requested
1885
+ xlsx_files = []
1886
+ if args.create_xlsx_output:
1887
+ print("\nCreating Excel output file...")
1888
+ try:
1889
+ xlsx_files, _ = collect_output_csvs_and_create_excel_output(
1890
+ in_data_files=args.input_file,
1891
+ chosen_cols=[args.text_column],
1892
+ reference_data_file_name_textbox=file_name,
1893
+ in_group_col=args.group_by,
1894
+ model_choice=args.model_choice,
1895
+ master_reference_df_state=master_reference_df_revised_summaries_state,
1896
+ master_unique_topics_df_state=master_unique_topics_df_revised_summaries_state,
1897
+ summarised_output_df=summarised_output_df,
1898
+ missing_df_state=missing_df_state,
1899
+ excel_sheets=args.excel_sheets[0] if args.excel_sheets else "",
1900
+ usage_logs_location=(
1901
+ os.path.join(USAGE_LOGS_FOLDER, USAGE_LOG_FILE_NAME)
1902
+ if args.save_logs_to_csv
1903
+ else ""
1904
+ ),
1905
+ model_name_map=model_name_map,
1906
+ output_folder=args.output_dir,
1907
+ structured_summaries=args.produce_structured_summary,
1908
+ )
1909
+ if xlsx_files:
1910
+ print(f"Excel output created: {sorted(xlsx_files)}")
1911
+ except Exception as e:
1912
+ print(f"Warning: Could not create Excel output: {e}")
1913
+
1914
+ # Upload outputs to S3 if enabled
1915
+ # Collect all output files from the pipeline
1916
+ all_output_files = []
1917
+ if topic_extraction_output_files:
1918
+ all_output_files.extend(topic_extraction_output_files)
1919
+ if overall_summary_output_files:
1920
+ all_output_files.extend(overall_summary_output_files)
1921
+ if xlsx_files:
1922
+ all_output_files.extend(xlsx_files)
1923
+ upload_outputs_to_s3_if_enabled(
1924
+ output_files=all_output_files,
1925
+ base_file_name=file_name,
1926
+ session_hash=session_hash,
1927
+ )
1928
+
1929
+ else:
1930
+ print(f"Error: Invalid task '{args.task}'.")
1931
+ print(
1932
+ "Valid options: 'extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one'"
1933
+ )
1934
+
1935
+ except Exception as e:
1936
+ print(f"\nAn error occurred during the workflow: {e}")
1937
+ import traceback
1938
+
1939
+ traceback.print_exc()
1940
+
1941
+
1942
+ if __name__ == "__main__":
1943
+ main()
entrypoint.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # Exit immediately if a command exits with a non-zero status.
4
+ set -e
5
+
6
+ echo "Starting in APP_MODE: $APP_MODE"
7
+
8
+ # --- Start the app based on mode ---
9
+
10
+ if [ "$APP_MODE" = "lambda" ]; then
11
+ echo "Starting in Lambda mode..."
12
+ # The CMD from Dockerfile will be passed as "$@"
13
+ exec python -m awslambdaric "$@"
14
+ else
15
+ echo "Starting in Gradio mode..."
16
+ exec python app.py
17
+ fi
18
+
example_data/case_note_headers_specific.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ο»ΏGeneral Topic,Subtopic
2
+ Mental health,Anger
3
+ Mental health,Social issues
4
+ Physical health,General
5
+ Physical health,Substance misuse
6
+ Behaviour at school,Behaviour at school
7
+ Trends over time,Trends over time
example_data/combined_case_notes.csv ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date,Social Worker,Client,Case Note
2
+ "January 3, 2023",Jane Smith,Alex D.,"Met with Alex at school following reports of increased absences and declining grades. Alex appeared sullen and avoided eye contact. When prompted about school, Alex expressed feelings of isolation and stated, ""No one gets me."" Scheduled a follow-up meeting to further explore these feelings."
3
+ "January 17, 2023",Jane Smith,Alex D.,"Met with Alex at the community center. Alex displayed sudden outbursts of anger when discussing home life, particularly in relation to a new stepfather. Alex mentioned occasional substance use, but did not specify which substances. Recommended a comprehensive assessment."
4
+ "February 5, 2023",Jane Smith,Alex D.,Home visit conducted. Alex's mother reported frequent arguments at home. She expressed concerns about Alex's new group of friends and late-night outings. Noted potential signs of substance abuse. Suggested family counseling.
5
+ "February 21, 2023",Jane Smith,Alex D.,"Met with Alex alone at my office. Alex appeared more agitated than in previous meetings. There were visible signs of self-harm on Alex's arms. When questioned, Alex became defensive. Immediate referral made to a mental health professional."
6
+ "March 10, 2023",Jane Smith,Alex D.,Attended joint session with Alex and a therapist. Alex shared feelings of hopelessness and admitted to occasional thoughts of self-harm. Therapist recommended a comprehensive mental health evaluation and ongoing therapy.
7
+ "March 25, 2023",Jane Smith,Alex D.,"Received a call from Alex's school about a physical altercation with another student. Met with Alex, who displayed high levels of frustration and admitted to the use of alcohol. Discussed the importance of seeking help and finding positive coping mechanisms. Recommended enrollment in an anger management program."
8
+ "April 15, 2023",Jane Smith,Alex D.,Met with Alex and mother to discuss progress. Alex's mother expressed concerns about Alex's increasing aggression at home. Alex acknowledged the issues but blamed others for provoking the behavior. It was decided that a more intensive intervention may be needed.
9
+ "April 30, 2023",Jane Smith,Alex D.,"Met with Alex and a psychiatrist. Psychiatrist diagnosed Alex with Oppositional Defiant Disorder (ODD) and co-morbid substance use disorder. A treatment plan was discussed, including medication, therapy, and family counseling."
10
+ "May 20, 2023",Jane Smith,Alex D.,"Met with Alex to discuss progress. Alex has started attending group therapy and has shown slight improvements in behavior. Still, concerns remain about substance use. Discussed potential for a short-term residential treatment program."
11
+ "January 3, 2023",Jane Smith,Jamie L.,"Met with Jamie at school after receiving reports of consistent tardiness and decreased participation in class. Jamie appeared withdrawn and exhibited signs of sadness. When asked about feelings, Jamie expressed feeling ""empty"" and ""hopeless"" at times. Scheduled a follow-up meeting to further explore these feelings."
12
+ "January 17, 2023",Jane Smith,Jamie L.,"Met with Jamie at the community center. Jamie shared feelings of low self-worth, mentioning that it's hard to find motivation for daily tasks. Discussed potential triggers and learned about recent family financial struggles. Recommended counseling and possible group therapy for peer support."
13
+ "February 5, 2023",Jane Smith,Jamie L.,Home visit conducted. Jamie's parents shared concerns about Jamie's increasing withdrawal from family activities and lack of interest in hobbies. Parents mentioned that Jamie spends a lot of time alone in the room. Suggested family therapy to open communication channels.
14
+ "February 21, 2023",Jane Smith,Jamie L.,Met with Jamie in my office. Jamie opened up about feelings of isolation and mentioned difficulty sleeping. No signs of self-harm or suicidal ideation were noted. Recommended a comprehensive mental health assessment to better understand the depth of the depression.
15
+ "March 10, 2023",Jane Smith,Jamie L.,"Attended a joint session with Jamie and a therapist. The therapist noted signs of moderate depression. Together, we discussed coping strategies and potential interventions. Jamie showed interest in art therapy."
16
+ "March 25, 2023",Jane Smith,Jamie L.,"Received feedback from Jamie's school that academic performance has slightly improved. However, social interactions remain limited. Encouraged Jamie to join school clubs or groups to foster connection."
17
+ "April 15, 2023",Jane Smith,Jamie L.,"Met with Jamie and parents to discuss progress. Parents have observed slight improvements in mood on some days, but overall, Jamie still appears to struggle. It was decided to explore medication as a potential aid alongside therapy."
18
+ "April 30, 2023",Jane Smith,Jamie L.,Met with Jamie and a psychiatrist. The psychiatrist diagnosed Jamie with Major Depressive Disorder (MDD) and suggested considering antidepressant medication. Discussed the potential benefits and side effects. Jamie and parents will think it over.
19
+ "May 20, 2023",Jane Smith,Jamie L.,"Jamie has started on a low dose of an antidepressant. Initial feedback is positive, with some improvement in mood and energy levels. Will continue monitoring and adjusting as necessary."
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_structured_summaries.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:322a081b29d4fb40ccae7d47aa74fda772a002eda576ddc98d6acc86366cff11
3
+ size 13502
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3dcc1ea155169c23d913043b1ad87da2f2912be36d9fb1521c72ee05b8dcf36
3
+ size 25299
example_data/combined_case_notes_col_Case_Note_Gemma_3_4B_topic_analysis_grouped.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e1eaede9af75b6ab695b1cfc6c01ec875abf14521249ba7257bd4bb0afd7ee8
3
+ size 28673
example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30947f1355eacc74c92d09b766e8e3d71092b9a240e7f8acd381874b7d7ebcb3
3
+ size 24673
example_data/dummy_consultation_r_col_Response_text_Gemma_3_4B_topic_analysis_zero_shot.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5f0e36143d8362391e3b11d1c20e3a2a1b7536b8f0c972e3d44644eb9ae4e82
3
+ size 27592
example_data/dummy_consultation_response.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Response Reference,Object to or Support application,Response text
2
+ R1,Object,I strongly object to the proposed five-storey apartment block on Main Street. It is completely out of keeping with the existing character of the area and will overshadow the existing buildings.
3
+ R2,Support,"I fully support the proposed development. The town needs more housing, and this development will provide much-needed homes."
4
+ R3,Object,The proposed development is too tall and will have a negative impact on the views from the surrounding area.
5
+ R4,Object,The loss of the well-loved cafe will be a great loss to the community.
6
+ R5,Support,The development will bring much-needed investment to the area and create jobs.
7
+ R6,Object,The increased traffic generated by the development will cause congestion on Main Street.
8
+ R7,Support,The development will provide much-needed affordable housing.
9
+ R8,Object,The development will have a negative impact on the local environment.
10
+ R9,Support,The development will improve the appearance of Main Street.
11
+ R10,Object,The development will overshadow the existing buildings and make them feel cramped.
12
+ R11,Support,The development will provide much-needed amenities for the local community.
13
+ R12,Object,The development will have a negative impact on the local wildlife.
14
+ R13,Support,The development will help to revitalise the town centre.
15
+ R14,Object,The development will increase noise pollution in the area.
16
+ R15,Support,The development will provide much-needed parking spaces.
17
+ R16,Object,The development will have a negative impact on the local businesses.
18
+ R17,Support,The development will provide much-needed green space.
19
+ R18,Object,The development will have a negative impact on the local heritage.
20
+ R19,Support,The development will provide much-needed facilities for young people.
21
+ R20,Object,The development will have a negative impact on the local schools.
22
+ R21,Support,The development will provide much-needed social housing.
23
+ R22,Object,The development will have a negative impact on the local infrastructure.
24
+ R23,Support,The development will provide much-needed jobs for local people.
25
+ R24,Object,The development will have a negative impact on the local economy.
26
+ R25,Support,The development will provide much-needed community facilities.
27
+ R26,Object,The development will have a negative impact on the local amenities.
28
+ R27,Support,The development will provide much-needed housing for young people.
29
+ R28,Object,The development will have a negative impact on the local character.
30
+ R29,Support,The development will provide much-needed housing for families.
31
+ R30,Object,The development will have a negative impact on the local quality of life.
example_data/dummy_consultation_response_themes.csv ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ο»Ώtopics
2
+ Need for family housing
3
+ Impact on the character of the area
4
+ Amenities for the local community
5
+ Revitalisation of the town centre
6
+ Impact on local wildlife
7
+ Parking
8
+ Impact on local businesses
9
+ Green space
10
+ Noise pollution
11
+ Impact on local heritage
12
+ Facilities for young people
13
+ Impact on local schools
14
+ Impact on views
15
+ Loss of cafe
16
+ Investment and job creation
17
+ Traffic congestion
18
+ Affordable housing
19
+ Impact on the local environment
20
+ Improvement of main street
21
+ Impact on local infrastructure
22
+ Investment and job creation
23
+ Impact on local schools
24
+ Provision of community facilities
25
+ Impact on local heritage
26
+ Impact on quality of life
intros/intro.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Create thematic summaries from your data
2
+
3
+ Extract topics and summarise open text using Large Language Models (LLMs). The model will loop through all text rows to find the most relevant general topics and subtopics, and provide a short summary of each. If you have specific topics in mind, you can enter them in 'Provide a list of specific topics' below.
4
+
5
+ NOTE: LLMs are not 100% accurate and may produce biased or incorrect responses. All files downloaded from this app **need to be checked by a human** before they are used in further outputs.
6
+
7
+ Unsure of how to use this app? Try an example by clicking on one of the example datasets below to see typical outputs the app can produce. There is also a user guide provided alongside this app - please ask your system administrator if you do not have access.
lambda_entrypoint.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import boto3
5
+ from dotenv import load_dotenv
6
+
7
+ # Import the main function from your CLI script
8
+ from cli_topics import main as cli_main
9
+ from tools.config import (
10
+ AWS_REGION,
11
+ BATCH_SIZE_DEFAULT,
12
+ DEDUPLICATION_THRESHOLD,
13
+ DEFAULT_COST_CODE,
14
+ DEFAULT_SAMPLED_SUMMARIES,
15
+ LLM_MAX_NEW_TOKENS,
16
+ LLM_SEED,
17
+ LLM_TEMPERATURE,
18
+ OUTPUT_DEBUG_FILES,
19
+ SAVE_LOGS_TO_CSV,
20
+ SAVE_LOGS_TO_DYNAMODB,
21
+ SESSION_OUTPUT_FOLDER,
22
+ USAGE_LOGS_FOLDER,
23
+ convert_string_to_boolean,
24
+ )
25
+
26
+
27
+ def _get_env_list(env_var_name: str | list[str] | None) -> list[str]:
28
+ """Parses a comma-separated environment variable into a list of strings."""
29
+ if isinstance(env_var_name, list):
30
+ return env_var_name
31
+ if env_var_name is None:
32
+ return []
33
+
34
+ # Handle string input
35
+ value = str(env_var_name).strip()
36
+ if not value or value == "[]":
37
+ return []
38
+
39
+ # Remove brackets if present (e.g., "[item1, item2]" -> "item1, item2")
40
+ if value.startswith("[") and value.endswith("]"):
41
+ value = value[1:-1]
42
+
43
+ # Remove quotes and split by comma
44
+ value = value.replace('"', "").replace("'", "")
45
+ if not value:
46
+ return []
47
+
48
+ # Split by comma and filter out any empty strings
49
+ return [s.strip() for s in value.split(",") if s.strip()]
50
+
51
+
52
+ print("Lambda entrypoint loading...")
53
+
54
+ # Initialize S3 client outside the handler for connection reuse
55
+ s3_client = boto3.client("s3", region_name=os.getenv("AWS_REGION", AWS_REGION))
56
+ print("S3 client initialised")
57
+
58
+ # Lambda's only writable directory is /tmp. Ensure that all temporary files are stored in this directory.
59
+ TMP_DIR = "/tmp"
60
+ INPUT_DIR = os.path.join(TMP_DIR, "input")
61
+ OUTPUT_DIR = os.path.join(TMP_DIR, "output")
62
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(TMP_DIR, "gradio_tmp")
63
+ os.environ["MPLCONFIGDIR"] = os.path.join(TMP_DIR, "matplotlib_cache")
64
+ os.environ["FEEDBACK_LOGS_FOLDER"] = os.path.join(TMP_DIR, "feedback")
65
+ os.environ["ACCESS_LOGS_FOLDER"] = os.path.join(TMP_DIR, "logs")
66
+ os.environ["USAGE_LOGS_FOLDER"] = os.path.join(TMP_DIR, "usage")
67
+
68
+ # Define compatible file types for processing
69
+ COMPATIBLE_FILE_TYPES = {
70
+ ".csv",
71
+ ".xlsx",
72
+ ".xls",
73
+ ".parquet",
74
+ }
75
+
76
+
77
+ def download_file_from_s3(bucket_name, key, download_path):
78
+ """Download a file from S3 to the local filesystem."""
79
+ try:
80
+ s3_client.download_file(bucket_name, key, download_path)
81
+ print(f"Successfully downloaded s3://{bucket_name}/{key} to {download_path}")
82
+ except Exception as e:
83
+ print(f"Error downloading from S3: {e}")
84
+ raise
85
+
86
+
87
+ def upload_directory_to_s3(local_directory, bucket_name, s3_prefix):
88
+ """Upload all files from a local directory to an S3 prefix."""
89
+ for root, _, files in os.walk(local_directory):
90
+ for file_name in files:
91
+ local_file_path = os.path.join(root, file_name)
92
+ # Create a relative path to maintain directory structure if needed
93
+ relative_path = os.path.relpath(local_file_path, local_directory)
94
+ output_key = os.path.join(s3_prefix, relative_path).replace("\\", "/")
95
+
96
+ try:
97
+ s3_client.upload_file(local_file_path, bucket_name, output_key)
98
+ print(
99
+ f"Successfully uploaded {local_file_path} to s3://{bucket_name}/{output_key}"
100
+ )
101
+ except Exception as e:
102
+ print(f"Error uploading to S3: {e}")
103
+ raise
104
+
105
+
106
+ def lambda_handler(event, context):
107
+ print(f"Received event: {json.dumps(event)}")
108
+
109
+ # 1. Setup temporary directories
110
+ os.makedirs(INPUT_DIR, exist_ok=True)
111
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
112
+
113
+ # 2. Extract information from the event
114
+ # Assumes the event is triggered by S3 and may contain an 'arguments' payload
115
+ try:
116
+ record = event["Records"][0]
117
+ bucket_name = record["s3"]["bucket"]["name"]
118
+ input_key = record["s3"]["object"]["key"]
119
+
120
+ # The user metadata can be used to pass arguments
121
+ # This is more robust than embedding them in the main event body
122
+ try:
123
+ response = s3_client.head_object(Bucket=bucket_name, Key=input_key)
124
+ metadata = response.get("Metadata", dict())
125
+ print(f"S3 object metadata: {metadata}")
126
+
127
+ # Arguments can be passed as a JSON string in metadata
128
+ arguments_str = metadata.get("arguments", "{}")
129
+ print(f"Arguments string from metadata: '{arguments_str}'")
130
+
131
+ if arguments_str and arguments_str != "{}":
132
+ arguments = json.loads(arguments_str)
133
+ print(f"Successfully parsed arguments from metadata: {arguments}")
134
+ else:
135
+ arguments = dict()
136
+ print("No arguments found in metadata, using empty dictionary")
137
+ except Exception as e:
138
+ print(f"Warning: Could not parse metadata arguments: {e}")
139
+ print("Using empty arguments dictionary")
140
+ arguments = dict()
141
+
142
+ except (KeyError, IndexError) as e:
143
+ print(
144
+ f"Could not parse S3 event record: {e}. Checking for direct invocation payload."
145
+ )
146
+ # Fallback for direct invocation (e.g., from Step Functions or manual test)
147
+ bucket_name = event.get("bucket_name")
148
+ input_key = event.get("input_key")
149
+ arguments = event.get("arguments", dict())
150
+ if not all([bucket_name, input_key]):
151
+ raise ValueError(
152
+ "Missing 'bucket_name' or 'input_key' in direct invocation event."
153
+ )
154
+
155
+ # Log file type information
156
+ file_extension = os.path.splitext(input_key)[1].lower()
157
+ print(f"Detected file extension: '{file_extension}'")
158
+
159
+ # 3. Download the main input file
160
+ input_file_path = os.path.join(INPUT_DIR, os.path.basename(input_key))
161
+ download_file_from_s3(bucket_name, input_key, input_file_path)
162
+
163
+ # 3.1. Validate file type compatibility
164
+ is_env_file = input_key.lower().endswith(".env")
165
+
166
+ if not is_env_file and file_extension not in COMPATIBLE_FILE_TYPES:
167
+ error_message = f"File type '{file_extension}' is not supported for processing. Compatible file types are: {', '.join(sorted(COMPATIBLE_FILE_TYPES))}"
168
+ print(f"ERROR: {error_message}")
169
+ print(f"File was not processed due to unsupported file type: {file_extension}")
170
+ return {
171
+ "statusCode": 400,
172
+ "body": json.dumps(
173
+ {
174
+ "error": "Unsupported file type",
175
+ "message": error_message,
176
+ "supported_types": list(COMPATIBLE_FILE_TYPES),
177
+ "received_type": file_extension,
178
+ "file_processed": False,
179
+ }
180
+ ),
181
+ }
182
+
183
+ print(f"File type '{file_extension}' is compatible for processing")
184
+ if is_env_file:
185
+ print("Processing .env file for configuration")
186
+ else:
187
+ print(f"Processing {file_extension} file for topic modelling")
188
+
189
+ # 3.5. Check if the downloaded file is a .env file and handle accordingly
190
+ actual_input_file_path = input_file_path
191
+ if input_key.lower().endswith(".env"):
192
+ print("Detected .env file, loading environment variables...")
193
+
194
+ # Load environment variables from the .env file
195
+ print(f"Loading .env file from: {input_file_path}")
196
+
197
+ # Check if file exists and is readable
198
+ if os.path.exists(input_file_path):
199
+ print(".env file exists and is readable")
200
+ with open(input_file_path, "r") as f:
201
+ content = f.read()
202
+ print(f".env file content preview: {content[:200]}...")
203
+ else:
204
+ print(f"ERROR: .env file does not exist at {input_file_path}")
205
+
206
+ load_dotenv(input_file_path, override=True)
207
+ print("Environment variables loaded from .env file")
208
+
209
+ # Extract the actual input file path from environment variables
210
+ env_input_file = os.getenv("INPUT_FILE")
211
+
212
+ if env_input_file:
213
+ print(f"Found input file path in environment: {env_input_file}")
214
+
215
+ # If the path is an S3 path, download it
216
+ if env_input_file.startswith("s3://"):
217
+ # Parse S3 path: s3://bucket/key
218
+ s3_path_parts = env_input_file[5:].split("/", 1)
219
+ if len(s3_path_parts) == 2:
220
+ env_bucket = s3_path_parts[0]
221
+ env_key = s3_path_parts[1]
222
+ actual_input_file_path = os.path.join(
223
+ INPUT_DIR, os.path.basename(env_key)
224
+ )
225
+ print(
226
+ f"Downloading actual input file from s3://{env_bucket}/{env_key}"
227
+ )
228
+ download_file_from_s3(env_bucket, env_key, actual_input_file_path)
229
+ else:
230
+ print("Warning: Invalid S3 path format in environment variable")
231
+ actual_input_file_path = input_file_path
232
+ else:
233
+ # Assume it's a local path or relative path
234
+ actual_input_file_path = env_input_file
235
+ print(
236
+ f"Using input file path from environment: {actual_input_file_path}"
237
+ )
238
+ else:
239
+ print("Warning: No input file path found in environment variables")
240
+ # Fall back to using the .env file itself (though this might not be what we want)
241
+ actual_input_file_path = input_file_path
242
+ else:
243
+ print("File is not a .env file, proceeding with normal processing")
244
+
245
+ # 4. Prepare arguments for the CLI function
246
+ # This dictionary should mirror the arguments that cli_topics.main() expects via direct_mode_args
247
+
248
+ cli_args = {
249
+ # Task Selection
250
+ "task": arguments.get("task", os.getenv("DIRECT_MODE_TASK", "extract")),
251
+ # General Arguments
252
+ "input_file": [actual_input_file_path] if actual_input_file_path else None,
253
+ "output_dir": arguments.get(
254
+ "output_dir", os.getenv("DIRECT_MODE_OUTPUT_DIR", OUTPUT_DIR)
255
+ ),
256
+ "input_dir": arguments.get("input_dir", INPUT_DIR),
257
+ "text_column": arguments.get(
258
+ "text_column", os.getenv("DIRECT_MODE_TEXT_COLUMN", "")
259
+ ),
260
+ "previous_output_files": _get_env_list(
261
+ arguments.get(
262
+ "previous_output_files",
263
+ os.getenv("DIRECT_MODE_PREVIOUS_OUTPUT_FILES", list()),
264
+ )
265
+ ),
266
+ "username": arguments.get("username", os.getenv("DIRECT_MODE_USERNAME", "")),
267
+ "save_to_user_folders": convert_string_to_boolean(
268
+ arguments.get(
269
+ "save_to_user_folders",
270
+ os.getenv("SESSION_OUTPUT_FOLDER", str(SESSION_OUTPUT_FOLDER)),
271
+ )
272
+ ),
273
+ "excel_sheets": _get_env_list(
274
+ arguments.get("excel_sheets", os.getenv("DIRECT_MODE_EXCEL_SHEETS", list()))
275
+ ),
276
+ "group_by": arguments.get("group_by", os.getenv("DIRECT_MODE_GROUP_BY", "")),
277
+ # Model Configuration
278
+ "model_choice": arguments.get(
279
+ "model_choice", os.getenv("DIRECT_MODE_MODEL_CHOICE", "")
280
+ ),
281
+ "temperature": float(
282
+ arguments.get(
283
+ "temperature",
284
+ os.getenv("DIRECT_MODE_TEMPERATURE", str(LLM_TEMPERATURE)),
285
+ )
286
+ ),
287
+ "batch_size": int(
288
+ arguments.get(
289
+ "batch_size",
290
+ os.getenv("DIRECT_MODE_BATCH_SIZE", str(BATCH_SIZE_DEFAULT)),
291
+ )
292
+ ),
293
+ "max_tokens": int(
294
+ arguments.get(
295
+ "max_tokens",
296
+ os.getenv("DIRECT_MODE_MAX_TOKENS", str(LLM_MAX_NEW_TOKENS)),
297
+ )
298
+ ),
299
+ "google_api_key": arguments.get(
300
+ "google_api_key", os.getenv("GEMINI_API_KEY", "")
301
+ ),
302
+ "aws_access_key": None, # Use IAM Role instead of keys
303
+ "aws_secret_key": None, # Use IAM Role instead of keys
304
+ "aws_region": os.getenv("AWS_REGION", AWS_REGION),
305
+ "hf_token": arguments.get("hf_token", os.getenv("HF_TOKEN", "")),
306
+ "azure_api_key": arguments.get(
307
+ "azure_api_key", os.getenv("AZURE_OPENAI_API_KEY", "")
308
+ ),
309
+ "azure_endpoint": arguments.get(
310
+ "azure_endpoint", os.getenv("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
311
+ ),
312
+ "api_url": arguments.get("api_url", os.getenv("API_URL", "")),
313
+ "inference_server_model": arguments.get(
314
+ "inference_server_model", os.getenv("CHOSEN_INFERENCE_SERVER_MODEL", "")
315
+ ),
316
+ # Topic Extraction Arguments
317
+ "context": arguments.get("context", os.getenv("DIRECT_MODE_CONTEXT", "")),
318
+ "candidate_topics": arguments.get(
319
+ "candidate_topics", os.getenv("DIRECT_MODE_CANDIDATE_TOPICS", "")
320
+ ),
321
+ "force_zero_shot": arguments.get(
322
+ "force_zero_shot", os.getenv("DIRECT_MODE_FORCE_ZERO_SHOT", "No")
323
+ ),
324
+ "force_single_topic": arguments.get(
325
+ "force_single_topic", os.getenv("DIRECT_MODE_FORCE_SINGLE_TOPIC", "No")
326
+ ),
327
+ "produce_structured_summary": arguments.get(
328
+ "produce_structured_summary",
329
+ os.getenv("DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY", "No"),
330
+ ),
331
+ "sentiment": arguments.get(
332
+ "sentiment", os.getenv("DIRECT_MODE_SENTIMENT", "Negative or Positive")
333
+ ),
334
+ "additional_summary_instructions": arguments.get(
335
+ "additional_summary_instructions",
336
+ os.getenv("DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS", ""),
337
+ ),
338
+ # Validation Arguments
339
+ "additional_validation_issues": arguments.get(
340
+ "additional_validation_issues",
341
+ os.getenv("DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES", ""),
342
+ ),
343
+ "show_previous_table": arguments.get(
344
+ "show_previous_table", os.getenv("DIRECT_MODE_SHOW_PREVIOUS_TABLE", "Yes")
345
+ ),
346
+ "output_debug_files": arguments.get(
347
+ "output_debug_files", str(OUTPUT_DEBUG_FILES)
348
+ ),
349
+ "max_time_for_loop": int(
350
+ arguments.get("max_time_for_loop", os.getenv("MAX_TIME_FOR_LOOP", "99999"))
351
+ ),
352
+ # Deduplication Arguments
353
+ "method": arguments.get(
354
+ "method", os.getenv("DIRECT_MODE_DEDUPLICATION_METHOD", "fuzzy")
355
+ ),
356
+ "similarity_threshold": int(
357
+ arguments.get(
358
+ "similarity_threshold",
359
+ os.getenv("DEDUPLICATION_THRESHOLD", DEDUPLICATION_THRESHOLD),
360
+ )
361
+ ),
362
+ "merge_sentiment": arguments.get(
363
+ "merge_sentiment", os.getenv("DIRECT_MODE_MERGE_SENTIMENT", "No")
364
+ ),
365
+ "merge_general_topics": arguments.get(
366
+ "merge_general_topics", os.getenv("DIRECT_MODE_MERGE_GENERAL_TOPICS", "Yes")
367
+ ),
368
+ # Summarisation Arguments
369
+ "summary_format": arguments.get(
370
+ "summary_format", os.getenv("DIRECT_MODE_SUMMARY_FORMAT", "two_paragraph")
371
+ ),
372
+ "sample_reference_table": arguments.get(
373
+ "sample_reference_table",
374
+ os.getenv("DIRECT_MODE_SAMPLE_REFERENCE_TABLE", "True"),
375
+ ),
376
+ "no_of_sampled_summaries": int(
377
+ arguments.get(
378
+ "no_of_sampled_summaries",
379
+ os.getenv("DEFAULT_SAMPLED_SUMMARIES", DEFAULT_SAMPLED_SUMMARIES),
380
+ )
381
+ ),
382
+ "random_seed": int(
383
+ arguments.get("random_seed", os.getenv("LLM_SEED", LLM_SEED))
384
+ ),
385
+ # Output Format Arguments
386
+ "create_xlsx_output": convert_string_to_boolean(
387
+ arguments.get(
388
+ "create_xlsx_output",
389
+ os.getenv("DIRECT_MODE_CREATE_XLSX_OUTPUT", "True"),
390
+ )
391
+ ),
392
+ # Logging Arguments
393
+ "save_logs_to_csv": convert_string_to_boolean(
394
+ arguments.get(
395
+ "save_logs_to_csv", os.getenv("SAVE_LOGS_TO_CSV", str(SAVE_LOGS_TO_CSV))
396
+ )
397
+ ),
398
+ "save_logs_to_dynamodb": convert_string_to_boolean(
399
+ arguments.get(
400
+ "save_logs_to_dynamodb",
401
+ os.getenv("SAVE_LOGS_TO_DYNAMODB", str(SAVE_LOGS_TO_DYNAMODB)),
402
+ )
403
+ ),
404
+ "usage_logs_folder": arguments.get("usage_logs_folder", USAGE_LOGS_FOLDER),
405
+ "cost_code": arguments.get(
406
+ "cost_code", os.getenv("DEFAULT_COST_CODE", DEFAULT_COST_CODE)
407
+ ),
408
+ }
409
+
410
+ # Download optional files if they are specified
411
+ candidate_topics_key = arguments.get("candidate_topics_s3_key")
412
+ if candidate_topics_key:
413
+ candidate_topics_path = os.path.join(INPUT_DIR, "candidate_topics.csv")
414
+ download_file_from_s3(bucket_name, candidate_topics_key, candidate_topics_path)
415
+ cli_args["candidate_topics"] = candidate_topics_path
416
+
417
+ # Download previous output files if they are S3 keys
418
+ if cli_args["previous_output_files"]:
419
+ downloaded_previous_files = []
420
+ for prev_file in cli_args["previous_output_files"]:
421
+ if prev_file.startswith("s3://"):
422
+ # Parse S3 path
423
+ s3_path_parts = prev_file[5:].split("/", 1)
424
+ if len(s3_path_parts) == 2:
425
+ prev_bucket = s3_path_parts[0]
426
+ prev_key = s3_path_parts[1]
427
+ local_prev_path = os.path.join(
428
+ INPUT_DIR, os.path.basename(prev_key)
429
+ )
430
+ download_file_from_s3(prev_bucket, prev_key, local_prev_path)
431
+ downloaded_previous_files.append(local_prev_path)
432
+ else:
433
+ downloaded_previous_files.append(prev_file)
434
+ else:
435
+ downloaded_previous_files.append(prev_file)
436
+ cli_args["previous_output_files"] = downloaded_previous_files
437
+
438
+ # 5. Execute the main application logic
439
+ try:
440
+ print("--- Starting CLI Topics Main Function ---")
441
+ print(
442
+ f"Arguments passed to cli_main: {json.dumps({k: v for k, v in cli_args.items() if k not in ['aws_access_key', 'aws_secret_key']}, default=str)}"
443
+ )
444
+ cli_main(direct_mode_args=cli_args)
445
+ print("--- CLI Topics Main Function Finished ---")
446
+ except Exception as e:
447
+ print(f"An error occurred during CLI execution: {e}")
448
+ import traceback
449
+
450
+ traceback.print_exc()
451
+ # Optionally, re-raise the exception to make the Lambda fail
452
+ raise
453
+
454
+ # 6. Upload results back to S3
455
+ output_s3_prefix = f"output/{os.path.splitext(os.path.basename(input_key))[0]}"
456
+ print(
457
+ f"Uploading contents of {OUTPUT_DIR} to s3://{bucket_name}/{output_s3_prefix}/"
458
+ )
459
+ upload_directory_to_s3(OUTPUT_DIR, bucket_name, output_s3_prefix)
460
+
461
+ return {
462
+ "statusCode": 200,
463
+ "body": json.dumps(
464
+ f"Processing complete for {input_key}. Output saved to s3://{bucket_name}/{output_s3_prefix}/"
465
+ ),
466
+ }
load_dynamo_logs.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import datetime
3
+ from decimal import Decimal
4
+
5
+ import boto3
6
+
7
+ from tools.config import (
8
+ AWS_REGION,
9
+ OUTPUT_FOLDER,
10
+ USAGE_LOG_DYNAMODB_TABLE_NAME,
11
+ )
12
+
13
+ # Replace with your actual table name and region
14
+ TABLE_NAME = USAGE_LOG_DYNAMODB_TABLE_NAME # Choose as appropriate
15
+ REGION = AWS_REGION
16
+ CSV_OUTPUT = OUTPUT_FOLDER + "dynamodb_logs_export.csv"
17
+
18
+ # Create DynamoDB resource
19
+ dynamodb = boto3.resource("dynamodb", region_name=REGION)
20
+ table = dynamodb.Table(TABLE_NAME)
21
+
22
+
23
+ # Helper function to convert Decimal to float or int
24
+ def convert_types(item):
25
+ new_item = {}
26
+ for key, value in item.items():
27
+ # Handle Decimals first
28
+ if isinstance(value, Decimal):
29
+ new_item[key] = int(value) if value % 1 == 0 else float(value)
30
+ # Handle Strings that might be dates
31
+ elif isinstance(value, str):
32
+ try:
33
+ # Attempt to parse a common ISO 8601 format.
34
+ # The .replace() handles the 'Z' for Zulu/UTC time.
35
+ dt_obj = datetime.datetime.fromisoformat(value.replace("Z", "+00:00"))
36
+ # Now that we have a datetime object, format it as desired
37
+ new_item[key] = dt_obj.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
38
+ except (ValueError, TypeError):
39
+ # If it fails to parse, it's just a regular string
40
+ new_item[key] = value
41
+ # Handle all other types
42
+ else:
43
+ new_item[key] = value
44
+ return new_item
45
+
46
+
47
+ # Paginated scan
48
+ def scan_table():
49
+ items = []
50
+ response = table.scan()
51
+ items.extend(response["Items"])
52
+
53
+ while "LastEvaluatedKey" in response:
54
+ response = table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
55
+ items.extend(response["Items"])
56
+
57
+ return items
58
+
59
+
60
+ # Export to CSV
61
+ # Export to CSV
62
+ def export_to_csv(items, output_path, fields_to_drop: list = None):
63
+ if not items:
64
+ print("No items found.")
65
+ return
66
+
67
+ # Use a set for efficient lookup
68
+ drop_set = set(fields_to_drop or [])
69
+
70
+ # Get a comprehensive list of all possible headers from all items
71
+ all_keys = set()
72
+ for item in items:
73
+ all_keys.update(item.keys())
74
+
75
+ # Determine the final fieldnames by subtracting the ones to drop
76
+ fieldnames = sorted(list(all_keys - drop_set))
77
+
78
+ print("Final CSV columns will be:", fieldnames)
79
+
80
+ with open(output_path, "w", newline="", encoding="utf-8-sig") as csvfile:
81
+ # The key fix is here: extrasaction='ignore'
82
+ # restval='' is also good practice to handle rows that are missing a key
83
+ writer = csv.DictWriter(
84
+ csvfile, fieldnames=fieldnames, extrasaction="ignore", restval=""
85
+ )
86
+ writer.writeheader()
87
+
88
+ for item in items:
89
+ # The convert_types function can now return the full dict,
90
+ # and the writer will simply ignore the extra fields.
91
+ writer.writerow(convert_types(item))
92
+
93
+ print(f"Exported {len(items)} items to {output_path}")
94
+
95
+
96
+ # Run export
97
+ items = scan_table()
98
+ export_to_csv(
99
+ items,
100
+ CSV_OUTPUT,
101
+ fields_to_drop=["Query metadata - usage counts and other parameters"],
102
+ )
load_s3_logs.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from io import StringIO
3
+
4
+ import boto3
5
+ import pandas as pd
6
+
7
+ from tools.config import (
8
+ AWS_ACCESS_KEY,
9
+ AWS_REGION,
10
+ AWS_SECRET_KEY,
11
+ DOCUMENT_REDACTION_BUCKET,
12
+ OUTPUT_FOLDER,
13
+ )
14
+
15
+ # Combine together log files that can be then used for e.g. dashboarding and financial tracking.
16
+
17
+ # S3 setup. Try to use provided keys (needs S3 permissions), otherwise assume AWS SSO connection
18
+ if AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION:
19
+ s3 = boto3.client(
20
+ "s3",
21
+ aws_access_key_id=AWS_ACCESS_KEY,
22
+ aws_secret_access_key=AWS_SECRET_KEY,
23
+ region_name=AWS_REGION,
24
+ )
25
+ else:
26
+ s3 = boto3.client("s3")
27
+
28
+ bucket_name = DOCUMENT_REDACTION_BUCKET
29
+ prefix = "usage/" # 'feedback/' # 'logs/' # Change as needed - top-level folder where logs are stored
30
+ earliest_date = "20250409" # Earliest date of logs folder retrieved
31
+ latest_date = "20250423" # Latest date of logs folder retrieved
32
+
33
+
34
+ # Function to list all files in a folder
35
+ def list_files_in_s3(bucket, prefix):
36
+ response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
37
+ if "Contents" in response:
38
+ return [content["Key"] for content in response["Contents"]]
39
+ return []
40
+
41
+
42
+ # Function to filter date range
43
+ def is_within_date_range(date_str, start_date, end_date):
44
+ date_obj = datetime.strptime(date_str, "%Y%m%d")
45
+ return start_date <= date_obj <= end_date
46
+
47
+
48
+ # Define the date range
49
+ start_date = datetime.strptime(earliest_date, "%Y%m%d") # Replace with your start date
50
+ end_date = datetime.strptime(latest_date, "%Y%m%d") # Replace with your end date
51
+
52
+ # List all subfolders under 'usage/'
53
+ all_files = list_files_in_s3(bucket_name, prefix)
54
+
55
+ # Filter based on date range
56
+ log_files = []
57
+ for file in all_files:
58
+ parts = file.split("/")
59
+ if len(parts) >= 3:
60
+ date_str = parts[1]
61
+ if (
62
+ is_within_date_range(date_str, start_date, end_date)
63
+ and parts[-1] == "log.csv"
64
+ ):
65
+ log_files.append(file)
66
+
67
+ # Download, read and concatenate CSV files into a pandas DataFrame
68
+ df_list = []
69
+ for log_file in log_files:
70
+ # Download the file
71
+ obj = s3.get_object(Bucket=bucket_name, Key=log_file)
72
+ try:
73
+ csv_content = (
74
+ obj["Body"].read().decode("utf-8")
75
+ ) # Suggest trying latin-1 instead of utf-8 if this fails
76
+ except Exception as e:
77
+ print("Could not load in log file:", log_file, "due to:", e)
78
+ continue
79
+
80
+ # Read CSV content into pandas DataFrame
81
+ df = pd.read_csv(StringIO(csv_content))
82
+
83
+ df_list.append(df)
84
+
85
+ # Concatenate all DataFrames
86
+ if df_list:
87
+ concatenated_df = pd.concat(df_list, ignore_index=True)
88
+
89
+ # Save the concatenated DataFrame to a CSV file
90
+ concatenated_df.to_csv(OUTPUT_FOLDER + "consolidated_s3_logs.csv", index=False)
91
+ print("Consolidated CSV saved as 'consolidated_s3_logs.csv'")
92
+ else:
93
+ print("No log files found in the given date range.")
pyproject.toml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llm_topic_modelling"
7
+ version = "0.6.0"
8
+ description = "Generate thematic summaries from open text in tabular data files with a large language model."
9
+ requires-python = ">=3.10"
10
+ readme = "README.md"
11
+ authors = [
12
+ { name = "Sean Pedrick-Case", email = "spedrickcase@lambeth.gov.uk" },
13
+ ]
14
+ maintainers = [
15
+ { name = "Sean Pedrick-Case", email = "spedrickcase@lambeth.gov.uk" },
16
+ ]
17
+ keywords = [
18
+ "topic-modelling",
19
+ "topic-modeling",
20
+ "llm",
21
+ "large-language-models",
22
+ "thematic-analysis",
23
+ "text-analysis",
24
+ "nlp",
25
+ "natural-language-processing",
26
+ "text-summarization",
27
+ "text-summarisation",
28
+ "thematic-summaries",
29
+ "gradio",
30
+ "data-analysis",
31
+ "tabular-data",
32
+ "excel",
33
+ "csv",
34
+ "open-text",
35
+ "text-mining"
36
+ ]
37
+ classifiers = [
38
+ "Development Status :: 4 - Beta",
39
+ "Intended Audience :: Developers",
40
+ "Intended Audience :: Science/Research",
41
+ "Intended Audience :: Information Technology",
42
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
43
+ "Topic :: Text Processing :: Linguistic",
44
+ "Topic :: Text Processing :: Markup",
45
+ "Topic :: Scientific/Engineering :: Information Analysis",
46
+ "Programming Language :: Python :: 3",
47
+ "Programming Language :: Python :: 3.10",
48
+ "Programming Language :: Python :: 3.11",
49
+ "Programming Language :: Python :: 3.12",
50
+ "Programming Language :: Python :: 3.13",
51
+ ]
52
+
53
+ dependencies = [
54
+ "gradio==6.0.2",
55
+ "transformers==4.57.2",
56
+ "spaces==0.42.1",
57
+ "boto3==1.42.1",
58
+ "pandas<=2.3.3",
59
+ "pyarrow>=21.0.0",
60
+ "openpyxl>=3.1.5",
61
+ "markdown>=3.7",
62
+ "tabulate>=0.9.0",
63
+ "lxml>=5.3.0",
64
+ "google-genai<=1.52.0",
65
+ "openai<=2.8.1",
66
+ "html5lib>=1.1",
67
+ "beautifulsoup4>=4.12.3",
68
+ "rapidfuzz>=3.13.0",
69
+ "python-dotenv>=1.1.0"
70
+ ]
71
+
72
+ [project.optional-dependencies]
73
+ dev = ["pytest"]
74
+ test = ["pytest", "pytest-cov"]
75
+
76
+ # Extra dependencies for VLM models
77
+ # For torch you should use --index-url https://download.pytorch.org/whl/cu128. Additionally installs the unsloth package
78
+ torch = [
79
+ "torch<=2.9.1",
80
+ "torchvision",
81
+ "accelerate",
82
+ "bitsandbytes",
83
+ "unsloth==2025.11.6",
84
+ "unsloth_zoo==2025.11.6",
85
+ "timm",
86
+ "xformers"
87
+ ]
88
+
89
+ # If you want to install llama-cpp-python in GPU mode, use cmake.args="-DGGML_CUDA=on" . If that doesn't work, try specific wheels for your system, e.g. for Linux see files in https://github.com/JamePeng/llama-cpp-python/releases. More details on installation here: https://llama-cpp-python.readthedocs.io/en/latest
90
+ llamacpp = [
91
+ "llama-cpp-python>=0.3.16",
92
+ ]
93
+
94
+ # Run Gradio as an mcp server
95
+ mcp = [
96
+ "gradio[mcp]==6.0.2"
97
+ ]
98
+
99
+ [project.urls]
100
+ Homepage = "https://github.com/seanpedrick-case/llm_topic_modelling"
101
+ repository = "https://github.com/seanpedrick-case/llm_topic_modelling"
102
+
103
+ [tool.setuptools]
104
+ packages = ["tools"]
105
+ py-modules = ["app"]
106
+
107
+ # Configuration for Ruff linter:
108
+ [tool.ruff]
109
+ line-length = 88
110
+
111
+ [tool.ruff.lint]
112
+ select = ["E", "F", "I"]
113
+ ignore = [
114
+ "E501", # line-too-long (handled with Black)
115
+ "E402", # module-import-not-at-top-of-file (sometimes needed for conditional imports)
116
+ ]
117
+
118
+ [tool.ruff.lint.per-file-ignores]
119
+ "__init__.py" = ["F401"] # Allow unused imports in __init__.py
120
+
121
+ # Configuration for a Black formatter:
122
+ [tool.black]
123
+ line-length = 88
124
+ target-version = ['py310']
125
+
126
+ # Configuration for pytest:
127
+ [tool.pytest.ini_options]
128
+ filterwarnings = [
129
+ "ignore::DeprecationWarning:click.parser",
130
+ "ignore::DeprecationWarning:weasel.util.config",
131
+ "ignore::DeprecationWarning:builtin type",
132
+ "ignore::DeprecationWarning:websockets.legacy",
133
+ "ignore::DeprecationWarning:websockets.server",
134
+ "ignore::DeprecationWarning:spacy.cli._util",
135
+ "ignore::DeprecationWarning:weasel.util.config",
136
+ "ignore::DeprecationWarning:importlib._bootstrap",
137
+ ]
138
+ testpaths = ["test"]
139
+ python_files = ["test_*.py", "*_test.py"]
140
+ python_classes = ["Test*"]
141
+ python_functions = ["test_*"]
142
+ addopts = [
143
+ "-v",
144
+ "--tb=short",
145
+ "--strict-markers",
146
+ "--disable-warnings",
147
+ ]
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note that this requirements file is optimised for Hugging Face spaces / Python 3.10. Please use requirements_no_local.txt for installation without local model inference (simplest approach to get going). Please use requirements_cpu.txt for CPU instances and requirements_gpu.txt for GPU instances using Python 3.11
2
+ gradio==6.0.2
3
+ transformers==4.57.2
4
+ spaces==0.42.1
5
+ boto3>=1.42.1
6
+ pandas>=2.3.3
7
+ pyarrow>=21.0.0
8
+ openpyxl>=3.1.5
9
+ markdown>=3.7
10
+ tabulate>=0.9.0
11
+ lxml>=5.3.0
12
+ google-genai>=1.52.0
13
+ openai>=2.8.1
14
+ html5lib>=1.1
15
+ beautifulsoup4>=4.12.3
16
+ rapidfuzz>=3.13.0
17
+ python-dotenv>=1.1.0
18
+ # GPU (for huggingface instance)
19
+ # Torch/Unsloth and llama-cpp-python
20
+ # Latest compatible with CUDA 12.4
21
+ torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cu128
22
+ unsloth[cu128-torch280]<=2025.11.6
23
+ unsloth_zoo<=2025.11.6
24
+ timm
25
+ # llama-cpp-python direct wheel link for GPU compatible version 3.17 for use with Python 3.10 and Hugging Face
26
+ https://github.com/JamePeng/llama-cpp-python/releases/download/v0.3.17-cu128-Basic-linux-20251202/llama_cpp_python-0.3.17-cp310-cp310-linux_x86_64.whl
27
+ #https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
28
+
29
+
requirements_cpu.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==6.0.2
2
+ transformers==4.57.2
3
+ spaces==0.42.1
4
+ pandas>=2.3.3
5
+ boto3>=1.42.1
6
+ pyarrow>=21.0.0
7
+ openpyxl>=3.1.5
8
+ markdown>=3.7
9
+ tabulate>=0.9.0
10
+ lxml>=5.3.0
11
+ google-genai>=1.52.0
12
+ openai>=2.8.1
13
+ html5lib>=1.1
14
+ beautifulsoup4>=4.12.3
15
+ rapidfuzz>=3.13.0
16
+ python-dotenv>=1.1.0
17
+ torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cpu
18
+ llama-cpp-python==0.3.16 -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
19
+ # Direct wheel links if above doesn't work
20
+ # I have created CPU Linux, Python 3.11 compatible wheels:
21
+ # https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
22
+ # Windows, Python 3.11 compatible CPU wheels available:
23
+ # https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-win_amd64_cpu_openblas.whl
24
+ # If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
requirements_gpu.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ gradio==6.0.2
3
+ transformers==4.57.2
4
+ spaces==0.42.1
5
+ boto3>=1.42.1
6
+ pandas>=2.3.3
7
+ pyarrow>=21.0.0
8
+ openpyxl>=3.1.5
9
+ markdown>=3.7
10
+ tabulate>=0.9.0
11
+ lxml>=5.3.0
12
+ google-genai>=1.52.0
13
+ openai>=2.8.1
14
+ html5lib>=1.1
15
+ beautifulsoup4>=4.12.3
16
+ rapidfuzz>=3.13.0
17
+ python-dotenv>=1.1.0
18
+ # Torch/Unsloth
19
+ # Latest compatible with CUDA 12.4
20
+ torch<=2.9.1 --extra-index-url https://download.pytorch.org/whl/cu128
21
+ unsloth[cu128-torch280]<=2025.11.6 # Refer here for more details on installation: https://pypi.org/project/unsloth
22
+ unsloth_zoo<=2025.11.6
23
+ # Additional for Windows and CUDA 12.4 older GPUS (RTX 3x or similar):
24
+ #triton-windows<3.3
25
+ timm
26
+ # Llama CPP Python
27
+ llama-cpp-python>=0.3.16 -C cmake.args="-DGGML_CUDA=on"
28
+ # If above doesn't work, try specific wheels for your system, see files in https://github.com/JamePeng/llama-cpp-python/releases for different python versions
requirements_lightweight.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This requirements file is optimised for AWS ECS using Python 3.11 alongside the Dockerfile, without local torch and llama-cpp-python. For AWS ECS, torch and llama-cpp-python are optionally installed in the main Dockerfile
2
+ gradio==6.0.2
3
+ transformers==4.57.2
4
+ spaces==0.42.1
5
+ boto3>=1.42.1
6
+ pandas>=2.3.3
7
+ pyarrow>=21.0.0
8
+ openpyxl>=3.1.5
9
+ markdown>=3.7
10
+ tabulate>=0.9.0
11
+ lxml>=5.3.0
12
+ google-genai>=1.52.0
13
+ openai>=2.8.1
14
+ html5lib>=1.1
15
+ beautifulsoup4>=4.12.3
16
+ rapidfuzz>=3.13.0
17
+ python-dotenv>=1.1.0
18
+ awslambdaric==3.1.1
test/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test Suite for LLM Topic Modeller
2
+
3
+ This test suite provides comprehensive testing for the CLI interface (`cli_topics.py`) and GUI application (`app.py`).
4
+
5
+ ## Overview
6
+
7
+ The test suite includes:
8
+ - **CLI Tests**: Tests based on examples from the `cli_topics.py` epilog
9
+ - **GUI Tests**: Tests to verify the Gradio interface loads correctly
10
+ - **Mock Inference Server**: A dummy inference-server endpoint that avoids API costs during testing
11
+
12
+ ## Structure
13
+
14
+ - `test.py`: Main test suite with CLI tests
15
+ - `test_gui_only.py`: GUI-specific tests
16
+ - `mock_inference_server.py`: Mock HTTP server that mimics an inference-server API
17
+ - `run_tests.py`: Test runner script
18
+ - `__init__.py`: Package initialization
19
+
20
+ ## Running Tests
21
+
22
+ ### Run All Tests
23
+
24
+ From the project root directory:
25
+
26
+ ```bash
27
+ python test/run_tests.py
28
+ ```
29
+
30
+ Or from the test directory:
31
+
32
+ ```bash
33
+ python run_tests.py
34
+ ```
35
+
36
+ ### Run Only CLI Tests
37
+
38
+ ```bash
39
+ python -m unittest test.test.TestCLITopicsExamples
40
+ ```
41
+
42
+ ### Run Only GUI Tests
43
+
44
+ ```bash
45
+ python test/test_gui_only.py
46
+ ```
47
+
48
+ ## Mock Inference Server
49
+
50
+ The test suite uses a mock inference server to avoid API costs during testing. The mock server:
51
+
52
+ - Listens on `localhost:8080` by default
53
+ - Responds to `/v1/chat/completions` endpoint
54
+ - Returns valid markdown table responses that satisfy validation requirements
55
+ - Provides token counts for usage tracking
56
+
57
+ The mock server is automatically started before tests and stopped after tests complete.
58
+
59
+ ## Test Coverage
60
+
61
+ The CLI tests cover:
62
+
63
+ 1. **Topic Extraction**
64
+ - Default settings
65
+ - Custom model and context
66
+ - Grouping by column
67
+ - Zero-shot extraction with candidate topics
68
+
69
+ 2. **Topic Deduplication**
70
+ - Fuzzy matching
71
+ - LLM-based deduplication
72
+
73
+ 3. **All-in-One Pipeline**
74
+ - Complete workflow (extract, deduplicate, summarise)
75
+
76
+ ## Requirements
77
+
78
+ - Python 3.7+
79
+ - All dependencies from `requirements.txt`
80
+ - Example data files in `example_data/` directory
81
+
82
+ ## Notes
83
+
84
+ - Tests will be skipped if required example files are not found
85
+ - The mock server must be running for CLI tests to work
86
+ - Tests use temporary output directories that are cleaned up after execution
87
+
test/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Test suite for LLM Topic Modeller CLI.
3
+
4
+ This package contains tests for the CLI interface and GUI application.
5
+ """
test/mock_inference_server.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mock inference server for testing CLI topic extraction without API costs.
4
+
5
+ This server mimics an inference-server API endpoint and returns dummy
6
+ responses that satisfy the validation requirements (markdown tables with |).
7
+ """
8
+
9
+ import json
10
+ import threading
11
+ from http.server import BaseHTTPRequestHandler, HTTPServer
12
+ from typing import Optional
13
+
14
+
15
+ class MockInferenceServerHandler(BaseHTTPRequestHandler):
16
+ """HTTP request handler for the mock inference server."""
17
+
18
+ def _generate_mock_response(self, prompt: str, system_prompt: str) -> str:
19
+ """
20
+ Generate a mock response that satisfies validation requirements.
21
+
22
+ The response must:
23
+ - Be longer than 120 characters
24
+ - Contain a markdown table (with | characters)
25
+
26
+ Args:
27
+ prompt: The user prompt
28
+ system_prompt: The system prompt
29
+
30
+ Returns:
31
+ A mock markdown table response
32
+ """
33
+ # Generate a simple markdown table that satisfies the validation
34
+ # This mimics a topic extraction table response
35
+ mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
36
+ |-----------|---------------|-----------|-----------|
37
+ | 1 | Test Topic | Test Subtopic | Positive |
38
+ | 2 | Another Topic | Another Subtopic | Neutral |
39
+ | 3 | Third Topic | Third Subtopic | Negative |
40
+
41
+ This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
42
+
43
+ return mock_table
44
+
45
+ def _estimate_tokens(self, text: str) -> int:
46
+ """Estimate token count (rough approximation: ~4 characters per token)."""
47
+ return max(1, len(text) // 4)
48
+
49
+ def do_POST(self):
50
+ """Handle POST requests to /v1/chat/completions."""
51
+ print(f"[Mock Server] Received POST request to: {self.path}")
52
+ if self.path == "/v1/chat/completions":
53
+ try:
54
+ # Read request body
55
+ content_length = int(self.headers.get("Content-Length", 0))
56
+ print(f"[Mock Server] Content-Length: {content_length}")
57
+ body = self.rfile.read(content_length)
58
+ payload = json.loads(body.decode("utf-8"))
59
+ print("[Mock Server] Payload received, processing...")
60
+
61
+ # Extract messages
62
+ messages = payload.get("messages", [])
63
+ system_prompt = ""
64
+ user_prompt = ""
65
+
66
+ for msg in messages:
67
+ role = msg.get("role", "")
68
+ content = msg.get("content", "")
69
+ if role == "system":
70
+ system_prompt = content
71
+ elif role == "user":
72
+ user_prompt = content
73
+
74
+ # Generate mock response
75
+ response_text = self._generate_mock_response(user_prompt, system_prompt)
76
+
77
+ # Estimate tokens
78
+ input_tokens = self._estimate_tokens(system_prompt + "\n" + user_prompt)
79
+ output_tokens = self._estimate_tokens(response_text)
80
+
81
+ # Check if streaming is requested
82
+ stream = payload.get("stream", False)
83
+
84
+ if stream:
85
+ # Handle streaming response
86
+ self.send_response(200)
87
+ self.send_header("Content-Type", "text/event-stream")
88
+ self.send_header("Cache-Control", "no-cache")
89
+ self.send_header("Connection", "keep-alive")
90
+ self.end_headers()
91
+
92
+ # Send streaming chunks
93
+ chunk_size = 20 # Characters per chunk
94
+ for i in range(0, len(response_text), chunk_size):
95
+ chunk = response_text[i : i + chunk_size]
96
+ chunk_data = {
97
+ "choices": [
98
+ {
99
+ "delta": {"content": chunk},
100
+ "index": 0,
101
+ "finish_reason": None,
102
+ }
103
+ ]
104
+ }
105
+ self.wfile.write(f"data: {json.dumps(chunk_data)}\n\n".encode())
106
+ self.wfile.flush()
107
+
108
+ # Send final done message
109
+ self.wfile.write(b"data: [DONE]\n\n")
110
+ self.wfile.flush()
111
+ else:
112
+ # Handle non-streaming response
113
+ response_data = {
114
+ "choices": [
115
+ {
116
+ "index": 0,
117
+ "finish_reason": "stop",
118
+ "message": {
119
+ "role": "assistant",
120
+ "content": response_text,
121
+ },
122
+ }
123
+ ],
124
+ "usage": {
125
+ "prompt_tokens": input_tokens,
126
+ "completion_tokens": output_tokens,
127
+ "total_tokens": input_tokens + output_tokens,
128
+ },
129
+ }
130
+
131
+ self.send_response(200)
132
+ self.send_header("Content-Type", "application/json")
133
+ self.end_headers()
134
+ self.wfile.write(json.dumps(response_data).encode())
135
+
136
+ except Exception as e:
137
+ self.send_response(500)
138
+ self.send_header("Content-Type", "application/json")
139
+ self.end_headers()
140
+ error_response = {"error": {"message": str(e), "type": "server_error"}}
141
+ self.wfile.write(json.dumps(error_response).encode())
142
+ else:
143
+ self.send_response(404)
144
+ self.end_headers()
145
+
146
+ def log_message(self, format, *args):
147
+ """Log messages for debugging."""
148
+ # Enable logging for debugging
149
+ print(f"[Mock Server] {format % args}")
150
+
151
+
152
+ class MockInferenceServer:
153
+ """Mock inference server that can be started and stopped for testing."""
154
+
155
+ def __init__(self, host: str = "localhost", port: int = 8080):
156
+ """
157
+ Initialize the mock server.
158
+
159
+ Args:
160
+ host: Host to bind to (default: localhost)
161
+ port: Port to bind to (default: 8080)
162
+ """
163
+ self.host = host
164
+ self.port = port
165
+ self.server: Optional[HTTPServer] = None
166
+ self.server_thread: Optional[threading.Thread] = None
167
+ self.running = False
168
+
169
+ def start(self):
170
+ """Start the mock server in a separate thread."""
171
+ if self.running:
172
+ return
173
+
174
+ def run_server():
175
+ self.server = HTTPServer((self.host, self.port), MockInferenceServerHandler)
176
+ self.running = True
177
+ self.server.serve_forever()
178
+
179
+ self.server_thread = threading.Thread(target=run_server, daemon=True)
180
+ self.server_thread.start()
181
+
182
+ # Wait a moment for server to start
183
+ import time
184
+
185
+ time.sleep(0.5)
186
+
187
+ def stop(self):
188
+ """Stop the mock server."""
189
+ if self.server and self.running:
190
+ self.server.shutdown()
191
+ self.server.server_close()
192
+ self.running = False
193
+
194
+ def get_url(self) -> str:
195
+ """Get the server URL."""
196
+ return f"http://{self.host}:{self.port}"
197
+
198
+ def __enter__(self):
199
+ """Context manager entry."""
200
+ self.start()
201
+ return self
202
+
203
+ def __exit__(self, exc_type, exc_val, exc_tb):
204
+ """Context manager exit."""
205
+ self.stop()
206
+
207
+
208
+ if __name__ == "__main__":
209
+ # Test the server
210
+ print("Starting mock inference server on http://localhost:8080")
211
+ print("Press Ctrl+C to stop")
212
+
213
+ server = MockInferenceServer()
214
+ try:
215
+ server.start()
216
+ print(f"Server running at {server.get_url()}")
217
+ # Keep running
218
+ while True:
219
+ import time
220
+
221
+ time.sleep(1)
222
+ except KeyboardInterrupt:
223
+ print("\nStopping server...")
224
+ server.stop()
225
+ print("Server stopped")
test/mock_llm_calls.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mock LLM function calls for testing CLI topic extraction without API costs.
4
+
5
+ This module patches requests.post to intercept HTTP calls to inference servers
6
+ and return mock responses instead.
7
+ """
8
+
9
+ import json
10
+ import os
11
+
12
+ # Store original requests if it exists
13
+ _original_requests = None
14
+
15
+
16
+ def _generate_mock_response(prompt: str, system_prompt: str) -> str:
17
+ """
18
+ Generate a mock response that satisfies validation requirements.
19
+
20
+ The response must:
21
+ - Be longer than 120 characters
22
+ - Contain a markdown table (with | characters)
23
+
24
+ Args:
25
+ prompt: The user prompt
26
+ system_prompt: The system prompt
27
+
28
+ Returns:
29
+ A mock markdown table response
30
+ """
31
+ # Generate a simple markdown table that satisfies the validation
32
+ # This mimics a topic extraction table response
33
+ mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
34
+ |-----------|---------------|-----------|-----------|
35
+ | 1 | Test Topic | Test Subtopic | Positive |
36
+ | 2 | Another Topic | Another Subtopic | Neutral |
37
+ | 3 | Third Topic | Third Subtopic | Negative |
38
+
39
+ This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
40
+
41
+ return mock_table
42
+
43
+
44
+ def _estimate_tokens(text: str) -> int:
45
+ """Estimate token count (rough approximation: ~4 characters per token)."""
46
+ return max(1, len(text) // 4)
47
+
48
+
49
+ def mock_requests_post(url, **kwargs):
50
+ """
51
+ Mock version of requests.post that intercepts inference-server calls.
52
+
53
+ Returns a mock response object that mimics the real requests.Response.
54
+ """
55
+ # Only mock inference-server URLs
56
+ if "/v1/chat/completions" not in url:
57
+ # For non-inference-server URLs, use real requests
58
+ import requests
59
+
60
+ return requests.post(url, **kwargs)
61
+
62
+ # Extract payload
63
+ payload = kwargs.get("json", {})
64
+ messages = payload.get("messages", [])
65
+
66
+ # Extract prompts
67
+ system_prompt = ""
68
+ user_prompt = ""
69
+ for msg in messages:
70
+ role = msg.get("role", "")
71
+ content = msg.get("content", "")
72
+ if role == "system":
73
+ system_prompt = content
74
+ elif role == "user":
75
+ user_prompt = content
76
+
77
+ # Generate mock response
78
+ response_text = _generate_mock_response(user_prompt, system_prompt)
79
+
80
+ # Estimate tokens
81
+ input_tokens = _estimate_tokens(system_prompt + "\n" + user_prompt)
82
+ output_tokens = _estimate_tokens(response_text)
83
+
84
+ # Check if streaming is requested
85
+ stream = payload.get("stream", False)
86
+
87
+ if stream:
88
+ # For streaming, create a mock response with iter_lines
89
+ class MockStreamResponse:
90
+ def __init__(self, text):
91
+ self.text = text
92
+ self.status_code = 200
93
+ self.lines = []
94
+ # Simulate streaming chunks
95
+ chunk_size = 20
96
+ for i in range(0, len(text), chunk_size):
97
+ chunk = text[i : i + chunk_size]
98
+ chunk_data = {
99
+ "choices": [
100
+ {
101
+ "delta": {"content": chunk},
102
+ "index": 0,
103
+ "finish_reason": None,
104
+ }
105
+ ]
106
+ }
107
+ self.lines.append(f"data: {json.dumps(chunk_data)}\n\n".encode())
108
+ self.lines.append(b"data: [DONE]\n\n")
109
+ self._line_index = 0
110
+
111
+ def raise_for_status(self):
112
+ pass
113
+
114
+ def iter_lines(self):
115
+ for line in self.lines:
116
+ yield line
117
+
118
+ return MockStreamResponse(response_text)
119
+ else:
120
+ # For non-streaming, create a simple mock response
121
+ class MockResponse:
122
+ def __init__(self, text, input_tokens, output_tokens):
123
+ self._json_data = {
124
+ "choices": [
125
+ {
126
+ "index": 0,
127
+ "finish_reason": "stop",
128
+ "message": {
129
+ "role": "assistant",
130
+ "content": text,
131
+ },
132
+ }
133
+ ],
134
+ "usage": {
135
+ "prompt_tokens": input_tokens,
136
+ "completion_tokens": output_tokens,
137
+ "total_tokens": input_tokens + output_tokens,
138
+ },
139
+ }
140
+ self.status_code = 200
141
+
142
+ def raise_for_status(self):
143
+ pass
144
+
145
+ def json(self):
146
+ return self._json_data
147
+
148
+ return MockResponse(response_text, input_tokens, output_tokens)
149
+
150
+
151
+ def apply_mock_patches():
152
+ """
153
+ Apply patches to mock HTTP requests.
154
+ This should be called before importing modules that use requests.
155
+ """
156
+ global _original_requests
157
+
158
+ try:
159
+ import requests
160
+
161
+ _original_requests = requests.post
162
+ requests.post = mock_requests_post
163
+ print("[Mock] Patched requests.post for inference-server calls")
164
+ except ImportError:
165
+ # requests not imported yet, will be patched when imported
166
+ pass
167
+
168
+
169
+ def restore_original():
170
+ """Restore original requests.post if it was patched."""
171
+ global _original_requests
172
+ if _original_requests:
173
+ try:
174
+ import requests
175
+
176
+ requests.post = _original_requests
177
+ _original_requests = None
178
+ print("[Mock] Restored original requests.post")
179
+ except ImportError:
180
+ pass
181
+
182
+
183
+ # Auto-apply patches if TEST_MODE environment variable is set
184
+ if os.environ.get("TEST_MODE") == "1" or os.environ.get("USE_MOCK_LLM") == "1":
185
+ apply_mock_patches()
test/run_tests.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple script to run the CLI topics test suite.
4
+
5
+ This script demonstrates how to run the comprehensive test suite
6
+ that covers all the examples from the CLI epilog.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+
12
+ # Add the parent directory to the path so we can import the test module
13
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ sys.path.insert(0, parent_dir)
15
+
16
+ # Import test functions
17
+ from test.test import run_all_tests
18
+
19
+ if __name__ == "__main__":
20
+ print("Starting LLM Topic Modeller Test Suite...")
21
+ print("This will test:")
22
+ print("- CLI examples from the epilog")
23
+ print("- GUI application functionality")
24
+ print("Using a mock inference-server to avoid API costs.")
25
+ print("=" * 60)
26
+
27
+ success = run_all_tests()
28
+
29
+ if success:
30
+ print("\nπŸŽ‰ All tests passed successfully!")
31
+ sys.exit(0)
32
+ else:
33
+ print("\n❌ Some tests failed. Check the output above for details.")
34
+ sys.exit(1)
test/test.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import sys
5
+ import tempfile
6
+ import time
7
+ import unittest
8
+ from typing import List, Optional
9
+
10
+ # Mock LLM calls are automatically applied via environment variables
11
+ # No need to import - the mock patches are applied when USE_MOCK_LLM=1 is set
12
+
13
+
14
+ def run_cli_topics(
15
+ script_path: str,
16
+ task: str,
17
+ output_dir: str,
18
+ input_file: Optional[str] = None,
19
+ text_column: Optional[str] = None,
20
+ previous_output_files: Optional[List[str]] = None,
21
+ timeout: int = 600, # 10-minute timeout
22
+ # General Arguments
23
+ username: Optional[str] = None,
24
+ save_to_user_folders: Optional[bool] = None,
25
+ excel_sheets: Optional[List[str]] = None,
26
+ group_by: Optional[str] = None,
27
+ # Model Configuration
28
+ model_choice: Optional[str] = None,
29
+ temperature: Optional[float] = None,
30
+ batch_size: Optional[int] = None,
31
+ max_tokens: Optional[int] = None,
32
+ api_url: Optional[str] = None,
33
+ inference_server_model: Optional[str] = None,
34
+ # Topic Extraction Arguments
35
+ context: Optional[str] = None,
36
+ candidate_topics: Optional[str] = None,
37
+ force_zero_shot: Optional[str] = None,
38
+ force_single_topic: Optional[str] = None,
39
+ produce_structured_summary: Optional[str] = None,
40
+ sentiment: Optional[str] = None,
41
+ additional_summary_instructions: Optional[str] = None,
42
+ # Validation Arguments
43
+ additional_validation_issues: Optional[str] = None,
44
+ show_previous_table: Optional[str] = None,
45
+ output_debug_files: Optional[str] = None,
46
+ max_time_for_loop: Optional[int] = None,
47
+ # Deduplication Arguments
48
+ method: Optional[str] = None,
49
+ similarity_threshold: Optional[int] = None,
50
+ merge_sentiment: Optional[str] = None,
51
+ merge_general_topics: Optional[str] = None,
52
+ # Summarisation Arguments
53
+ summary_format: Optional[str] = None,
54
+ sample_reference_table: Optional[str] = None,
55
+ no_of_sampled_summaries: Optional[int] = None,
56
+ random_seed: Optional[int] = None,
57
+ # Output Format Arguments
58
+ create_xlsx_output: Optional[bool] = None,
59
+ # Logging Arguments
60
+ save_logs_to_csv: Optional[bool] = None,
61
+ save_logs_to_dynamodb: Optional[bool] = None,
62
+ cost_code: Optional[str] = None,
63
+ ) -> bool:
64
+ """
65
+ Executes the cli_topics.py script with specified arguments using a subprocess.
66
+
67
+ Args:
68
+ script_path (str): The path to the cli_topics.py script.
69
+ task (str): The main task to perform ('extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one').
70
+ output_dir (str): The path to the directory for output files.
71
+ input_file (str, optional): Path to the input file to process.
72
+ text_column (str, optional): Name of the text column to process.
73
+ previous_output_files (List[str], optional): Path(s) to previous output files.
74
+ timeout (int): Timeout in seconds for the subprocess.
75
+
76
+ All other arguments match the CLI arguments from cli_topics.py.
77
+
78
+ Returns:
79
+ bool: True if the script executed successfully, False otherwise.
80
+ """
81
+ # 1. Get absolute paths and perform pre-checks
82
+ script_abs_path = os.path.abspath(script_path)
83
+ output_abs_dir = os.path.abspath(output_dir)
84
+
85
+ # Handle input file based on task
86
+ if task in ["extract", "validate", "all_in_one"] and input_file is None:
87
+ raise ValueError(f"Input file is required for '{task}' task")
88
+
89
+ if input_file:
90
+ input_abs_path = os.path.abspath(input_file)
91
+ if not os.path.isfile(input_abs_path):
92
+ raise FileNotFoundError(f"Input file not found: {input_abs_path}")
93
+
94
+ if not os.path.isfile(script_abs_path):
95
+ raise FileNotFoundError(f"Script not found: {script_abs_path}")
96
+
97
+ if not os.path.isdir(output_abs_dir):
98
+ # Create the output directory if it doesn't exist
99
+ print(f"Output directory not found. Creating: {output_abs_dir}")
100
+ os.makedirs(output_abs_dir)
101
+
102
+ script_folder = os.path.dirname(script_abs_path)
103
+
104
+ # 2. Dynamically build the command list
105
+ command = [
106
+ "python",
107
+ script_abs_path,
108
+ "--output_dir",
109
+ output_abs_dir,
110
+ "--task",
111
+ task,
112
+ ]
113
+
114
+ # Add input_file only if it's not None
115
+ if input_file:
116
+ command.extend(["--input_file", input_abs_path])
117
+
118
+ # Add general arguments
119
+ if text_column:
120
+ command.extend(["--text_column", text_column])
121
+ if previous_output_files:
122
+ command.extend(["--previous_output_files"] + previous_output_files)
123
+ if username:
124
+ command.extend(["--username", username])
125
+ if save_to_user_folders is not None:
126
+ command.extend(["--save_to_user_folders", str(save_to_user_folders)])
127
+ if excel_sheets:
128
+ command.append("--excel_sheets")
129
+ command.extend(excel_sheets)
130
+ if group_by:
131
+ command.extend(["--group_by", group_by])
132
+
133
+ # Add model configuration arguments
134
+ if model_choice:
135
+ command.extend(["--model_choice", model_choice])
136
+ if temperature is not None:
137
+ command.extend(["--temperature", str(temperature)])
138
+ if batch_size is not None:
139
+ command.extend(["--batch_size", str(batch_size)])
140
+ if max_tokens is not None:
141
+ command.extend(["--max_tokens", str(max_tokens)])
142
+ if api_url:
143
+ command.extend(["--api_url", api_url])
144
+ if inference_server_model:
145
+ command.extend(["--inference_server_model", inference_server_model])
146
+
147
+ # Add topic extraction arguments
148
+ if context:
149
+ command.extend(["--context", context])
150
+ if candidate_topics:
151
+ command.extend(["--candidate_topics", candidate_topics])
152
+ if force_zero_shot:
153
+ command.extend(["--force_zero_shot", force_zero_shot])
154
+ if force_single_topic:
155
+ command.extend(["--force_single_topic", force_single_topic])
156
+ if produce_structured_summary:
157
+ command.extend(["--produce_structured_summary", produce_structured_summary])
158
+ if sentiment:
159
+ command.extend(["--sentiment", sentiment])
160
+ if additional_summary_instructions:
161
+ command.extend(
162
+ ["--additional_summary_instructions", additional_summary_instructions]
163
+ )
164
+
165
+ # Add validation arguments
166
+ if additional_validation_issues:
167
+ command.extend(["--additional_validation_issues", additional_validation_issues])
168
+ if show_previous_table:
169
+ command.extend(["--show_previous_table", show_previous_table])
170
+ if output_debug_files:
171
+ command.extend(["--output_debug_files", output_debug_files])
172
+ if max_time_for_loop is not None:
173
+ command.extend(["--max_time_for_loop", str(max_time_for_loop)])
174
+
175
+ # Add deduplication arguments
176
+ if method:
177
+ command.extend(["--method", method])
178
+ if similarity_threshold is not None:
179
+ command.extend(["--similarity_threshold", str(similarity_threshold)])
180
+ if merge_sentiment:
181
+ command.extend(["--merge_sentiment", merge_sentiment])
182
+ if merge_general_topics:
183
+ command.extend(["--merge_general_topics", merge_general_topics])
184
+
185
+ # Add summarisation arguments
186
+ if summary_format:
187
+ command.extend(["--summary_format", summary_format])
188
+ if sample_reference_table:
189
+ command.extend(["--sample_reference_table", sample_reference_table])
190
+ if no_of_sampled_summaries is not None:
191
+ command.extend(["--no_of_sampled_summaries", str(no_of_sampled_summaries)])
192
+ if random_seed is not None:
193
+ command.extend(["--random_seed", str(random_seed)])
194
+
195
+ # Add output format arguments
196
+ if create_xlsx_output is False:
197
+ command.append("--no_xlsx_output")
198
+
199
+ # Add logging arguments
200
+ if save_logs_to_csv is not None:
201
+ command.extend(["--save_logs_to_csv", str(save_logs_to_csv)])
202
+ if save_logs_to_dynamodb is not None:
203
+ command.extend(["--save_logs_to_dynamodb", str(save_logs_to_dynamodb)])
204
+ if cost_code:
205
+ command.extend(["--cost_code", cost_code])
206
+
207
+ # Filter out None values before joining
208
+ command_str = " ".join(str(arg) for arg in command if arg is not None)
209
+ print(f"Executing command: {command_str}")
210
+
211
+ # 3. Execute the command using subprocess
212
+ try:
213
+ # Use unbuffered output to avoid hanging
214
+ env = os.environ.copy()
215
+ env["PYTHONUNBUFFERED"] = "1"
216
+ # Ensure inference server is enabled for testing
217
+ env["RUN_INFERENCE_SERVER"] = "1"
218
+ # Enable mock mode
219
+ env["USE_MOCK_LLM"] = "1"
220
+ env["TEST_MODE"] = "1"
221
+
222
+ result = subprocess.Popen(
223
+ command,
224
+ stdout=subprocess.PIPE,
225
+ stderr=subprocess.STDOUT, # Combine stderr with stdout to avoid deadlocks
226
+ text=True,
227
+ cwd=script_folder, # Important for relative paths within the script
228
+ env=env,
229
+ bufsize=0, # Unbuffered
230
+ )
231
+
232
+ # Read output in real-time to avoid deadlocks
233
+ start_time = time.time()
234
+
235
+ # For Windows, we need a different approach
236
+ if sys.platform == "win32":
237
+ # On Windows, use communicate with timeout
238
+ try:
239
+ stdout, stderr = result.communicate(timeout=timeout)
240
+ except subprocess.TimeoutExpired:
241
+ result.kill()
242
+ stdout, stderr = result.communicate()
243
+ raise subprocess.TimeoutExpired(result.args, timeout)
244
+ else:
245
+ # On Unix, we can use select for real-time reading
246
+ import select
247
+
248
+ stdout_lines = []
249
+ while result.poll() is None:
250
+ ready, _, _ = select.select([result.stdout], [], [], 0.1)
251
+ if ready:
252
+ line = result.stdout.readline()
253
+ if line:
254
+ print(line.rstrip(), flush=True)
255
+ stdout_lines.append(line)
256
+ # Check timeout
257
+ if time.time() - start_time > timeout:
258
+ result.kill()
259
+ raise subprocess.TimeoutExpired(result.args, timeout)
260
+
261
+ # Read remaining output
262
+ remaining = result.stdout.read()
263
+ if remaining:
264
+ print(remaining, end="", flush=True)
265
+ stdout_lines.append(remaining)
266
+
267
+ stdout = "".join(stdout_lines)
268
+ stderr = "" # Combined with stdout
269
+
270
+ print("--- SCRIPT STDOUT ---")
271
+ if stdout:
272
+ print(stdout)
273
+ print("--- SCRIPT STDERR ---")
274
+ if stderr:
275
+ print(stderr)
276
+ print("---------------------")
277
+
278
+ # Analyze the output for errors and success indicators
279
+ analysis = analyze_test_output(stdout, stderr)
280
+
281
+ if analysis["has_errors"]:
282
+ print("❌ Errors detected in output:")
283
+ for i, error_type in enumerate(analysis["error_types"]):
284
+ print(f" {i+1}. {error_type}")
285
+ if analysis["error_messages"]:
286
+ print(" Error messages:")
287
+ for msg in analysis["error_messages"][
288
+ :3
289
+ ]: # Show first 3 error messages
290
+ print(f" - {msg}")
291
+ return False
292
+ elif result.returncode == 0:
293
+ success_msg = "βœ… Script executed successfully."
294
+ if analysis["success_indicators"]:
295
+ success_msg += f" (Success indicators: {', '.join(analysis['success_indicators'][:3])})"
296
+ print(success_msg)
297
+ return True
298
+ else:
299
+ print(f"❌ Command failed with return code {result.returncode}")
300
+ return False
301
+
302
+ except subprocess.TimeoutExpired:
303
+ result.kill()
304
+ print(f"❌ Subprocess timed out after {timeout} seconds.")
305
+ return False
306
+ except Exception as e:
307
+ print(f"❌ An unexpected error occurred: {e}")
308
+ return False
309
+
310
+
311
+ def analyze_test_output(stdout: str, stderr: str) -> dict:
312
+ """
313
+ Analyze test output to provide detailed error information.
314
+
315
+ Args:
316
+ stdout (str): Standard output from the test
317
+ stderr (str): Standard error from the test
318
+
319
+ Returns:
320
+ dict: Analysis results with error details
321
+ """
322
+ combined_output = (stdout or "") + (stderr or "")
323
+
324
+ analysis = {
325
+ "has_errors": False,
326
+ "error_types": [],
327
+ "error_messages": [],
328
+ "success_indicators": [],
329
+ "warning_indicators": [],
330
+ }
331
+
332
+ # Error patterns
333
+ error_patterns = {
334
+ "An error occurred": "General error message",
335
+ "Error:": "Error prefix",
336
+ "Exception:": "Exception occurred",
337
+ "Traceback": "Python traceback",
338
+ "Failed to": "Operation failure",
339
+ "Cannot": "Operation not possible",
340
+ "Unable to": "Operation not possible",
341
+ "KeyError:": "Missing key/dictionary error",
342
+ "AttributeError:": "Missing attribute error",
343
+ "TypeError:": "Type mismatch error",
344
+ "ValueError:": "Invalid value error",
345
+ "FileNotFoundError:": "File not found",
346
+ "ImportError:": "Import failure",
347
+ "ModuleNotFoundError:": "Module not found",
348
+ }
349
+
350
+ # Success indicators
351
+ success_patterns = [
352
+ "Successfully",
353
+ "Completed",
354
+ "Finished",
355
+ "Processed",
356
+ "Complete",
357
+ "Output files saved",
358
+ ]
359
+
360
+ # Warning indicators
361
+ warning_patterns = ["Warning:", "WARNING:", "Deprecated", "DeprecationWarning"]
362
+
363
+ # Check for errors
364
+ for pattern, description in error_patterns.items():
365
+ if pattern.lower() in combined_output.lower():
366
+ analysis["has_errors"] = True
367
+ analysis["error_types"].append(description)
368
+
369
+ # Extract the actual error message
370
+ lines = combined_output.split("\n")
371
+ for line in lines:
372
+ if pattern.lower() in line.lower():
373
+ analysis["error_messages"].append(line.strip())
374
+
375
+ # Check for success indicators
376
+ for pattern in success_patterns:
377
+ if pattern.lower() in combined_output.lower():
378
+ analysis["success_indicators"].append(pattern)
379
+
380
+ # Check for warnings
381
+ for pattern in warning_patterns:
382
+ if pattern.lower() in combined_output.lower():
383
+ analysis["warning_indicators"].append(pattern)
384
+
385
+ return analysis
386
+
387
+
388
+ def run_app_direct_mode(
389
+ app_path: str,
390
+ task: str,
391
+ output_dir: str,
392
+ input_file: Optional[str] = None,
393
+ text_column: Optional[str] = None,
394
+ previous_output_files: Optional[List[str]] = None,
395
+ timeout: int = 600,
396
+ # General Arguments
397
+ username: Optional[str] = None,
398
+ save_to_user_folders: Optional[bool] = None,
399
+ excel_sheets: Optional[List[str]] = None,
400
+ group_by: Optional[str] = None,
401
+ # Model Configuration
402
+ model_choice: Optional[str] = None,
403
+ temperature: Optional[float] = None,
404
+ batch_size: Optional[int] = None,
405
+ max_tokens: Optional[int] = None,
406
+ api_url: Optional[str] = None,
407
+ inference_server_model: Optional[str] = None,
408
+ # Topic Extraction Arguments
409
+ context: Optional[str] = None,
410
+ candidate_topics: Optional[str] = None,
411
+ force_zero_shot: Optional[str] = None,
412
+ force_single_topic: Optional[str] = None,
413
+ produce_structured_summary: Optional[str] = None,
414
+ sentiment: Optional[str] = None,
415
+ additional_summary_instructions: Optional[str] = None,
416
+ # Validation Arguments
417
+ additional_validation_issues: Optional[str] = None,
418
+ show_previous_table: Optional[str] = None,
419
+ output_debug_files: Optional[str] = None,
420
+ max_time_for_loop: Optional[int] = None,
421
+ # Deduplication Arguments
422
+ method: Optional[str] = None,
423
+ similarity_threshold: Optional[int] = None,
424
+ merge_sentiment: Optional[str] = None,
425
+ merge_general_topics: Optional[str] = None,
426
+ # Summarisation Arguments
427
+ summary_format: Optional[str] = None,
428
+ sample_reference_table: Optional[str] = None,
429
+ no_of_sampled_summaries: Optional[int] = None,
430
+ random_seed: Optional[int] = None,
431
+ # Output Format Arguments
432
+ create_xlsx_output: Optional[bool] = None,
433
+ # Logging Arguments
434
+ save_logs_to_csv: Optional[bool] = None,
435
+ save_logs_to_dynamodb: Optional[bool] = None,
436
+ cost_code: Optional[str] = None,
437
+ ) -> bool:
438
+ """
439
+ Executes the app.py script in direct mode with specified environment variables.
440
+
441
+ Args:
442
+ app_path (str): The path to the app.py script.
443
+ task (str): The main task to perform ('extract', 'validate', 'deduplicate', 'summarise', 'overall_summary', or 'all_in_one').
444
+ output_dir (str): The path to the directory for output files.
445
+ input_file (str, optional): Path to the input file to process.
446
+ text_column (str, optional): Name of the text column to process.
447
+ previous_output_files (List[str], optional): Path(s) to previous output files.
448
+ timeout (int): Timeout in seconds for the subprocess.
449
+
450
+ All other arguments match the CLI arguments from cli_topics.py, but are set as environment variables.
451
+
452
+ Returns:
453
+ bool: True if the script executed successfully, False otherwise.
454
+ """
455
+ # 1. Get absolute paths and perform pre-checks
456
+ app_abs_path = os.path.abspath(app_path)
457
+ output_abs_dir = os.path.abspath(output_dir)
458
+
459
+ # Handle input file based on task
460
+ if task in ["extract", "validate", "all_in_one"] and input_file is None:
461
+ raise ValueError(f"Input file is required for '{task}' task")
462
+
463
+ if input_file:
464
+ input_abs_path = os.path.abspath(input_file)
465
+ if not os.path.isfile(input_abs_path):
466
+ raise FileNotFoundError(f"Input file not found: {input_abs_path}")
467
+
468
+ if not os.path.isfile(app_abs_path):
469
+ raise FileNotFoundError(f"App script not found: {app_abs_path}")
470
+
471
+ if not os.path.isdir(output_abs_dir):
472
+ # Create the output directory if it doesn't exist
473
+ print(f"Output directory not found. Creating: {output_abs_dir}")
474
+ os.makedirs(output_abs_dir)
475
+
476
+ script_folder = os.path.dirname(app_abs_path)
477
+
478
+ # 2. Build environment variables for direct mode
479
+ env = os.environ.copy()
480
+ env["PYTHONUNBUFFERED"] = "1"
481
+ env["RUN_INFERENCE_SERVER"] = "1"
482
+ env["USE_MOCK_LLM"] = "1"
483
+ env["TEST_MODE"] = "1"
484
+
485
+ # Enable direct mode
486
+ env["RUN_DIRECT_MODE"] = "1"
487
+
488
+ # Task selection
489
+ env["DIRECT_MODE_TASK"] = task
490
+
491
+ # General arguments
492
+ if input_file:
493
+ # Use pipe separator to handle file paths with spaces
494
+ env["DIRECT_MODE_INPUT_FILE"] = input_abs_path
495
+ env["DIRECT_MODE_OUTPUT_DIR"] = output_abs_dir
496
+ if text_column:
497
+ env["DIRECT_MODE_TEXT_COLUMN"] = text_column
498
+ if previous_output_files:
499
+ # Use pipe separator to handle file paths with spaces
500
+ env["DIRECT_MODE_PREVIOUS_OUTPUT_FILES"] = "|".join(previous_output_files)
501
+ if username:
502
+ env["DIRECT_MODE_USERNAME"] = username
503
+ if save_to_user_folders is not None:
504
+ env["SESSION_OUTPUT_FOLDER"] = str(save_to_user_folders)
505
+ if excel_sheets:
506
+ env["DIRECT_MODE_EXCEL_SHEETS"] = ",".join(excel_sheets)
507
+ if group_by:
508
+ env["DIRECT_MODE_GROUP_BY"] = group_by
509
+
510
+ # Model configuration
511
+ if model_choice:
512
+ env["DIRECT_MODE_MODEL_CHOICE"] = model_choice
513
+ if temperature is not None:
514
+ env["DIRECT_MODE_TEMPERATURE"] = str(temperature)
515
+ if batch_size is not None:
516
+ env["DIRECT_MODE_BATCH_SIZE"] = str(batch_size)
517
+ if max_tokens is not None:
518
+ env["DIRECT_MODE_MAX_TOKENS"] = str(max_tokens)
519
+ if api_url:
520
+ env["API_URL"] = api_url
521
+ if inference_server_model:
522
+ env["DIRECT_MODE_INFERENCE_SERVER_MODEL"] = inference_server_model
523
+
524
+ # Topic extraction arguments
525
+ if context:
526
+ env["DIRECT_MODE_CONTEXT"] = context
527
+ if candidate_topics:
528
+ env["DIRECT_MODE_CANDIDATE_TOPICS"] = candidate_topics
529
+ if force_zero_shot:
530
+ env["DIRECT_MODE_FORCE_ZERO_SHOT"] = force_zero_shot
531
+ if force_single_topic:
532
+ env["DIRECT_MODE_FORCE_SINGLE_TOPIC"] = force_single_topic
533
+ if produce_structured_summary:
534
+ env["DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY"] = produce_structured_summary
535
+ if sentiment:
536
+ env["DIRECT_MODE_SENTIMENT"] = sentiment
537
+ if additional_summary_instructions:
538
+ env["DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS"] = (
539
+ additional_summary_instructions
540
+ )
541
+
542
+ # Validation arguments
543
+ if additional_validation_issues:
544
+ env["DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES"] = additional_validation_issues
545
+ if show_previous_table:
546
+ env["DIRECT_MODE_SHOW_PREVIOUS_TABLE"] = show_previous_table
547
+ if output_debug_files:
548
+ env["OUTPUT_DEBUG_FILES"] = output_debug_files
549
+ if max_time_for_loop is not None:
550
+ env["DIRECT_MODE_MAX_TIME_FOR_LOOP"] = str(max_time_for_loop)
551
+
552
+ # Deduplication arguments
553
+ if method:
554
+ env["DIRECT_MODE_DEDUP_METHOD"] = method
555
+ if similarity_threshold is not None:
556
+ env["DIRECT_MODE_SIMILARITY_THRESHOLD"] = str(similarity_threshold)
557
+ if merge_sentiment:
558
+ env["DIRECT_MODE_MERGE_SENTIMENT"] = merge_sentiment
559
+ if merge_general_topics:
560
+ env["DIRECT_MODE_MERGE_GENERAL_TOPICS"] = merge_general_topics
561
+
562
+ # Summarisation arguments
563
+ if summary_format:
564
+ env["DIRECT_MODE_SUMMARY_FORMAT"] = summary_format
565
+ if sample_reference_table:
566
+ env["DIRECT_MODE_SAMPLE_REFERENCE_TABLE"] = sample_reference_table
567
+ if no_of_sampled_summaries is not None:
568
+ env["DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES"] = str(no_of_sampled_summaries)
569
+ if random_seed is not None:
570
+ env["DIRECT_MODE_RANDOM_SEED"] = str(random_seed)
571
+
572
+ # Output format arguments
573
+ if create_xlsx_output is not None:
574
+ env["DIRECT_MODE_CREATE_XLSX_OUTPUT"] = str(create_xlsx_output)
575
+
576
+ # Logging arguments
577
+ if save_logs_to_csv is not None:
578
+ env["SAVE_LOGS_TO_CSV"] = str(save_logs_to_csv)
579
+ if save_logs_to_dynamodb is not None:
580
+ env["SAVE_LOGS_TO_DYNAMODB"] = str(save_logs_to_dynamodb)
581
+ if cost_code:
582
+ env["DEFAULT_COST_CODE"] = cost_code
583
+
584
+ # 3. Build command (just run app.py, no arguments needed in direct mode)
585
+ command = ["python", app_abs_path]
586
+ command_str = " ".join(str(arg) for arg in command)
587
+ print(f"Executing direct mode command: {command_str}")
588
+ print(f"Direct mode task: {task}")
589
+ if input_file:
590
+ print(f"Input file: {input_abs_path}")
591
+ if text_column:
592
+ print(f"Text column: {text_column}")
593
+
594
+ # 4. Execute the command using subprocess
595
+ try:
596
+ result = subprocess.Popen(
597
+ command,
598
+ stdout=subprocess.PIPE,
599
+ stderr=subprocess.STDOUT, # Combine stderr with stdout to avoid deadlocks
600
+ text=True,
601
+ cwd=script_folder, # Important for relative paths within the script
602
+ env=env,
603
+ bufsize=0, # Unbuffered
604
+ )
605
+
606
+ # Read output in real-time to avoid deadlocks
607
+ start_time = time.time()
608
+
609
+ # For Windows, we need a different approach
610
+ if sys.platform == "win32":
611
+ # On Windows, use communicate with timeout
612
+ try:
613
+ stdout, stderr = result.communicate(timeout=timeout)
614
+ except subprocess.TimeoutExpired:
615
+ result.kill()
616
+ stdout, stderr = result.communicate()
617
+ raise subprocess.TimeoutExpired(result.args, timeout)
618
+ else:
619
+ # On Unix, we can use select for real-time reading
620
+ import select
621
+
622
+ stdout_lines = []
623
+ while result.poll() is None:
624
+ ready, _, _ = select.select([result.stdout], [], [], 0.1)
625
+ if ready:
626
+ line = result.stdout.readline()
627
+ if line:
628
+ print(line.rstrip(), flush=True)
629
+ stdout_lines.append(line)
630
+ # Check timeout
631
+ if time.time() - start_time > timeout:
632
+ result.kill()
633
+ raise subprocess.TimeoutExpired(result.args, timeout)
634
+
635
+ # Read remaining output
636
+ remaining = result.stdout.read()
637
+ if remaining:
638
+ print(remaining, end="", flush=True)
639
+ stdout_lines.append(remaining)
640
+
641
+ stdout = "".join(stdout_lines)
642
+ stderr = "" # Combined with stdout
643
+
644
+ print("--- SCRIPT STDOUT ---")
645
+ if stdout:
646
+ print(stdout)
647
+ print("--- SCRIPT STDERR ---")
648
+ if stderr:
649
+ print(stderr)
650
+ print("---------------------")
651
+
652
+ # Analyze the output for errors and success indicators
653
+ analysis = analyze_test_output(stdout, stderr)
654
+
655
+ if analysis["has_errors"]:
656
+ print("❌ Errors detected in output:")
657
+ for i, error_type in enumerate(analysis["error_types"]):
658
+ print(f" {i+1}. {error_type}")
659
+ if analysis["error_messages"]:
660
+ print(" Error messages:")
661
+ for msg in analysis["error_messages"][
662
+ :3
663
+ ]: # Show first 3 error messages
664
+ print(f" - {msg}")
665
+ return False
666
+ elif result.returncode == 0:
667
+ success_msg = "βœ… Script executed successfully."
668
+ if analysis["success_indicators"]:
669
+ success_msg += f" (Success indicators: {', '.join(analysis['success_indicators'][:3])})"
670
+ print(success_msg)
671
+ return True
672
+ else:
673
+ print(f"❌ Command failed with return code {result.returncode}")
674
+ return False
675
+
676
+ except subprocess.TimeoutExpired:
677
+ result.kill()
678
+ print(f"❌ Subprocess timed out after {timeout} seconds.")
679
+ return False
680
+ except Exception as e:
681
+ print(f"❌ An unexpected error occurred: {e}")
682
+ return False
683
+
684
+
685
+ class TestCLITopicsExamples(unittest.TestCase):
686
+ """Test suite for CLI topic extraction examples from the epilog."""
687
+
688
+ @classmethod
689
+ def setUpClass(cls):
690
+ """Set up test environment before running tests."""
691
+ cls.script_path = os.path.join(
692
+ os.path.dirname(os.path.dirname(__file__)), "cli_topics.py"
693
+ )
694
+ cls.example_data_dir = os.path.join(
695
+ os.path.dirname(os.path.dirname(__file__)), "example_data"
696
+ )
697
+ cls.temp_output_dir = tempfile.mkdtemp(prefix="test_output_")
698
+
699
+ # Verify script exists
700
+ if not os.path.isfile(cls.script_path):
701
+ raise FileNotFoundError(f"CLI script not found: {cls.script_path}")
702
+
703
+ print(f"Test setup complete. Script: {cls.script_path}")
704
+ print(f"Example data directory: {cls.example_data_dir}")
705
+ print(f"Temp output directory: {cls.temp_output_dir}")
706
+ print("Using function mocking instead of HTTP server")
707
+
708
+ # Debug: Check if example data directory exists and list contents
709
+ if os.path.exists(cls.example_data_dir):
710
+ print("Example data directory exists. Contents:")
711
+ for item in os.listdir(cls.example_data_dir):
712
+ item_path = os.path.join(cls.example_data_dir, item)
713
+ if os.path.isfile(item_path):
714
+ print(f" File: {item} ({os.path.getsize(item_path)} bytes)")
715
+ else:
716
+ print(f" Directory: {item}")
717
+ else:
718
+ print(f"Example data directory does not exist: {cls.example_data_dir}")
719
+
720
+ @classmethod
721
+ def tearDownClass(cls):
722
+ """Clean up test environment after running tests."""
723
+ if os.path.exists(cls.temp_output_dir):
724
+ shutil.rmtree(cls.temp_output_dir)
725
+ print(f"Cleaned up temp directory: {cls.temp_output_dir}")
726
+
727
+ def test_extract_topics_default_settings(self):
728
+ """Test: Extract topics from a CSV file with default settings"""
729
+ print("\n=== Testing topic extraction with default settings ===")
730
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
731
+
732
+ if not os.path.isfile(input_file):
733
+ self.skipTest(f"Example file not found: {input_file}")
734
+
735
+ result = run_cli_topics(
736
+ script_path=self.script_path,
737
+ task="extract",
738
+ input_file=input_file,
739
+ text_column="Case Note",
740
+ output_dir=self.temp_output_dir,
741
+ model_choice="test-model",
742
+ inference_server_model="test-model",
743
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
744
+ create_xlsx_output=False,
745
+ save_logs_to_csv=False,
746
+ )
747
+
748
+ self.assertTrue(result, "Topic extraction with default settings should succeed")
749
+ print("βœ… Topic extraction with default settings passed")
750
+
751
+ def test_extract_topics_custom_model_and_context(self):
752
+ """Test: Extract topics with custom model and context"""
753
+ print("\n=== Testing topic extraction with custom model and context ===")
754
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
755
+
756
+ if not os.path.isfile(input_file):
757
+ self.skipTest(f"Example file not found: {input_file}")
758
+
759
+ result = run_cli_topics(
760
+ script_path=self.script_path,
761
+ task="extract",
762
+ input_file=input_file,
763
+ text_column="Case Note",
764
+ output_dir=self.temp_output_dir,
765
+ model_choice="test-model",
766
+ inference_server_model="test-model",
767
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
768
+ context="Social Care case notes for young people",
769
+ create_xlsx_output=False,
770
+ save_logs_to_csv=False,
771
+ )
772
+
773
+ self.assertTrue(
774
+ result, "Topic extraction with custom model and context should succeed"
775
+ )
776
+ print("βœ… Topic extraction with custom model and context passed")
777
+
778
+ def test_extract_topics_with_grouping(self):
779
+ """Test: Extract topics with grouping"""
780
+ print("\n=== Testing topic extraction with grouping ===")
781
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
782
+
783
+ if not os.path.isfile(input_file):
784
+ self.skipTest(f"Example file not found: {input_file}")
785
+
786
+ result = run_cli_topics(
787
+ script_path=self.script_path,
788
+ task="extract",
789
+ input_file=input_file,
790
+ text_column="Case Note",
791
+ output_dir=self.temp_output_dir,
792
+ group_by="Client",
793
+ model_choice="test-model",
794
+ inference_server_model="test-model",
795
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
796
+ create_xlsx_output=False,
797
+ save_logs_to_csv=False,
798
+ )
799
+
800
+ self.assertTrue(result, "Topic extraction with grouping should succeed")
801
+ print("βœ… Topic extraction with grouping passed")
802
+
803
+ def test_extract_topics_with_candidate_topics(self):
804
+ """Test: Extract topics with candidate topics (zero-shot)"""
805
+ print("\n=== Testing topic extraction with candidate topics ===")
806
+ input_file = os.path.join(
807
+ self.example_data_dir, "dummy_consultation_response.csv"
808
+ )
809
+ candidate_topics_file = os.path.join(
810
+ self.example_data_dir, "dummy_consultation_response_themes.csv"
811
+ )
812
+
813
+ if not os.path.isfile(input_file):
814
+ self.skipTest(f"Example file not found: {input_file}")
815
+ if not os.path.isfile(candidate_topics_file):
816
+ self.skipTest(f"Candidate topics file not found: {candidate_topics_file}")
817
+
818
+ result = run_cli_topics(
819
+ script_path=self.script_path,
820
+ task="extract",
821
+ input_file=input_file,
822
+ text_column="Response text",
823
+ output_dir=self.temp_output_dir,
824
+ candidate_topics=candidate_topics_file,
825
+ model_choice="test-model",
826
+ inference_server_model="test-model",
827
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
828
+ create_xlsx_output=False,
829
+ save_logs_to_csv=False,
830
+ )
831
+
832
+ self.assertTrue(result, "Topic extraction with candidate topics should succeed")
833
+ print("βœ… Topic extraction with candidate topics passed")
834
+
835
+ def test_deduplicate_topics_fuzzy(self):
836
+ """Test: Deduplicate topics using fuzzy matching"""
837
+ print("\n=== Testing topic deduplication with fuzzy matching ===")
838
+
839
+ # First, we need to create some output files by running extraction
840
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
841
+
842
+ if not os.path.isfile(input_file):
843
+ self.skipTest(f"Example file not found: {input_file}")
844
+
845
+ # Run extraction first to create output files
846
+ extract_result = run_cli_topics(
847
+ script_path=self.script_path,
848
+ task="extract",
849
+ input_file=input_file,
850
+ text_column="Case Note",
851
+ output_dir=self.temp_output_dir,
852
+ model_choice="test-model",
853
+ inference_server_model="test-model",
854
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
855
+ create_xlsx_output=False,
856
+ save_logs_to_csv=False,
857
+ )
858
+
859
+ if not extract_result:
860
+ self.skipTest("Extraction failed, cannot test deduplication")
861
+
862
+ # Find the output files (they should be in temp_output_dir)
863
+ # The file names follow a pattern like: {input_file_name}_col_{text_column}_reference_table.csv
864
+ import glob
865
+
866
+ reference_files = glob.glob(
867
+ os.path.join(self.temp_output_dir, "*reference_table.csv")
868
+ )
869
+ unique_files = glob.glob(
870
+ os.path.join(self.temp_output_dir, "*unique_topics.csv")
871
+ )
872
+
873
+ if not reference_files or not unique_files:
874
+ self.skipTest("Could not find output files from extraction")
875
+
876
+ result = run_cli_topics(
877
+ script_path=self.script_path,
878
+ task="deduplicate",
879
+ previous_output_files=[reference_files[0], unique_files[0]],
880
+ output_dir=self.temp_output_dir,
881
+ method="fuzzy",
882
+ similarity_threshold=90,
883
+ create_xlsx_output=False,
884
+ save_logs_to_csv=False,
885
+ )
886
+
887
+ self.assertTrue(
888
+ result, "Topic deduplication with fuzzy matching should succeed"
889
+ )
890
+ print("βœ… Topic deduplication with fuzzy matching passed")
891
+
892
+ def test_deduplicate_topics_llm(self):
893
+ """Test: Deduplicate topics using LLM"""
894
+ print("\n=== Testing topic deduplication with LLM ===")
895
+
896
+ # First, we need to create some output files by running extraction
897
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
898
+
899
+ if not os.path.isfile(input_file):
900
+ self.skipTest(f"Example file not found: {input_file}")
901
+
902
+ # Run extraction first to create output files
903
+ extract_result = run_cli_topics(
904
+ script_path=self.script_path,
905
+ task="extract",
906
+ input_file=input_file,
907
+ text_column="Case Note",
908
+ output_dir=self.temp_output_dir,
909
+ model_choice="test-model",
910
+ inference_server_model="test-model",
911
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
912
+ create_xlsx_output=False,
913
+ save_logs_to_csv=False,
914
+ )
915
+
916
+ if not extract_result:
917
+ self.skipTest("Extraction failed, cannot test deduplication")
918
+
919
+ # Find the output files
920
+ import glob
921
+
922
+ reference_files = glob.glob(
923
+ os.path.join(self.temp_output_dir, "*reference_table.csv")
924
+ )
925
+ unique_files = glob.glob(
926
+ os.path.join(self.temp_output_dir, "*unique_topics.csv")
927
+ )
928
+
929
+ if not reference_files or not unique_files:
930
+ self.skipTest("Could not find output files from extraction")
931
+
932
+ result = run_cli_topics(
933
+ script_path=self.script_path,
934
+ task="deduplicate",
935
+ previous_output_files=[reference_files[0], unique_files[0]],
936
+ output_dir=self.temp_output_dir,
937
+ method="llm",
938
+ model_choice="test-model",
939
+ inference_server_model="test-model",
940
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
941
+ create_xlsx_output=False,
942
+ save_logs_to_csv=False,
943
+ )
944
+
945
+ self.assertTrue(result, "Topic deduplication with LLM should succeed")
946
+ print("βœ… Topic deduplication with LLM passed")
947
+
948
+ def test_all_in_one_pipeline(self):
949
+ """Test: Run complete pipeline (extract, deduplicate, summarise)"""
950
+ print("\n=== Testing all-in-one pipeline ===")
951
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
952
+
953
+ if not os.path.isfile(input_file):
954
+ self.skipTest(f"Example file not found: {input_file}")
955
+
956
+ result = run_cli_topics(
957
+ script_path=self.script_path,
958
+ task="all_in_one",
959
+ input_file=input_file,
960
+ text_column="Case Note",
961
+ output_dir=self.temp_output_dir,
962
+ model_choice="test-model",
963
+ inference_server_model="test-model",
964
+ api_url="http://localhost:8080", # URL doesn't matter with function mocking
965
+ create_xlsx_output=False,
966
+ save_logs_to_csv=False,
967
+ timeout=120, # Shorter timeout for debugging
968
+ )
969
+
970
+ self.assertTrue(result, "All-in-one pipeline should succeed")
971
+ print("βœ… All-in-one pipeline passed")
972
+
973
+ def test_direct_mode_extract(self):
974
+ """Test: Run app in direct mode for topic extraction"""
975
+ print("\n=== Testing direct mode - topic extraction ===")
976
+ input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
977
+
978
+ if not os.path.isfile(input_file):
979
+ self.skipTest(f"Example file not found: {input_file}")
980
+
981
+ app_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "app.py")
982
+
983
+ if not os.path.isfile(app_path):
984
+ self.skipTest(f"App script not found: {app_path}")
985
+
986
+ result = run_app_direct_mode(
987
+ app_path=app_path,
988
+ task="extract",
989
+ input_file=input_file,
990
+ text_column="Case Note",
991
+ output_dir=self.temp_output_dir,
992
+ model_choice="test-model",
993
+ inference_server_model="test-model",
994
+ api_url="http://localhost:8080",
995
+ create_xlsx_output=False,
996
+ save_logs_to_csv=False,
997
+ )
998
+
999
+ self.assertTrue(result, "Direct mode topic extraction should succeed")
1000
+ print("βœ… Direct mode topic extraction passed")
1001
+
1002
+
1003
+ def run_all_tests():
1004
+ """Run all test examples and report results."""
1005
+ print("=" * 80)
1006
+ print("LLM TOPIC MODELLER TEST SUITE")
1007
+ print("=" * 80)
1008
+ print("This test suite includes:")
1009
+ print("- CLI examples from the epilog")
1010
+ print("- GUI application tests")
1011
+ print("- Tests use a mock inference-server to avoid API costs")
1012
+ print("Tests will be skipped if required example files are not found.")
1013
+ print("=" * 80)
1014
+
1015
+ # Create test suite
1016
+ loader = unittest.TestLoader()
1017
+ suite = unittest.TestSuite()
1018
+
1019
+ # Add CLI tests
1020
+ cli_suite = loader.loadTestsFromTestCase(TestCLITopicsExamples)
1021
+ suite.addTests(cli_suite)
1022
+
1023
+ # Add GUI tests
1024
+ try:
1025
+ from test.test_gui_only import TestGUIAppOnly
1026
+
1027
+ gui_suite = loader.loadTestsFromTestCase(TestGUIAppOnly)
1028
+ suite.addTests(gui_suite)
1029
+ print("GUI tests included in test suite.")
1030
+ except ImportError as e:
1031
+ print(f"Warning: Could not import GUI tests: {e}")
1032
+ print("Skipping GUI tests.")
1033
+
1034
+ # Run tests with detailed output
1035
+ runner = unittest.TextTestRunner(verbosity=2, stream=None)
1036
+ result = runner.run(suite)
1037
+
1038
+ # Print summary
1039
+ print("\n" + "=" * 80)
1040
+ print("TEST SUMMARY")
1041
+ print("=" * 80)
1042
+ print(f"Tests run: {result.testsRun}")
1043
+ print(f"Failures: {len(result.failures)}")
1044
+ print(f"Errors: {len(result.errors)}")
1045
+ print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
1046
+
1047
+ if result.failures:
1048
+ print("\nFAILURES:")
1049
+ for test, traceback in result.failures:
1050
+ print(f"- {test}: {traceback}")
1051
+
1052
+ if result.errors:
1053
+ print("\nERRORS:")
1054
+ for test, traceback in result.errors:
1055
+ print(f"- {test}: {traceback}")
1056
+
1057
+ success = len(result.failures) == 0 and len(result.errors) == 0
1058
+ print(f"\nOverall result: {'βœ… PASSED' if success else '❌ FAILED'}")
1059
+ print("=" * 80)
1060
+
1061
+ return success
1062
+
1063
+
1064
+ if __name__ == "__main__":
1065
+ # Run the test suite
1066
+ success = run_all_tests()
1067
+ exit(0 if success else 1)
test/test_gui_only.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone GUI test script for the LLM topic modeller application.
4
+
5
+ This script tests only the GUI functionality of app.py to ensure it loads correctly.
6
+ Run this script to verify that the Gradio interface can be imported and initialized.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import threading
12
+ import unittest
13
+
14
+ # Add the parent directory to the path so we can import the app
15
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
16
+ if parent_dir not in sys.path:
17
+ sys.path.insert(0, parent_dir)
18
+
19
+
20
+ class TestGUIAppOnly(unittest.TestCase):
21
+ """Test suite for GUI application loading and basic functionality."""
22
+
23
+ @classmethod
24
+ def setUpClass(cls):
25
+ """Set up test environment for GUI tests."""
26
+ cls.app_path = os.path.join(parent_dir, "app.py")
27
+
28
+ # Verify app.py exists
29
+ if not os.path.isfile(cls.app_path):
30
+ raise FileNotFoundError(f"App file not found: {cls.app_path}")
31
+
32
+ print(f"GUI test setup complete. App: {cls.app_path}")
33
+
34
+ def test_app_import_and_initialization(self):
35
+ """Test: Import app.py and check if the Gradio app object is created successfully."""
36
+ print("\n=== Testing GUI app import and initialization ===")
37
+
38
+ try:
39
+ # Import the app module
40
+ import app
41
+
42
+ # Check if the app object exists and is a Gradio Blocks object
43
+ self.assertTrue(
44
+ hasattr(app, "app"), "App object should exist in the module"
45
+ )
46
+
47
+ # Check if it's a Gradio Blocks instance
48
+ import gradio as gr
49
+
50
+ self.assertIsInstance(
51
+ app.app, gr.Blocks, "App should be a Gradio Blocks instance"
52
+ )
53
+
54
+ print("βœ… GUI app import and initialization passed")
55
+
56
+ except ImportError as e:
57
+ error_msg = f"Failed to import app module: {e}"
58
+ self.fail(error_msg)
59
+ except Exception as e:
60
+ self.fail(f"Unexpected error during app initialization: {e}")
61
+
62
+ def test_app_launch_headless(self):
63
+ """Test: Launch the app in headless mode to verify it starts without errors."""
64
+ print("\n=== Testing GUI app launch in headless mode ===")
65
+
66
+ try:
67
+ # Import the app module
68
+ import app
69
+
70
+ # Set up a flag to track if the app launched successfully
71
+ app_launched = threading.Event()
72
+ launch_error = None
73
+
74
+ def launch_app():
75
+ try:
76
+ # Launch the app in headless mode with a short timeout
77
+ app.app.launch(
78
+ show_error=True,
79
+ inbrowser=False, # Don't open browser
80
+ server_port=0, # Use any available port
81
+ quiet=True, # Suppress output
82
+ prevent_thread_lock=True, # Don't block the main thread
83
+ )
84
+ app_launched.set()
85
+ except Exception:
86
+ app_launched.set()
87
+
88
+ # Start the app in a separate thread
89
+ launch_thread = threading.Thread(target=launch_app)
90
+ launch_thread.daemon = True
91
+ launch_thread.start()
92
+
93
+ # Wait for the app to launch (with timeout)
94
+ if app_launched.wait(timeout=10): # 10 second timeout
95
+ if launch_error:
96
+ self.fail(f"App launch failed: {launch_error}")
97
+ else:
98
+ print("βœ… GUI app launch in headless mode passed")
99
+ else:
100
+ self.fail("App launch timed out after 10 seconds")
101
+
102
+ except Exception as e:
103
+ error_msg = f"Unexpected error during app launch test: {e}"
104
+ self.fail(error_msg)
105
+
106
+ def test_app_configuration_loading(self):
107
+ """Test: Verify that the app can load its configuration without errors."""
108
+ print("\n=== Testing GUI app configuration loading ===")
109
+
110
+ try:
111
+ # Check if key configuration variables are accessible
112
+ # These should be imported from tools.config
113
+ from tools.config import (
114
+ DEFAULT_COST_CODE,
115
+ GRADIO_SERVER_PORT,
116
+ MAX_FILE_SIZE,
117
+ default_model_choice,
118
+ model_name_map,
119
+ )
120
+
121
+ # Verify these are not None/empty
122
+ self.assertIsNotNone(
123
+ GRADIO_SERVER_PORT, "GRADIO_SERVER_PORT should be configured"
124
+ )
125
+ self.assertIsNotNone(MAX_FILE_SIZE, "MAX_FILE_SIZE should be configured")
126
+ self.assertIsNotNone(
127
+ DEFAULT_COST_CODE, "DEFAULT_COST_CODE should be configured"
128
+ )
129
+ self.assertIsNotNone(
130
+ default_model_choice, "default_model_choice should be configured"
131
+ )
132
+ self.assertIsNotNone(model_name_map, "model_name_map should be configured")
133
+
134
+ print("βœ… GUI app configuration loading passed")
135
+
136
+ except ImportError as e:
137
+ error_msg = f"Failed to import configuration: {e}"
138
+ self.fail(error_msg)
139
+ except Exception as e:
140
+ error_msg = f"Unexpected error during configuration test: {e}"
141
+ self.fail(error_msg)
142
+
143
+
144
+ def run_gui_tests():
145
+ """Run GUI tests and report results."""
146
+ print("=" * 80)
147
+ print("LLM TOPIC MODELLER GUI TEST SUITE")
148
+ print("=" * 80)
149
+ print("This test suite verifies that the GUI application loads correctly.")
150
+ print("=" * 80)
151
+
152
+ # Create test suite
153
+ loader = unittest.TestLoader()
154
+ suite = loader.loadTestsFromTestCase(TestGUIAppOnly)
155
+
156
+ # Run tests with detailed output
157
+ runner = unittest.TextTestRunner(verbosity=2, stream=None)
158
+ result = runner.run(suite)
159
+
160
+ # Print summary
161
+ print("\n" + "=" * 80)
162
+ print("GUI TEST SUMMARY")
163
+ print("=" * 80)
164
+ print(f"Tests run: {result.testsRun}")
165
+ print(f"Failures: {len(result.failures)}")
166
+ print(f"Errors: {len(result.errors)}")
167
+ print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
168
+
169
+ if result.failures:
170
+ print("\nFAILURES:")
171
+ for test, traceback in result.failures:
172
+ print(f"- {test}: {traceback}")
173
+
174
+ if result.errors:
175
+ print("\nERRORS:")
176
+ for test, traceback in result.errors:
177
+ print(f"- {test}: {traceback}")
178
+
179
+ success = len(result.failures) == 0 and len(result.errors) == 0
180
+ print(f"\nOverall result: {'βœ… PASSED' if success else '❌ FAILED'}")
181
+ print("=" * 80)
182
+
183
+ return success
184
+
185
+
186
+ if __name__ == "__main__":
187
+ # Run the GUI test suite
188
+ success = run_gui_tests()
189
+ exit(0 if success else 1)
tools/__init__.py ADDED
File without changes
tools/auth.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+ import hmac
4
+
5
+ import boto3
6
+
7
+ from tools.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_REGION, AWS_USER_POOL_ID
8
+
9
+
10
+ def calculate_secret_hash(client_id: str, client_secret: str, username: str):
11
+ message = username + client_id
12
+ dig = hmac.new(
13
+ str(client_secret).encode("utf-8"),
14
+ msg=str(message).encode("utf-8"),
15
+ digestmod=hashlib.sha256,
16
+ ).digest()
17
+ secret_hash = base64.b64encode(dig).decode()
18
+ return secret_hash
19
+
20
+
21
+ def authenticate_user(
22
+ username: str,
23
+ password: str,
24
+ user_pool_id: str = AWS_USER_POOL_ID,
25
+ client_id: str = AWS_CLIENT_ID,
26
+ client_secret: str = AWS_CLIENT_SECRET,
27
+ ):
28
+ """Authenticates a user against an AWS Cognito user pool.
29
+
30
+ Args:
31
+ user_pool_id (str): The ID of the Cognito user pool.
32
+ client_id (str): The ID of the Cognito user pool client.
33
+ username (str): The username of the user.
34
+ password (str): The password of the user.
35
+ client_secret (str): The client secret of the app client
36
+
37
+ Returns:
38
+ bool: True if the user is authenticated, False otherwise.
39
+ """
40
+
41
+ client = boto3.client(
42
+ "cognito-idp", region_name=AWS_REGION
43
+ ) # Cognito Identity Provider client
44
+
45
+ # Compute the secret hash
46
+ secret_hash = calculate_secret_hash(client_id, client_secret, username)
47
+
48
+ try:
49
+
50
+ if client_secret == "":
51
+ response = client.initiate_auth(
52
+ AuthFlow="USER_PASSWORD_AUTH",
53
+ AuthParameters={
54
+ "USERNAME": username,
55
+ "PASSWORD": password,
56
+ },
57
+ ClientId=client_id,
58
+ )
59
+
60
+ else:
61
+ response = client.initiate_auth(
62
+ AuthFlow="USER_PASSWORD_AUTH",
63
+ AuthParameters={
64
+ "USERNAME": username,
65
+ "PASSWORD": password,
66
+ "SECRET_HASH": secret_hash,
67
+ },
68
+ ClientId=client_id,
69
+ )
70
+
71
+ # If successful, you'll receive an AuthenticationResult in the response
72
+ if response.get("AuthenticationResult"):
73
+ return True
74
+ else:
75
+ return False
76
+
77
+ except client.exceptions.NotAuthorizedException:
78
+ return False
79
+ except client.exceptions.UserNotFoundException:
80
+ return False
81
+ except Exception as e:
82
+ out_message = f"An error occurred: {e}"
83
+ print(out_message)
84
+ raise Exception(out_message)
85
+ return False
tools/aws_functions.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import boto3
5
+
6
+ from tools.config import (
7
+ AWS_ACCESS_KEY,
8
+ AWS_REGION,
9
+ AWS_SECRET_KEY,
10
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
11
+ RUN_AWS_FUNCTIONS,
12
+ S3_LOG_BUCKET,
13
+ S3_OUTPUTS_BUCKET,
14
+ )
15
+
16
+ # Empty bucket name in case authentication fails
17
+ bucket_name = S3_LOG_BUCKET
18
+
19
+
20
+ def connect_to_bedrock_runtime(
21
+ model_name_map: dict,
22
+ model_choice: str,
23
+ aws_access_key_textbox: str = "",
24
+ aws_secret_key_textbox: str = "",
25
+ aws_region_textbox: str = "",
26
+ ):
27
+ # If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
28
+ model_source = model_name_map[model_choice]["source"]
29
+
30
+ # Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
31
+ region = aws_region_textbox if aws_region_textbox else AWS_REGION
32
+
33
+ if "AWS" in model_source:
34
+ if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
35
+ print("Connecting to Bedrock via existing SSO connection")
36
+ bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
37
+ elif aws_access_key_textbox and aws_secret_key_textbox:
38
+ print(
39
+ "Connecting to Bedrock using AWS access key and secret keys from user input."
40
+ )
41
+ bedrock_runtime = boto3.client(
42
+ "bedrock-runtime",
43
+ aws_access_key_id=aws_access_key_textbox,
44
+ aws_secret_access_key=aws_secret_key_textbox,
45
+ region_name=region,
46
+ )
47
+ elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
48
+ print("Getting Bedrock credentials from environment variables")
49
+ bedrock_runtime = boto3.client(
50
+ "bedrock-runtime",
51
+ aws_access_key_id=AWS_ACCESS_KEY,
52
+ aws_secret_access_key=AWS_SECRET_KEY,
53
+ region_name=region,
54
+ )
55
+ elif RUN_AWS_FUNCTIONS == "1":
56
+ print("Connecting to Bedrock via existing SSO connection")
57
+ bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
58
+ else:
59
+ bedrock_runtime = ""
60
+ out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
61
+ print(out_message)
62
+ raise Exception(out_message)
63
+ else:
64
+ bedrock_runtime = None
65
+
66
+ return bedrock_runtime
67
+
68
+
69
+ def connect_to_s3_client(
70
+ aws_access_key_textbox: str = "",
71
+ aws_secret_key_textbox: str = "",
72
+ aws_region_textbox: str = "",
73
+ ):
74
+ # If running an anthropic model, assume that running an AWS s3 model, load in s3
75
+ s3_client = None
76
+
77
+ # Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
78
+ region = aws_region_textbox if aws_region_textbox else AWS_REGION
79
+
80
+ if aws_access_key_textbox and aws_secret_key_textbox:
81
+ print("Connecting to s3 using AWS access key and secret keys from user input.")
82
+ s3_client = boto3.client(
83
+ "s3",
84
+ aws_access_key_id=aws_access_key_textbox,
85
+ aws_secret_access_key=aws_secret_key_textbox,
86
+ region_name=region,
87
+ )
88
+ elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
89
+ print("Connecting to s3 via existing SSO connection")
90
+ s3_client = boto3.client("s3", region_name=region)
91
+ elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
92
+ print("Getting s3 credentials from environment variables")
93
+ s3_client = boto3.client(
94
+ "s3",
95
+ aws_access_key_id=AWS_ACCESS_KEY,
96
+ aws_secret_access_key=AWS_SECRET_KEY,
97
+ region_name=region,
98
+ )
99
+ elif RUN_AWS_FUNCTIONS == "1":
100
+ print("Connecting to s3 via existing SSO connection")
101
+ s3_client = boto3.client("s3", region_name=region)
102
+ else:
103
+ s3_client = ""
104
+ out_message = "Cannot connect to S3 service. Please provide access keys under LLM settings, or choose another model type."
105
+ print(out_message)
106
+ raise Exception(out_message)
107
+
108
+ return s3_client
109
+
110
+
111
+ # Download direct from S3 - requires login credentials
112
+ def download_file_from_s3(
113
+ bucket_name: str,
114
+ key: str,
115
+ local_file_path: str,
116
+ aws_access_key_textbox: str = "",
117
+ aws_secret_key_textbox: str = "",
118
+ aws_region_textbox: str = "",
119
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
120
+ ):
121
+
122
+ if RUN_AWS_FUNCTIONS == "1":
123
+
124
+ s3 = connect_to_s3_client(
125
+ aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
126
+ )
127
+ # boto3.client('s3')
128
+ s3.download_file(bucket_name, key, local_file_path)
129
+ print(f"File downloaded from S3: s3://{bucket_name}/{key} to {local_file_path}")
130
+
131
+
132
+ def download_folder_from_s3(
133
+ bucket_name: str,
134
+ s3_folder: str,
135
+ local_folder: str,
136
+ aws_access_key_textbox: str = "",
137
+ aws_secret_key_textbox: str = "",
138
+ aws_region_textbox: str = "",
139
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
140
+ ):
141
+ """
142
+ Download all files from an S3 folder to a local folder.
143
+ """
144
+ if RUN_AWS_FUNCTIONS == "1":
145
+ s3 = connect_to_s3_client(
146
+ aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
147
+ )
148
+ # boto3.client('s3')
149
+
150
+ # List objects in the specified S3 folder
151
+ response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
152
+
153
+ # Download each object
154
+ for obj in response.get("Contents", []):
155
+ # Extract object key and construct local file path
156
+ object_key = obj["Key"]
157
+ local_file_path = os.path.join(
158
+ local_folder, os.path.relpath(object_key, s3_folder)
159
+ )
160
+
161
+ # Create directories if necessary
162
+ os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
163
+
164
+ # Download the object
165
+ try:
166
+ s3.download_file(bucket_name, object_key, local_file_path)
167
+ print(
168
+ f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
169
+ )
170
+ except Exception as e:
171
+ print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
172
+
173
+
174
+ def download_files_from_s3(
175
+ bucket_name: str,
176
+ s3_folder: str,
177
+ local_folder: str,
178
+ filenames: list[str],
179
+ aws_access_key_textbox: str = "",
180
+ aws_secret_key_textbox: str = "",
181
+ aws_region_textbox: str = "",
182
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
183
+ ):
184
+ """
185
+ Download specific files from an S3 folder to a local folder.
186
+ """
187
+ if RUN_AWS_FUNCTIONS == "1":
188
+ s3 = connect_to_s3_client(
189
+ aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
190
+ )
191
+ # boto3.client('s3')
192
+
193
+ print("Trying to download file: ", filenames)
194
+
195
+ if filenames == "*":
196
+ # List all objects in the S3 folder
197
+ print("Trying to download all files in AWS folder: ", s3_folder)
198
+ response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
199
+
200
+ print("Found files in AWS folder: ", response.get("Contents", []))
201
+
202
+ filenames = [
203
+ obj["Key"].split("/")[-1] for obj in response.get("Contents", [])
204
+ ]
205
+
206
+ print("Found filenames in AWS folder: ", filenames)
207
+
208
+ for filename in filenames:
209
+ object_key = os.path.join(s3_folder, filename)
210
+ local_file_path = os.path.join(local_folder, filename)
211
+
212
+ # Create directories if necessary
213
+ os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
214
+
215
+ # Download the object
216
+ try:
217
+ s3.download_file(bucket_name, object_key, local_file_path)
218
+ print(
219
+ f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
220
+ )
221
+ except Exception as e:
222
+ print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
223
+
224
+
225
+ def upload_file_to_s3(
226
+ local_file_paths: List[str],
227
+ s3_key: str,
228
+ s3_bucket: str = bucket_name,
229
+ aws_access_key_textbox: str = "",
230
+ aws_secret_key_textbox: str = "",
231
+ aws_region_textbox: str = "",
232
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
233
+ ):
234
+ """
235
+ Uploads a file from local machine to Amazon S3.
236
+
237
+ Args:
238
+ - local_file_path: Local file path(s) of the file(s) to upload.
239
+ - s3_key: Key (path) to the file in the S3 bucket.
240
+ - s3_bucket: Name of the S3 bucket.
241
+
242
+ Returns:
243
+ - Message as variable/printed to console
244
+ """
245
+ if RUN_AWS_FUNCTIONS == "1":
246
+
247
+ final_out_message = list()
248
+
249
+ s3_client = connect_to_s3_client(
250
+ aws_access_key_textbox, aws_secret_key_textbox, aws_region_textbox
251
+ )
252
+ # boto3.client('s3')
253
+
254
+ if isinstance(local_file_paths, str):
255
+ local_file_paths = [local_file_paths]
256
+
257
+ for file in local_file_paths:
258
+ try:
259
+ # Get file name off file path
260
+ file_name = os.path.basename(file)
261
+
262
+ s3_key_full = s3_key + file_name
263
+ print("S3 key: ", s3_key_full)
264
+
265
+ s3_client.upload_file(file, s3_bucket, s3_key_full)
266
+ out_message = "File " + file_name + " uploaded successfully!"
267
+ print(out_message)
268
+
269
+ except Exception as e:
270
+ out_message = f"Error uploading file(s): {e}"
271
+ print(out_message)
272
+
273
+ final_out_message.append(out_message)
274
+ final_out_message_str = "\n".join(final_out_message)
275
+
276
+ else:
277
+ final_out_message_str = "Not connected to AWS, no files uploaded."
278
+
279
+ return final_out_message_str
280
+
281
+
282
+ # Helper to upload outputs to S3 when enabled in config.
283
+ def export_outputs_to_s3(
284
+ file_list_state,
285
+ s3_output_folder_state_value: str,
286
+ save_outputs_to_s3_flag: bool,
287
+ base_file_state=None,
288
+ s3_bucket: str = S3_OUTPUTS_BUCKET,
289
+ ):
290
+ """
291
+ Upload a list of local output files to the configured S3 outputs folder.
292
+
293
+ - file_list_state: Gradio dropdown state that holds a list of file paths or a
294
+ single path/string. If blank/empty, no action is taken.
295
+ - s3_output_folder_state_value: Final S3 key prefix (including any session hash)
296
+ to use as the destination folder for uploads.
297
+ - s3_bucket: Name of the S3 bucket.
298
+ """
299
+ try:
300
+
301
+ # Respect the runtime toggle as well as environment configuration
302
+ if not save_outputs_to_s3_flag:
303
+ return
304
+
305
+ if not s3_output_folder_state_value:
306
+ # No configured S3 outputs folder – nothing to do
307
+ return
308
+
309
+ # Normalise input to a Python list of strings
310
+ file_paths = file_list_state
311
+ if not file_paths:
312
+ return
313
+
314
+ # Gradio dropdown may return a single string or a list
315
+ if isinstance(file_paths, str):
316
+ file_paths = [file_paths]
317
+
318
+ # Filter out any non-truthy values
319
+ file_paths = [p for p in file_paths if p]
320
+ if not file_paths:
321
+ return
322
+
323
+ # Derive a base file stem (name without extension) from the original
324
+ # file(s) being analysed, if provided. This is used to create an
325
+ # additional subfolder layer so that outputs are grouped under the
326
+ # analysed file name rather than under each output file name.
327
+ base_stem = None
328
+ if base_file_state:
329
+ base_path = None
330
+
331
+ # Gradio File components typically provide a list of objects with a `.name` attribute
332
+ if isinstance(base_file_state, str):
333
+ base_path = base_file_state
334
+ elif isinstance(base_file_state, list) and base_file_state:
335
+ first_item = base_file_state[0]
336
+ base_path = getattr(first_item, "name", None) or str(first_item)
337
+ else:
338
+ base_path = getattr(base_file_state, "name", None) or str(
339
+ base_file_state
340
+ )
341
+
342
+ if base_path:
343
+ base_name = os.path.basename(base_path)
344
+ base_stem, _ = os.path.splitext(base_name)
345
+
346
+ # Ensure base S3 prefix (session/date) ends with a trailing slash
347
+ base_prefix = s3_output_folder_state_value
348
+ if not base_prefix.endswith("/"):
349
+ base_prefix = base_prefix + "/"
350
+
351
+ # For each file, append a subfolder. If we have a derived base_stem
352
+ # from the input being analysed, use that; otherwise, fall back to
353
+ # the individual output file name stem. Final pattern:
354
+ # <session_output_folder>/<date>/<base_file_stem>/<file_name>
355
+ # or, if base_file_stem is not available:
356
+ # <session_output_folder>/<date>/<output_file_stem>/<file_name>
357
+ for file in file_paths:
358
+ file_name = os.path.basename(file)
359
+
360
+ if base_stem:
361
+ folder_stem = base_stem
362
+ else:
363
+ folder_stem, _ = os.path.splitext(file_name)
364
+
365
+ per_file_prefix = base_prefix + folder_stem + "/"
366
+
367
+ out_message = upload_file_to_s3(
368
+ local_file_paths=[file],
369
+ s3_key=per_file_prefix,
370
+ s3_bucket=s3_bucket,
371
+ )
372
+
373
+ # Log any issues to console so failures are visible in logs/stdout
374
+ if (
375
+ "Error uploading file" in out_message
376
+ or "could not upload" in out_message.lower()
377
+ ):
378
+ print("export_outputs_to_s3 encountered issues:", out_message)
379
+
380
+ print("Successfully uploaded outputs to S3")
381
+
382
+ except Exception as e:
383
+ # Do not break the app flow if S3 upload fails – just report to console
384
+ print(f"export_outputs_to_s3 failed with error: {e}")
385
+
386
+ # No GUI outputs to update
387
+ return
tools/combine_sheets_into_xlsx.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import date, datetime
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ from openpyxl import Workbook
7
+ from openpyxl.styles import Alignment, Font
8
+ from openpyxl.utils import get_column_letter
9
+ from openpyxl.utils.dataframe import dataframe_to_rows
10
+
11
+ from tools.config import OUTPUT_FOLDER
12
+ from tools.config import model_name_map as global_model_name_map
13
+ from tools.helper_functions import (
14
+ clean_column_name,
15
+ convert_reference_table_to_pivot_table,
16
+ ensure_model_in_map,
17
+ get_basic_response_data,
18
+ load_in_data_file,
19
+ )
20
+
21
+
22
+ def add_cover_sheet(
23
+ wb: Workbook,
24
+ intro_paragraphs: list[str],
25
+ model_name: str,
26
+ analysis_date: str,
27
+ analysis_cost: str,
28
+ number_of_responses: int,
29
+ number_of_responses_with_text: int,
30
+ number_of_responses_with_text_five_plus_words: int,
31
+ llm_call_number: int,
32
+ input_tokens: int,
33
+ output_tokens: int,
34
+ time_taken: float,
35
+ file_name: str,
36
+ column_name: str,
37
+ number_of_responses_with_topic_assignment: int,
38
+ custom_title: str = "Cover Sheet",
39
+ ):
40
+ ws = wb.create_sheet(title=custom_title, index=0)
41
+
42
+ # Freeze top row
43
+ ws.freeze_panes = "A2"
44
+
45
+ # Write title
46
+ ws["A1"] = "Large Language Model Topic analysis"
47
+ ws["A1"].font = Font(size=14, bold=True)
48
+ ws["A1"].alignment = Alignment(wrap_text=True, vertical="top")
49
+
50
+ # Add intro paragraphs
51
+ row = 3
52
+ for paragraph in intro_paragraphs:
53
+ ws.merge_cells(start_row=row, start_column=1, end_row=row, end_column=2)
54
+ cell = ws.cell(row=row, column=1, value=paragraph)
55
+ cell.alignment = Alignment(wrap_text=True, vertical="top")
56
+ ws.row_dimensions[row].height = 60 # Adjust height as needed
57
+ row += 2
58
+
59
+ # Add metadata
60
+ meta_start = row + 1
61
+ metadata = {
62
+ "Date Excel file created": date.today().strftime("%Y-%m-%d"),
63
+ "File name": file_name,
64
+ "Column name": column_name,
65
+ "Model name": model_name,
66
+ "Analysis date": analysis_date,
67
+ # "Analysis cost": analysis_cost,
68
+ "Number of responses": number_of_responses,
69
+ "Number of responses with text": number_of_responses_with_text,
70
+ "Number of responses with text five plus words": number_of_responses_with_text_five_plus_words,
71
+ "Number of responses with at least one assigned topic": number_of_responses_with_topic_assignment,
72
+ "Number of LLM calls": llm_call_number,
73
+ "Total number of input tokens from LLM calls": input_tokens,
74
+ "Total number of output tokens from LLM calls": output_tokens,
75
+ "Total time taken for all LLM calls (seconds)": round(float(time_taken), 1),
76
+ }
77
+
78
+ for i, (label, value) in enumerate(metadata.items()):
79
+ row_num = meta_start + i
80
+ ws[f"A{row_num}"] = label
81
+ ws[f"A{row_num}"].font = Font(bold=True)
82
+
83
+ cell = ws[f"B{row_num}"]
84
+ cell.value = value
85
+ cell.alignment = Alignment(wrap_text=True)
86
+ # Optional: Adjust column widths
87
+ ws.column_dimensions["A"].width = 25
88
+ ws.column_dimensions["B"].width = 75
89
+
90
+ # Ensure first row cells are wrapped on the cover sheet
91
+ for col_idx in range(1, ws.max_column + 1):
92
+ header_cell = ws.cell(row=1, column=col_idx)
93
+ header_cell.alignment = Alignment(wrap_text=True, vertical="center")
94
+
95
+
96
+ def csvs_to_excel(
97
+ csv_files: list[str],
98
+ output_filename: str,
99
+ sheet_names: list[str] = None,
100
+ column_widths: dict = None, # Dict of {sheet_name: {col_letter: width}}
101
+ wrap_text_columns: dict = None, # Dict of {sheet_name: [col_letters]}
102
+ intro_text: list[str] = None,
103
+ model_name: str = "",
104
+ analysis_date: str = "",
105
+ analysis_cost: str = "",
106
+ llm_call_number: int = 0,
107
+ input_tokens: int = 0,
108
+ output_tokens: int = 0,
109
+ time_taken: float = 0,
110
+ number_of_responses: int = 0,
111
+ number_of_responses_with_text: int = 0,
112
+ number_of_responses_with_text_five_plus_words: int = 0,
113
+ column_name: str = "",
114
+ number_of_responses_with_topic_assignment: int = 0,
115
+ file_name: str = "",
116
+ unique_reference_numbers: list = [],
117
+ ):
118
+ if intro_text is None:
119
+ intro_text = list()
120
+
121
+ wb = Workbook()
122
+ # Remove default sheet
123
+ wb.remove(wb.active)
124
+
125
+ for idx, csv_path in enumerate(csv_files):
126
+ # Use provided sheet name or derive from file name
127
+ sheet_name = (
128
+ sheet_names[idx]
129
+ if sheet_names and idx < len(sheet_names)
130
+ else os.path.splitext(os.path.basename(csv_path))[0]
131
+ )
132
+ df = pd.read_csv(csv_path)
133
+
134
+ if sheet_name == "Original data":
135
+ try:
136
+ # Create a copy to avoid modifying the original
137
+ df_copy = df.copy()
138
+ # Insert the Reference column at position 0 (first column)
139
+ df_copy.insert(0, "Reference", unique_reference_numbers)
140
+ df = df_copy
141
+ except Exception as e:
142
+ print("Could not add reference number to original data due to:", e)
143
+
144
+ ws = wb.create_sheet(title=sheet_name)
145
+
146
+ for r_idx, row in enumerate(
147
+ dataframe_to_rows(df, index=False, header=True), start=1
148
+ ):
149
+ ws.append(row)
150
+
151
+ for col_idx, value in enumerate(row, start=1):
152
+ cell = ws.cell(row=r_idx, column=col_idx)
153
+
154
+ # Bold header row
155
+ if r_idx == 1:
156
+ cell.font = Font(bold=True)
157
+
158
+ # Set vertical alignment to middle by default
159
+ cell.alignment = Alignment(vertical="center")
160
+
161
+ # Apply wrap text if needed
162
+ if wrap_text_columns and sheet_name in wrap_text_columns:
163
+ for col_letter in wrap_text_columns[sheet_name]:
164
+ cell = ws[f"{col_letter}{r_idx}"]
165
+ cell.alignment = Alignment(vertical="center", wrap_text=True)
166
+
167
+ # Freeze top row for all data sheets
168
+ ws.freeze_panes = "A2"
169
+
170
+ # Ensure all header cells (first row) are wrapped
171
+ for col_idx in range(1, ws.max_column + 1):
172
+ header_cell = ws.cell(row=1, column=col_idx)
173
+ header_cell.alignment = Alignment(vertical="center", wrap_text=True)
174
+
175
+ # Set column widths
176
+ if column_widths and sheet_name in column_widths:
177
+ for col_letter, width in column_widths[sheet_name].items():
178
+ ws.column_dimensions[col_letter].width = width
179
+
180
+ add_cover_sheet(
181
+ wb,
182
+ intro_paragraphs=intro_text,
183
+ model_name=model_name,
184
+ analysis_date=analysis_date,
185
+ analysis_cost=analysis_cost,
186
+ number_of_responses=number_of_responses,
187
+ number_of_responses_with_text=number_of_responses_with_text,
188
+ number_of_responses_with_text_five_plus_words=number_of_responses_with_text_five_plus_words,
189
+ llm_call_number=llm_call_number,
190
+ input_tokens=input_tokens,
191
+ output_tokens=output_tokens,
192
+ time_taken=time_taken,
193
+ file_name=file_name,
194
+ column_name=column_name,
195
+ number_of_responses_with_topic_assignment=number_of_responses_with_topic_assignment,
196
+ )
197
+
198
+ wb.save(output_filename)
199
+
200
+ print(f"Output xlsx summary saved as '{output_filename}'")
201
+
202
+ return output_filename
203
+
204
+
205
+ ###
206
+ # Run the functions
207
+ ###
208
+ def collect_output_csvs_and_create_excel_output(
209
+ in_data_files: List,
210
+ chosen_cols: list[str],
211
+ reference_data_file_name_textbox: str,
212
+ in_group_col: str,
213
+ model_choice: str,
214
+ master_reference_df_state: pd.DataFrame,
215
+ master_unique_topics_df_state: pd.DataFrame,
216
+ summarised_output_df: pd.DataFrame,
217
+ missing_df_state: pd.DataFrame,
218
+ excel_sheets: str = "",
219
+ usage_logs_location: str = "",
220
+ model_name_map: dict = dict(),
221
+ output_folder: str = OUTPUT_FOLDER,
222
+ structured_summaries: str = "No",
223
+ ):
224
+ """
225
+ Collect together output CSVs from various output boxes and combine them into a single output Excel file.
226
+
227
+ Args:
228
+ in_data_files (List): A list of paths to the input data files.
229
+ chosen_cols (list[str]): A list of column names selected for analysis.
230
+ reference_data_file_name_textbox (str): The name of the reference data file.
231
+ in_group_col (str): The column used for grouping the data.
232
+ model_choice (str): The LLM model chosen for the analysis.
233
+ master_reference_df_state (pd.DataFrame): The master DataFrame containing reference data.
234
+ master_unique_topics_df_state (pd.DataFrame): The master DataFrame containing unique topics data.
235
+ summarised_output_df (pd.DataFrame): DataFrame containing the summarised output.
236
+ missing_df_state (pd.DataFrame): DataFrame containing information about missing data.
237
+ excel_sheets (str): Information regarding Excel sheets, typically sheet names or structure.
238
+ usage_logs_location (str, optional): Path to the usage logs CSV file. Defaults to "".
239
+ model_name_map (dict, optional): A dictionary mapping model choices to their display names. Defaults to {}.
240
+ output_folder (str, optional): The directory where the output Excel file will be saved. Defaults to OUTPUT_FOLDER.
241
+ structured_summaries (str, optional): Indicates whether structured summaries are being produced ("Yes" or "No"). Defaults to "No".
242
+
243
+ Returns:
244
+ tuple: A tuple containing:
245
+ - list: A list of paths to the generated Excel output files.
246
+ - list: A duplicate of the list of paths to the generated Excel output files (for UI compatibility).
247
+ """
248
+ # Use passed model_name_map if provided and not empty, otherwise use global one
249
+ if not model_name_map:
250
+ model_name_map = global_model_name_map
251
+
252
+ # Ensure custom model_choice is registered in model_name_map
253
+ ensure_model_in_map(model_choice, model_name_map)
254
+
255
+ if structured_summaries == "Yes":
256
+ structured_summaries = True
257
+ else:
258
+ structured_summaries = False
259
+
260
+ if not chosen_cols:
261
+ raise Exception("Could not find chosen column")
262
+
263
+ today_date = datetime.today().strftime("%Y-%m-%d")
264
+ original_data_file_path = os.path.abspath(in_data_files[0])
265
+
266
+ csv_files = list()
267
+ sheet_names = list()
268
+ column_widths = dict()
269
+ wrap_text_columns = dict()
270
+ short_file_name = os.path.basename(reference_data_file_name_textbox)
271
+ reference_pivot_table = pd.DataFrame()
272
+ reference_table_csv_path = ""
273
+ reference_pivot_table_csv_path = ""
274
+ unique_topic_table_csv_path = ""
275
+ missing_df_state_csv_path = ""
276
+ overall_summary_csv_path = ""
277
+ number_of_responses_with_topic_assignment = 0
278
+
279
+ if in_group_col:
280
+ group = in_group_col
281
+ else:
282
+ group = "All"
283
+
284
+ overall_summary_csv_path = output_folder + "overall_summary_for_xlsx.csv"
285
+
286
+ if structured_summaries is True and not master_unique_topics_df_state.empty:
287
+ print("Producing overall summary based on structured summaries.")
288
+ # Create structured summary from master_unique_topics_df_state
289
+ structured_summary_data = list()
290
+
291
+ # Group by 'Group' column
292
+ for group_name, group_df in master_unique_topics_df_state.groupby("Group"):
293
+ group_summary = f"## {group_name}\n\n"
294
+
295
+ # Group by 'General topic' within each group
296
+ for general_topic, topic_df in group_df.groupby("General topic"):
297
+ group_summary += f"### {general_topic}\n\n"
298
+
299
+ # Add subtopics under each general topic
300
+ for _, row in topic_df.iterrows():
301
+ subtopic = row["Subtopic"]
302
+ summary = row["Summary"]
303
+ # sentiment = row.get('Sentiment', '')
304
+ # num_responses = row.get('Number of responses', '')
305
+
306
+ # Create subtopic entry
307
+ subtopic_entry = f"**{subtopic}**"
308
+ # if sentiment:
309
+ # subtopic_entry += f" ({sentiment})"
310
+ # if num_responses:
311
+ # subtopic_entry += f" - {num_responses} responses"
312
+ subtopic_entry += "\n\n"
313
+
314
+ if summary and pd.notna(summary):
315
+ subtopic_entry += f"{summary}\n\n"
316
+
317
+ group_summary += subtopic_entry
318
+
319
+ # Add to structured summary data
320
+ structured_summary_data.append(
321
+ {"Group": group_name, "Summary": group_summary.strip()}
322
+ )
323
+
324
+ # Create DataFrame for structured summary
325
+ structured_summary_df = pd.DataFrame(structured_summary_data)
326
+ structured_summary_df.to_csv(overall_summary_csv_path, index=False)
327
+ else:
328
+ # Use original summarised_output_df
329
+ structured_summary_df = summarised_output_df
330
+ structured_summary_df.to_csv(overall_summary_csv_path, index=None)
331
+
332
+ if not structured_summary_df.empty:
333
+ csv_files.append(overall_summary_csv_path)
334
+ sheet_names.append("Overall summary")
335
+ column_widths["Overall summary"] = {"A": 20, "B": 100}
336
+ wrap_text_columns["Overall summary"] = ["B"]
337
+
338
+ if not master_reference_df_state.empty:
339
+ # Simplify table to just responses column and the Response reference number
340
+ file_data, file_name, num_batches = load_in_data_file(
341
+ in_data_files, chosen_cols, 1, in_excel_sheets=excel_sheets
342
+ )
343
+ basic_response_data = get_basic_response_data(
344
+ file_data, chosen_cols, verify_titles="No"
345
+ )
346
+ reference_pivot_table = convert_reference_table_to_pivot_table(
347
+ master_reference_df_state, basic_response_data
348
+ )
349
+
350
+ unique_reference_numbers = basic_response_data["Reference"].tolist()
351
+
352
+ try:
353
+ master_reference_df_state.rename(
354
+ columns={"Topic_number": "Topic number"}, inplace=True, errors="ignore"
355
+ )
356
+ master_reference_df_state.drop(
357
+ columns=["1", "2", "3"], inplace=True, errors="ignore"
358
+ )
359
+ except Exception as e:
360
+ print("Could not rename Topic_number due to", e)
361
+
362
+ number_of_responses_with_topic_assignment = len(
363
+ master_reference_df_state["Response References"].unique()
364
+ )
365
+
366
+ reference_table_csv_path = output_folder + "reference_df_for_xlsx.csv"
367
+ master_reference_df_state.to_csv(reference_table_csv_path, index=None)
368
+
369
+ reference_pivot_table_csv_path = (
370
+ output_folder + "reference_pivot_df_for_xlsx.csv"
371
+ )
372
+ reference_pivot_table.to_csv(reference_pivot_table_csv_path, index=None)
373
+
374
+ short_file_name = os.path.basename(file_name)
375
+
376
+ if not master_unique_topics_df_state.empty:
377
+
378
+ master_unique_topics_df_state.drop(
379
+ columns=["1", "2", "3"], inplace=True, errors="ignore"
380
+ )
381
+
382
+ unique_topic_table_csv_path = (
383
+ output_folder + "unique_topic_table_df_for_xlsx.csv"
384
+ )
385
+ master_unique_topics_df_state.to_csv(unique_topic_table_csv_path, index=None)
386
+
387
+ if unique_topic_table_csv_path:
388
+ csv_files.append(unique_topic_table_csv_path)
389
+ sheet_names.append("Topic summary")
390
+ column_widths["Topic summary"] = {"A": 25, "B": 25, "C": 15, "D": 15, "F": 100}
391
+ wrap_text_columns["Topic summary"] = ["B", "F"]
392
+ else:
393
+ print("Relevant unique topic files not found, excluding from xlsx output.")
394
+
395
+ if reference_table_csv_path:
396
+ if structured_summaries:
397
+ print(
398
+ "Structured summaries are being produced, excluding response level data from xlsx output."
399
+ )
400
+ else:
401
+ csv_files.append(reference_table_csv_path)
402
+ sheet_names.append("Response level data")
403
+ column_widths["Response level data"] = {"A": 15, "B": 30, "C": 40, "H": 100}
404
+ wrap_text_columns["Response level data"] = ["C", "G"]
405
+ else:
406
+ print("Relevant reference files not found, excluding from xlsx output.")
407
+
408
+ if reference_pivot_table_csv_path:
409
+ if structured_summaries:
410
+ print(
411
+ "Structured summaries are being produced, excluding topic response pivot table from xlsx output."
412
+ )
413
+ else:
414
+ csv_files.append(reference_pivot_table_csv_path)
415
+ sheet_names.append("Topic response pivot table")
416
+
417
+ if reference_pivot_table.empty:
418
+ reference_pivot_table = pd.read_csv(reference_pivot_table_csv_path)
419
+
420
+ # Base widths and wrap
421
+ column_widths["Topic response pivot table"] = {"A": 25, "B": 100}
422
+ wrap_text_columns["Topic response pivot table"] = ["B"]
423
+
424
+ num_cols = len(reference_pivot_table.columns)
425
+ col_letters = [get_column_letter(i) for i in range(3, num_cols + 1)]
426
+
427
+ for col_letter in col_letters:
428
+ column_widths["Topic response pivot table"][col_letter] = 25
429
+
430
+ wrap_text_columns["Topic response pivot table"].extend(col_letters)
431
+ else:
432
+ print(
433
+ "Relevant reference pivot table files not found, excluding from xlsx output."
434
+ )
435
+
436
+ if not missing_df_state.empty:
437
+ missing_df_state_csv_path = output_folder + "missing_df_state_df_for_xlsx.csv"
438
+ missing_df_state.to_csv(missing_df_state_csv_path, index=None)
439
+
440
+ if missing_df_state_csv_path:
441
+ if structured_summaries:
442
+ print(
443
+ "Structured summaries are being produced, excluding missing responses from xlsx output."
444
+ )
445
+ else:
446
+ csv_files.append(missing_df_state_csv_path)
447
+ sheet_names.append("Missing responses")
448
+ column_widths["Missing responses"] = {"A": 25, "B": 30, "C": 50}
449
+ wrap_text_columns["Missing responses"] = ["C"]
450
+ else:
451
+ print("Relevant missing responses files not found, excluding from xlsx output.")
452
+
453
+ new_csv_files = csv_files.copy()
454
+
455
+ # Original data file
456
+ original_ext = os.path.splitext(original_data_file_path)[1].lower()
457
+ if original_ext == ".csv":
458
+ csv_files.append(original_data_file_path)
459
+ else:
460
+ # Read and convert to CSV
461
+ if original_ext == ".xlsx":
462
+ if excel_sheets:
463
+ df = pd.read_excel(original_data_file_path, sheet_name=excel_sheets)
464
+ else:
465
+ df = pd.read_excel(original_data_file_path)
466
+ elif original_ext == ".parquet":
467
+ df = pd.read_parquet(original_data_file_path)
468
+ else:
469
+ raise Exception(f"Unsupported file type for original data: {original_ext}")
470
+
471
+ # Save as CSV in output folder
472
+ original_data_csv_path = os.path.join(
473
+ output_folder,
474
+ os.path.splitext(os.path.basename(original_data_file_path))[0]
475
+ + "_for_xlsx.csv",
476
+ )
477
+ df.to_csv(original_data_csv_path, index=False)
478
+ csv_files.append(original_data_csv_path)
479
+
480
+ sheet_names.append("Original data")
481
+ column_widths["Original data"] = {"A": 20, "B": 20, "C": 20}
482
+ wrap_text_columns["Original data"] = ["C"]
483
+ if isinstance(chosen_cols, list) and chosen_cols:
484
+ chosen_cols = chosen_cols[0]
485
+ else:
486
+ chosen_cols = str(chosen_cols) if chosen_cols else ""
487
+
488
+ # Intro page text
489
+ intro_text = [
490
+ "This workbook contains outputs from the large language model topic analysis of open text data. Each sheet corresponds to a different CSV report included in the analysis.",
491
+ f"The file analysed was {short_file_name}, the column analysed was '{chosen_cols}' and the data was grouped by column '{group}'."
492
+ " Please contact the LLM Topic Modelling app administrator if you need any explanation on how to use the results."
493
+ "Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this analysis **need to be checked by a human** to check for harmful outputs, false information, and bias.",
494
+ ]
495
+
496
+ # Get values for number of rows, number of responses, and number of responses longer than five words
497
+ number_of_responses = basic_response_data.shape[0]
498
+ # number_of_responses_with_text = basic_response_data["Response"].str.strip().notnull().sum()
499
+ number_of_responses_with_text = (
500
+ basic_response_data["Response"].str.strip().notnull()
501
+ & (basic_response_data["Response"].str.split().str.len() >= 1)
502
+ ).sum()
503
+ number_of_responses_with_text_five_plus_words = (
504
+ basic_response_data["Response"].str.strip().notnull()
505
+ & (basic_response_data["Response"].str.split().str.len() >= 5)
506
+ ).sum()
507
+
508
+ # Get number of LLM calls, input and output tokens
509
+ if usage_logs_location:
510
+ try:
511
+ usage_logs = pd.read_csv(usage_logs_location)
512
+
513
+ relevant_logs = usage_logs.loc[
514
+ (
515
+ usage_logs["Reference data file name"]
516
+ == reference_data_file_name_textbox
517
+ )
518
+ & (
519
+ usage_logs[
520
+ "Large language model for topic extraction and summarisation"
521
+ ]
522
+ == model_choice
523
+ )
524
+ & (
525
+ usage_logs[
526
+ "Select the open text column of interest. In an Excel file, this shows columns across all sheets."
527
+ ]
528
+ == (
529
+ chosen_cols[0]
530
+ if isinstance(chosen_cols, list) and chosen_cols
531
+ else chosen_cols
532
+ )
533
+ ),
534
+ :,
535
+ ]
536
+
537
+ llm_call_number = sum(relevant_logs["Total LLM calls"].astype(int))
538
+ input_tokens = sum(relevant_logs["Total input tokens"].astype(int))
539
+ output_tokens = sum(relevant_logs["Total output tokens"].astype(int))
540
+ time_taken = sum(
541
+ relevant_logs["Estimated time taken (seconds)"].astype(float)
542
+ )
543
+ except Exception as e:
544
+ print("Could not obtain usage logs due to:", e)
545
+ usage_logs = pd.DataFrame()
546
+ llm_call_number = 0
547
+ input_tokens = 0
548
+ output_tokens = 0
549
+ time_taken = 0
550
+ else:
551
+ print("LLM call logs location not provided")
552
+ usage_logs = pd.DataFrame()
553
+ llm_call_number = 0
554
+ input_tokens = 0
555
+ output_tokens = 0
556
+ time_taken = 0
557
+
558
+ # Create short filename:
559
+ model_choice_clean_short = clean_column_name(
560
+ model_name_map[model_choice]["short_name"],
561
+ max_length=20,
562
+ front_characters=False,
563
+ )
564
+ # Extract first column name as string for cleaning and Excel output
565
+ chosen_col_str = (
566
+ chosen_cols[0]
567
+ if isinstance(chosen_cols, list) and chosen_cols
568
+ else str(chosen_cols) if chosen_cols else ""
569
+ )
570
+ in_column_cleaned = clean_column_name(chosen_col_str, max_length=20)
571
+ file_name_cleaned = clean_column_name(
572
+ file_name, max_length=20, front_characters=True
573
+ )
574
+
575
+ # Save outputs for each batch. If master file created, label file as master
576
+ file_path_details = (
577
+ f"{file_name_cleaned}_col_{in_column_cleaned}_{model_choice_clean_short}"
578
+ )
579
+ output_xlsx_filename = (
580
+ output_folder
581
+ + file_path_details
582
+ + ("_structured_summaries" if structured_summaries else "_topic_analysis")
583
+ + ".xlsx"
584
+ )
585
+
586
+ xlsx_output_filename = csvs_to_excel(
587
+ csv_files=csv_files,
588
+ output_filename=output_xlsx_filename,
589
+ sheet_names=sheet_names,
590
+ column_widths=column_widths,
591
+ wrap_text_columns=wrap_text_columns,
592
+ intro_text=intro_text,
593
+ model_name=model_choice,
594
+ analysis_date=today_date,
595
+ analysis_cost="",
596
+ llm_call_number=llm_call_number,
597
+ input_tokens=input_tokens,
598
+ output_tokens=output_tokens,
599
+ time_taken=time_taken,
600
+ number_of_responses=number_of_responses,
601
+ number_of_responses_with_text=number_of_responses_with_text,
602
+ number_of_responses_with_text_five_plus_words=number_of_responses_with_text_five_plus_words,
603
+ column_name=chosen_col_str,
604
+ number_of_responses_with_topic_assignment=number_of_responses_with_topic_assignment,
605
+ file_name=short_file_name,
606
+ unique_reference_numbers=unique_reference_numbers,
607
+ )
608
+
609
+ xlsx_output_filenames = [xlsx_output_filename]
610
+
611
+ # Delete intermediate csv files
612
+ for csv_file in new_csv_files:
613
+ os.remove(csv_file)
614
+
615
+ return xlsx_output_filenames, xlsx_output_filenames
tools/config.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import logging
3
+ import os
4
+ import socket
5
+ import tempfile
6
+ from datetime import datetime
7
+ from typing import List
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ today_rev = datetime.now().strftime("%Y%m%d")
12
+ HOST_NAME = socket.gethostname()
13
+
14
+ # Set or retrieve configuration variables for the redaction app
15
+
16
+
17
+ def get_or_create_env_var(var_name: str, default_value: str, print_val: bool = False):
18
+ """
19
+ Get an environmental variable, and set it to a default value if it doesn't exist
20
+ """
21
+ # Get the environment variable if it exists
22
+ value = os.environ.get(var_name)
23
+
24
+ # If it doesn't exist, set the environment variable to the default value
25
+ if value is None:
26
+ os.environ[var_name] = default_value
27
+ value = default_value
28
+
29
+ if print_val is True:
30
+ print(f"The value of {var_name} is {value}")
31
+
32
+ return value
33
+
34
+
35
+ def add_folder_to_path(folder_path: str):
36
+ """
37
+ Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
38
+ """
39
+
40
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
41
+ print(folder_path, "folder exists.")
42
+
43
+ # Resolve relative path to absolute path
44
+ absolute_path = os.path.abspath(folder_path)
45
+
46
+ current_path = os.environ["PATH"]
47
+ if absolute_path not in current_path.split(os.pathsep):
48
+ full_path_extension = absolute_path + os.pathsep + current_path
49
+ os.environ["PATH"] = full_path_extension
50
+ # print(f"Updated PATH with: ", full_path_extension)
51
+ else:
52
+ print(f"Directory {folder_path} already exists in PATH.")
53
+ else:
54
+ print(f"Folder not found at {folder_path} - not added to PATH")
55
+
56
+
57
+ def convert_string_to_boolean(value: str) -> bool:
58
+ """Convert string to boolean, handling various formats."""
59
+ if isinstance(value, bool):
60
+ return value
61
+ elif value in ["True", "1", "true", "TRUE"]:
62
+ return True
63
+ elif value in ["False", "0", "false", "FALSE"]:
64
+ return False
65
+ else:
66
+ raise ValueError(f"Invalid boolean value: {value}")
67
+
68
+
69
+ ###
70
+ # LOAD CONFIG FROM ENV FILE
71
+ ###
72
+
73
+ CONFIG_FOLDER = get_or_create_env_var("CONFIG_FOLDER", "config/")
74
+
75
+ # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
76
+ APP_CONFIG_PATH = get_or_create_env_var(
77
+ "APP_CONFIG_PATH", CONFIG_FOLDER + "app_config.env"
78
+ ) # e.g. config/app_config.env
79
+
80
+ if APP_CONFIG_PATH:
81
+ if os.path.exists(APP_CONFIG_PATH):
82
+ print(f"Loading app variables from config file {APP_CONFIG_PATH}")
83
+ load_dotenv(APP_CONFIG_PATH)
84
+ else:
85
+ print("App config file not found at location:", APP_CONFIG_PATH)
86
+
87
+ ###
88
+ # AWS OPTIONS
89
+ ###
90
+
91
+ # If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
92
+ AWS_CONFIG_PATH = get_or_create_env_var(
93
+ "AWS_CONFIG_PATH", ""
94
+ ) # e.g. config/aws_config.env
95
+
96
+ if AWS_CONFIG_PATH:
97
+ if os.path.exists(AWS_CONFIG_PATH):
98
+ print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
99
+ load_dotenv(AWS_CONFIG_PATH)
100
+ else:
101
+ print("AWS config file not found at location:", AWS_CONFIG_PATH)
102
+
103
+ RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
104
+
105
+ AWS_REGION = get_or_create_env_var("AWS_REGION", "")
106
+
107
+ AWS_CLIENT_ID = get_or_create_env_var("AWS_CLIENT_ID", "")
108
+
109
+ AWS_CLIENT_SECRET = get_or_create_env_var("AWS_CLIENT_SECRET", "")
110
+
111
+ AWS_USER_POOL_ID = get_or_create_env_var("AWS_USER_POOL_ID", "")
112
+
113
+ AWS_ACCESS_KEY = get_or_create_env_var("AWS_ACCESS_KEY", "")
114
+ # if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
115
+
116
+ AWS_SECRET_KEY = get_or_create_env_var("AWS_SECRET_KEY", "")
117
+ # if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
118
+
119
+ # Should the app prioritise using AWS SSO over using API keys stored in environment variables/secrets (defaults to yes)
120
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS = get_or_create_env_var(
121
+ "PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS", "1"
122
+ )
123
+
124
+ S3_LOG_BUCKET = get_or_create_env_var("S3_LOG_BUCKET", "")
125
+
126
+ # Custom headers e.g. if routing traffic through Cloudfront
127
+ # Retrieving or setting CUSTOM_HEADER
128
+ CUSTOM_HEADER = get_or_create_env_var("CUSTOM_HEADER", "")
129
+
130
+ # Retrieving or setting CUSTOM_HEADER_VALUE
131
+ CUSTOM_HEADER_VALUE = get_or_create_env_var("CUSTOM_HEADER_VALUE", "")
132
+
133
+ ###
134
+ # File I/O
135
+ ###
136
+ SESSION_OUTPUT_FOLDER = get_or_create_env_var(
137
+ "SESSION_OUTPUT_FOLDER", "False"
138
+ ) # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
139
+
140
+ OUTPUT_FOLDER = get_or_create_env_var("GRADIO_OUTPUT_FOLDER", "output/") # 'output/'
141
+ INPUT_FOLDER = get_or_create_env_var("GRADIO_INPUT_FOLDER", "input/") # 'input/'
142
+
143
+
144
+ # Allow for files to be saved in a temporary folder for increased security in some instances
145
+ if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
146
+ # Create a temporary directory
147
+ with tempfile.TemporaryDirectory() as temp_dir:
148
+ print(f"Temporary directory created at: {temp_dir}")
149
+
150
+ if OUTPUT_FOLDER == "TEMP":
151
+ OUTPUT_FOLDER = temp_dir + "/"
152
+ if INPUT_FOLDER == "TEMP":
153
+ INPUT_FOLDER = temp_dir + "/"
154
+
155
+
156
+ GRADIO_TEMP_DIR = get_or_create_env_var(
157
+ "GRADIO_TEMP_DIR", "tmp/gradio_tmp/"
158
+ ) # Default Gradio temp folder
159
+ MPLCONFIGDIR = get_or_create_env_var(
160
+ "MPLCONFIGDIR", "tmp/matplotlib_cache/"
161
+ ) # Matplotlib cache folder
162
+
163
+ S3_OUTPUTS_BUCKET = get_or_create_env_var("S3_OUTPUTS_BUCKET", "")
164
+ S3_OUTPUTS_FOLDER = get_or_create_env_var("S3_OUTPUTS_FOLDER", "")
165
+ SAVE_OUTPUTS_TO_S3 = get_or_create_env_var("SAVE_OUTPUTS_TO_S3", "False")
166
+
167
+ ###
168
+ # LOGGING OPTIONS
169
+ ###
170
+
171
+ # By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
172
+ # Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
173
+
174
+ SAVE_LOGS_TO_CSV = get_or_create_env_var("SAVE_LOGS_TO_CSV", "True")
175
+
176
+ USE_LOG_SUBFOLDERS = get_or_create_env_var("USE_LOG_SUBFOLDERS", "True")
177
+
178
+ FEEDBACK_LOGS_FOLDER = get_or_create_env_var("FEEDBACK_LOGS_FOLDER", "feedback/")
179
+ ACCESS_LOGS_FOLDER = get_or_create_env_var("ACCESS_LOGS_FOLDER", "logs/")
180
+ USAGE_LOGS_FOLDER = get_or_create_env_var("USAGE_LOGS_FOLDER", "usage/")
181
+
182
+ # Initialize full_log_subfolder based on USE_LOG_SUBFOLDERS setting
183
+ if USE_LOG_SUBFOLDERS == "True":
184
+ day_log_subfolder = today_rev + "/"
185
+ host_name_subfolder = HOST_NAME + "/"
186
+ full_log_subfolder = day_log_subfolder + host_name_subfolder
187
+
188
+ FEEDBACK_LOGS_FOLDER = FEEDBACK_LOGS_FOLDER + full_log_subfolder
189
+ ACCESS_LOGS_FOLDER = ACCESS_LOGS_FOLDER + full_log_subfolder
190
+ USAGE_LOGS_FOLDER = USAGE_LOGS_FOLDER + full_log_subfolder
191
+ else:
192
+ full_log_subfolder = "" # Empty string when subfolders are not used
193
+
194
+ S3_FEEDBACK_LOGS_FOLDER = get_or_create_env_var(
195
+ "S3_FEEDBACK_LOGS_FOLDER", "feedback/" + full_log_subfolder
196
+ )
197
+ S3_ACCESS_LOGS_FOLDER = get_or_create_env_var(
198
+ "S3_ACCESS_LOGS_FOLDER", "logs/" + full_log_subfolder
199
+ )
200
+ S3_USAGE_LOGS_FOLDER = get_or_create_env_var(
201
+ "S3_USAGE_LOGS_FOLDER", "usage/" + full_log_subfolder
202
+ )
203
+
204
+ LOG_FILE_NAME = get_or_create_env_var("LOG_FILE_NAME", "log.csv")
205
+ USAGE_LOG_FILE_NAME = get_or_create_env_var("USAGE_LOG_FILE_NAME", LOG_FILE_NAME)
206
+ FEEDBACK_LOG_FILE_NAME = get_or_create_env_var("FEEDBACK_LOG_FILE_NAME", LOG_FILE_NAME)
207
+
208
+ # Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
209
+ DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var(
210
+ "DISPLAY_FILE_NAMES_IN_LOGS", "False"
211
+ )
212
+
213
+ # Further customisation options for CSV logs
214
+
215
+ CSV_ACCESS_LOG_HEADERS = get_or_create_env_var(
216
+ "CSV_ACCESS_LOG_HEADERS", ""
217
+ ) # If blank, uses component labels
218
+ CSV_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
219
+ "CSV_FEEDBACK_LOG_HEADERS", ""
220
+ ) # If blank, uses component labels
221
+ CSV_USAGE_LOG_HEADERS = get_or_create_env_var(
222
+ "CSV_USAGE_LOG_HEADERS", ""
223
+ ) # If blank, uses component labels
224
+
225
+ ### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
226
+ SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var("SAVE_LOGS_TO_DYNAMODB", "False")
227
+
228
+ ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
229
+ "ACCESS_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_access_log"
230
+ )
231
+ DYNAMODB_ACCESS_LOG_HEADERS = get_or_create_env_var("DYNAMODB_ACCESS_LOG_HEADERS", "")
232
+
233
+ FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
234
+ "FEEDBACK_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_feedback"
235
+ )
236
+ DYNAMODB_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
237
+ "DYNAMODB_FEEDBACK_LOG_HEADERS", ""
238
+ )
239
+
240
+ USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
241
+ "USAGE_LOG_DYNAMODB_TABLE_NAME", "llm_topic_model_usage"
242
+ )
243
+ DYNAMODB_USAGE_LOG_HEADERS = get_or_create_env_var("DYNAMODB_USAGE_LOG_HEADERS", "")
244
+
245
+ # Report logging to console?
246
+ LOGGING = get_or_create_env_var("LOGGING", "False")
247
+
248
+ if LOGGING == "True":
249
+ # Configure logging
250
+ logging.basicConfig(
251
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
252
+ )
253
+
254
+ ###
255
+ # App run variables
256
+ ###
257
+ OUTPUT_DEBUG_FILES = get_or_create_env_var(
258
+ "OUTPUT_DEBUG_FILES", "False"
259
+ ) # Whether to output debug files
260
+ SHOW_ADDITIONAL_INSTRUCTION_TEXTBOXES = get_or_create_env_var(
261
+ "SHOW_ADDITIONAL_INSTRUCTION_TEXTBOXES", "True"
262
+ ) # Whether to show additional instruction textboxes in the GUI
263
+
264
+ TIMEOUT_WAIT = int(
265
+ get_or_create_env_var("TIMEOUT_WAIT", "30")
266
+ ) # Maximum number of seconds to wait for a response from the LLM
267
+ NUMBER_OF_RETRY_ATTEMPTS = int(
268
+ get_or_create_env_var("NUMBER_OF_RETRY_ATTEMPTS", "5")
269
+ ) # Maximum number of times to retry a request to the LLM
270
+ # Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
271
+ MAX_OUTPUT_VALIDATION_ATTEMPTS = int(
272
+ get_or_create_env_var("MAX_OUTPUT_VALIDATION_ATTEMPTS", "3")
273
+ )
274
+ ENABLE_VALIDATION = get_or_create_env_var(
275
+ "ENABLE_VALIDATION", "False"
276
+ ) # Whether to run validation loop after initial topic extraction
277
+ MAX_TIME_FOR_LOOP = int(
278
+ get_or_create_env_var("MAX_TIME_FOR_LOOP", "99999")
279
+ ) # Maximum number of seconds to run the loop for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there)
280
+
281
+ MAX_COMMENT_CHARS = int(
282
+ get_or_create_env_var("MAX_COMMENT_CHARS", "14000")
283
+ ) # Maximum number of characters in a comment
284
+ MAX_ROWS = int(
285
+ get_or_create_env_var("MAX_ROWS", "5000")
286
+ ) # Maximum number of rows to process
287
+ MAX_GROUPS = int(
288
+ get_or_create_env_var("MAX_GROUPS", "99")
289
+ ) # Maximum number of groups to process
290
+ BATCH_SIZE_DEFAULT = int(
291
+ get_or_create_env_var("BATCH_SIZE_DEFAULT", "5")
292
+ ) # Default batch size for LLM calls
293
+ MAXIMUM_ZERO_SHOT_TOPICS = int(
294
+ get_or_create_env_var("MAXIMUM_ZERO_SHOT_TOPICS", "120")
295
+ ) # Maximum number of zero shot topics to process
296
+ MAX_SPACES_GPU_RUN_TIME = int(
297
+ get_or_create_env_var("MAX_SPACES_GPU_RUN_TIME", "240")
298
+ ) # Maximum number of seconds to run on GPU on Hugging Face Spaces
299
+
300
+ DEDUPLICATION_THRESHOLD = int(
301
+ get_or_create_env_var("DEDUPLICATION_THRESHOLD", "90")
302
+ ) # Deduplication threshold for topic summary tables
303
+
304
+ ###
305
+ # Model options
306
+ ###
307
+
308
+ RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "0")
309
+
310
+ RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
311
+
312
+ RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
313
+ GEMINI_API_KEY = get_or_create_env_var("GEMINI_API_KEY", "")
314
+
315
+ INTRO_TEXT = get_or_create_env_var(
316
+ "INTRO_TEXT",
317
+ """# Large language model topic modelling
318
+
319
+ Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure/OpenAI, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure/OpenAI, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
320
+
321
+ NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""",
322
+ )
323
+
324
+ # Read in intro text from a text file if it is a path to a text file
325
+ if INTRO_TEXT.endswith(".txt"):
326
+ INTRO_TEXT = open(INTRO_TEXT, "r").read()
327
+
328
+ INTRO_TEXT = INTRO_TEXT.strip('"').strip("'")
329
+
330
+ # Azure/OpenAI AI Inference settings
331
+ RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "1")
332
+ AZURE_OPENAI_API_KEY = get_or_create_env_var("AZURE_OPENAI_API_KEY", "")
333
+ AZURE_OPENAI_INFERENCE_ENDPOINT = get_or_create_env_var(
334
+ "AZURE_OPENAI_INFERENCE_ENDPOINT", ""
335
+ )
336
+
337
+ # Llama-server settings
338
+ RUN_INFERENCE_SERVER = get_or_create_env_var("RUN_INFERENCE_SERVER", "0")
339
+ API_URL = get_or_create_env_var("API_URL", "http://localhost:8080")
340
+
341
+ RUN_MCP_SERVER = convert_string_to_boolean(
342
+ get_or_create_env_var("RUN_MCP_SERVER", "False")
343
+ )
344
+
345
+ # Build up options for models
346
+ model_full_names = list()
347
+ model_short_names = list()
348
+ model_source = list()
349
+
350
+ CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var(
351
+ "CHOSEN_LOCAL_MODEL_TYPE", "Qwen 3 4B"
352
+ ) # Gemma 3 1B # "Gemma 2b" # "Gemma 3 4B"
353
+
354
+ USE_LLAMA_SWAP = get_or_create_env_var("USE_LLAMA_SWAP", "False")
355
+ if USE_LLAMA_SWAP == "True":
356
+ USE_LLAMA_SWAP = True
357
+ else:
358
+ USE_LLAMA_SWAP = False
359
+
360
+ if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
361
+ model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
362
+ model_short_names.append(CHOSEN_LOCAL_MODEL_TYPE)
363
+ model_source.append("Local")
364
+
365
+ if RUN_AWS_BEDROCK_MODELS == "1":
366
+ amazon_models = [
367
+ "anthropic.claude-3-haiku-20240307-v1:0",
368
+ "anthropic.claude-3-7-sonnet-20250219-v1:0",
369
+ "anthropic.claude-sonnet-4-5-20250929-v1:0",
370
+ "amazon.nova-micro-v1:0",
371
+ "amazon.nova-lite-v1:0",
372
+ "amazon.nova-pro-v1:0",
373
+ "deepseek.v3-v1:0",
374
+ "openai.gpt-oss-20b-1:0",
375
+ "openai.gpt-oss-120b-1:0",
376
+ "google.gemma-3-12b-it",
377
+ "mistral.ministral-3-14b-instruct",
378
+ ]
379
+ model_full_names.extend(amazon_models)
380
+ model_short_names.extend(
381
+ [
382
+ "haiku",
383
+ "sonnet_3_7",
384
+ "sonnet_4_5",
385
+ "nova_micro",
386
+ "nova_lite",
387
+ "nova_pro",
388
+ "deepseek_v3",
389
+ "gpt_oss_20b_aws",
390
+ "gpt_oss_120b_aws",
391
+ "gemma_3_12b_it",
392
+ "ministral_3_14b_instruct",
393
+ ]
394
+ )
395
+ model_source.extend(["AWS"] * len(amazon_models))
396
+
397
+ if RUN_GEMINI_MODELS == "1":
398
+ gemini_models = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"]
399
+ model_full_names.extend(gemini_models)
400
+ model_short_names.extend(
401
+ ["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"]
402
+ )
403
+ model_source.extend(["Gemini"] * len(gemini_models))
404
+
405
+ # Register Azure/OpenAI AI models (model names must match your Azure/OpenAI deployments)
406
+ if RUN_AZURE_MODELS == "1":
407
+ # Example deployments; adjust to the deployments you actually create in Azure/OpenAI
408
+ azure_models = ["gpt-5-mini", "gpt-4o-mini"]
409
+ model_full_names.extend(azure_models)
410
+ model_short_names.extend(["gpt-5-mini", "gpt-4o-mini"])
411
+ model_source.extend(["Azure/OpenAI"] * len(azure_models))
412
+
413
+ # Register inference-server models
414
+ CHOSEN_INFERENCE_SERVER_MODEL = ""
415
+ if RUN_INFERENCE_SERVER == "1":
416
+ # Example inference-server models; adjust to the models you have available on your server
417
+ inference_server_models = [
418
+ "unnamed-inference-server-model",
419
+ "qwen_3_4b_it",
420
+ "qwen_3_4b_think",
421
+ "gpt_oss_20b",
422
+ "gemma_3_12b",
423
+ "ministral_3_14b_it",
424
+ ]
425
+ model_full_names.extend(inference_server_models)
426
+ model_short_names.extend(inference_server_models)
427
+ model_source.extend(["inference-server"] * len(inference_server_models))
428
+
429
+ CHOSEN_INFERENCE_SERVER_MODEL = get_or_create_env_var(
430
+ "CHOSEN_INFERENCE_SERVER_MODEL", inference_server_models[0]
431
+ )
432
+
433
+ if CHOSEN_INFERENCE_SERVER_MODEL not in inference_server_models:
434
+ model_full_names.append(CHOSEN_INFERENCE_SERVER_MODEL)
435
+ model_short_names.append(CHOSEN_INFERENCE_SERVER_MODEL)
436
+ model_source.append("inference-server")
437
+
438
+ model_name_map = {
439
+ full: {"short_name": short, "source": source}
440
+ for full, short, source in zip(model_full_names, model_short_names, model_source)
441
+ }
442
+
443
+ if RUN_LOCAL_MODEL == "1":
444
+ default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
445
+ elif RUN_INFERENCE_SERVER == "1":
446
+ default_model_choice = CHOSEN_INFERENCE_SERVER_MODEL
447
+ elif RUN_AWS_FUNCTIONS == "1":
448
+ default_model_choice = amazon_models[0]
449
+ else:
450
+ default_model_choice = gemini_models[0]
451
+
452
+ default_model_source = model_name_map[default_model_choice]["source"]
453
+ model_sources = list(
454
+ set([model_name_map[model]["source"] for model in model_full_names])
455
+ )
456
+
457
+
458
+ def update_model_choice_config(default_model_source, model_name_map):
459
+ # Filter models by source and return the first matching model name
460
+ matching_models = [
461
+ model_name
462
+ for model_name, model_info in model_name_map.items()
463
+ if model_info["source"] == default_model_source
464
+ ]
465
+
466
+ output_model = matching_models[0] if matching_models else model_full_names[0]
467
+
468
+ return output_model, matching_models
469
+
470
+
471
+ default_model_choice, default_source_models = update_model_choice_config(
472
+ default_model_source, model_name_map
473
+ )
474
+
475
+ # print("model_name_map:", model_name_map)
476
+
477
+ # HF token may or may not be needed for downloading models from Hugging Face
478
+ HF_TOKEN = get_or_create_env_var("HF_TOKEN", "")
479
+
480
+ LOAD_LOCAL_MODEL_AT_START = get_or_create_env_var("LOAD_LOCAL_MODEL_AT_START", "False")
481
+
482
+ # If you are using a system with low VRAM, you can set this to True to reduce the memory requirements
483
+ LOW_VRAM_SYSTEM = get_or_create_env_var("LOW_VRAM_SYSTEM", "False")
484
+
485
+ MULTIMODAL_PROMPT_FORMAT = get_or_create_env_var("MULTIMODAL_PROMPT_FORMAT", "False")
486
+
487
+ if LOW_VRAM_SYSTEM == "True":
488
+ print("Using settings for low VRAM system")
489
+ USE_LLAMA_CPP = get_or_create_env_var("USE_LLAMA_CPP", "True")
490
+ LLM_MAX_NEW_TOKENS = int(get_or_create_env_var("LLM_MAX_NEW_TOKENS", "4096"))
491
+ LLM_CONTEXT_LENGTH = int(get_or_create_env_var("LLM_CONTEXT_LENGTH", "16384"))
492
+ LLM_BATCH_SIZE = int(get_or_create_env_var("LLM_BATCH_SIZE", "512"))
493
+ K_QUANT_LEVEL = int(
494
+ get_or_create_env_var("K_QUANT_LEVEL", "2")
495
+ ) # 2 = q4_0, 8 = q8_0, 4 = fp16
496
+ V_QUANT_LEVEL = int(
497
+ get_or_create_env_var("V_QUANT_LEVEL", "2")
498
+ ) # 2 = q4_0, 8 = q8_0, 4 = fp16
499
+
500
+ USE_LLAMA_CPP = get_or_create_env_var(
501
+ "USE_LLAMA_CPP", "True"
502
+ ) # Llama.cpp or transformers with unsloth
503
+
504
+ LOCAL_REPO_ID = get_or_create_env_var("LOCAL_REPO_ID", "")
505
+ LOCAL_MODEL_FILE = get_or_create_env_var("LOCAL_MODEL_FILE", "")
506
+ LOCAL_MODEL_FOLDER = get_or_create_env_var("LOCAL_MODEL_FOLDER", "")
507
+
508
+ GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "unsloth/gemma-2-it-GGUF")
509
+ GEMMA2_REPO_TRANSFORMERS_ID = get_or_create_env_var(
510
+ "GEMMA2_2B_REPO_TRANSFORMERS_ID", "unsloth/gemma-2-2b-it-bnb-4bit"
511
+ )
512
+ if USE_LLAMA_CPP == "False":
513
+ GEMMA2_REPO_ID = GEMMA2_REPO_TRANSFORMERS_ID
514
+ GEMMA2_MODEL_FILE = get_or_create_env_var(
515
+ "GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it.q8_0.gguf"
516
+ )
517
+ GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
518
+
519
+ GEMMA3_4B_REPO_ID = get_or_create_env_var(
520
+ "GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF"
521
+ )
522
+ GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
523
+ "GEMMA3_4B_REPO_TRANSFORMERS_ID", "unsloth/gemma-3-4b-it-bnb-4bit"
524
+ )
525
+ if USE_LLAMA_CPP == "False":
526
+ GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
527
+ GEMMA3_4B_MODEL_FILE = get_or_create_env_var(
528
+ "GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-UD-Q4_K_XL.gguf"
529
+ )
530
+ GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var(
531
+ "GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b"
532
+ )
533
+
534
+ GEMMA3_12B_REPO_ID = get_or_create_env_var(
535
+ "GEMMA3_12B_REPO_ID", "unsloth/gemma-3-12b-it-GGUF"
536
+ )
537
+ GEMMA3_12B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
538
+ "GEMMA3_12B_REPO_TRANSFORMERS_ID", "unsloth/gemma-3-12b-it-bnb-4bit"
539
+ )
540
+ if USE_LLAMA_CPP == "False":
541
+ GEMMA3_12B_REPO_ID = GEMMA3_12B_REPO_TRANSFORMERS_ID
542
+ GEMMA3_12B_MODEL_FILE = get_or_create_env_var(
543
+ "GEMMA3_12B_MODEL_FILE", "gemma-3-12b-it-UD-Q4_K_XL.gguf"
544
+ )
545
+ GEMMA3_12B_MODEL_FOLDER = get_or_create_env_var(
546
+ "GEMMA3_12B_MODEL_FOLDER", "model/gemma3_12b"
547
+ )
548
+
549
+ GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
550
+ GPT_OSS_REPO_TRANSFORMERS_ID = get_or_create_env_var(
551
+ "GPT_OSS_REPO_TRANSFORMERS_ID", "unsloth/gpt-oss-20b-unsloth-bnb-4bit"
552
+ )
553
+ if USE_LLAMA_CPP == "False":
554
+ GPT_OSS_REPO_ID = GPT_OSS_REPO_TRANSFORMERS_ID
555
+ GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
556
+ GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
557
+
558
+ QWEN3_4B_REPO_ID = get_or_create_env_var(
559
+ "QWEN3_4B_REPO_ID", "unsloth/Qwen3-4B-Instruct-2507-GGUF"
560
+ )
561
+ QWEN3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
562
+ "QWEN3_4B_REPO_TRANSFORMERS_ID", "unsloth/Qwen3-4B-unsloth-bnb-4bit"
563
+ )
564
+ if USE_LLAMA_CPP == "False":
565
+ QWEN3_4B_REPO_ID = QWEN3_4B_REPO_TRANSFORMERS_ID
566
+
567
+ QWEN3_4B_MODEL_FILE = get_or_create_env_var(
568
+ "QWEN3_4B_MODEL_FILE", "Qwen3-4B-Instruct-2507-UD-Q4_K_XL.gguf"
569
+ )
570
+ QWEN3_4B_MODEL_FOLDER = get_or_create_env_var("QWEN3_4B_MODEL_FOLDER", "model/qwen")
571
+
572
+ GRANITE_4_TINY_REPO_ID = get_or_create_env_var(
573
+ "GRANITE_4_TINY_REPO_ID", "unsloth/granite-4.0-h-tiny-GGUF"
574
+ )
575
+ GRANITE_4_TINY_REPO_TRANSFORMERS_ID = get_or_create_env_var(
576
+ "GRANITE_4_TINY_REPO_TRANSFORMERS_ID", "unsloth/granite-4.0-h-tiny-FP8-Dynamic"
577
+ )
578
+ if USE_LLAMA_CPP == "False":
579
+ GRANITE_4_TINY_REPO_ID = GRANITE_4_TINY_REPO_TRANSFORMERS_ID
580
+ GRANITE_4_TINY_MODEL_FILE = get_or_create_env_var(
581
+ "GRANITE_4_TINY_MODEL_FILE", "granite-4.0-h-tiny-UD-Q4_K_XL.gguf"
582
+ )
583
+ GRANITE_4_TINY_MODEL_FOLDER = get_or_create_env_var(
584
+ "GRANITE_4_TINY_MODEL_FOLDER", "model/granite"
585
+ )
586
+
587
+ GRANITE_4_3B_REPO_ID = get_or_create_env_var(
588
+ "GRANITE_4_3B_REPO_ID", "unsloth/granite-4.0-h-micro-GGUF"
589
+ )
590
+ GRANITE_4_3B_REPO_TRANSFORMERS_ID = get_or_create_env_var(
591
+ "GRANITE_4_3B_REPO_TRANSFORMERS_ID", "unsloth/granite-4.0-micro-unsloth-bnb-4bit"
592
+ )
593
+ if USE_LLAMA_CPP == "False":
594
+ GRANITE_4_3B_REPO_ID = GRANITE_4_3B_REPO_TRANSFORMERS_ID
595
+ GRANITE_4_3B_MODEL_FILE = get_or_create_env_var(
596
+ "GRANITE_4_3B_MODEL_FILE", "granite-4.0-h-micro-UD-Q4_K_XL.gguf"
597
+ )
598
+ GRANITE_4_3B_MODEL_FOLDER = get_or_create_env_var(
599
+ "GRANITE_4_3B_MODEL_FOLDER", "model/granite"
600
+ )
601
+
602
+ if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
603
+ LOCAL_REPO_ID = GEMMA2_REPO_ID
604
+ LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
605
+ LOCAL_MODEL_FOLDER = GEMMA2_MODEL_FOLDER
606
+
607
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
608
+ LOCAL_REPO_ID = GEMMA3_4B_REPO_ID
609
+ LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
610
+ LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
611
+ MULTIMODAL_PROMPT_FORMAT = "True"
612
+
613
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 12B":
614
+ LOCAL_REPO_ID = GEMMA3_12B_REPO_ID
615
+ LOCAL_MODEL_FILE = GEMMA3_12B_MODEL_FILE
616
+ LOCAL_MODEL_FOLDER = GEMMA3_12B_MODEL_FOLDER
617
+ MULTIMODAL_PROMPT_FORMAT = "True"
618
+
619
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B":
620
+ LOCAL_REPO_ID = QWEN3_4B_REPO_ID
621
+ LOCAL_MODEL_FILE = QWEN3_4B_MODEL_FILE
622
+ LOCAL_MODEL_FOLDER = QWEN3_4B_MODEL_FOLDER
623
+
624
+ elif CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
625
+ LOCAL_REPO_ID = GPT_OSS_REPO_ID
626
+ LOCAL_MODEL_FILE = GPT_OSS_MODEL_FILE
627
+ LOCAL_MODEL_FOLDER = GPT_OSS_MODEL_FOLDER
628
+
629
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Granite 4 Tiny":
630
+ LOCAL_REPO_ID = GRANITE_4_TINY_REPO_ID
631
+ LOCAL_MODEL_FILE = GRANITE_4_TINY_MODEL_FILE
632
+ LOCAL_MODEL_FOLDER = GRANITE_4_TINY_MODEL_FOLDER
633
+
634
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Granite 4 Micro":
635
+ LOCAL_REPO_ID = GRANITE_4_3B_REPO_ID
636
+ LOCAL_MODEL_FILE = GRANITE_4_3B_MODEL_FILE
637
+ LOCAL_MODEL_FOLDER = GRANITE_4_3B_MODEL_FOLDER
638
+
639
+ elif not CHOSEN_LOCAL_MODEL_TYPE:
640
+ print("No local model type chosen")
641
+ LOCAL_REPO_ID = ""
642
+ LOCAL_MODEL_FILE = ""
643
+ LOCAL_MODEL_FOLDER = ""
644
+ else:
645
+ print("CHOSEN_LOCAL_MODEL_TYPE not found")
646
+ LOCAL_REPO_ID = ""
647
+ LOCAL_MODEL_FILE = ""
648
+ LOCAL_MODEL_FOLDER = ""
649
+
650
+ USE_SPECULATIVE_DECODING = get_or_create_env_var("USE_SPECULATIVE_DECODING", "False")
651
+
652
+ ASSISTANT_MODEL = get_or_create_env_var("ASSISTANT_MODEL", "")
653
+ if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
654
+ ASSISTANT_MODEL = get_or_create_env_var(
655
+ "ASSISTANT_MODEL", "unsloth/gemma-3-270m-it"
656
+ )
657
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B":
658
+ ASSISTANT_MODEL = get_or_create_env_var("ASSISTANT_MODEL", "unsloth/Qwen3-0.6B")
659
+
660
+ DRAFT_MODEL_LOC = get_or_create_env_var("DRAFT_MODEL_LOC", ".cache/llama.cpp/")
661
+
662
+ GEMMA3_DRAFT_MODEL_LOC = get_or_create_env_var(
663
+ "GEMMA3_DRAFT_MODEL_LOC",
664
+ DRAFT_MODEL_LOC + "unsloth_gemma-3-270m-it-qat-GGUF_gemma-3-270m-it-qat-F16.gguf",
665
+ )
666
+ GEMMA3_4B_DRAFT_MODEL_LOC = get_or_create_env_var(
667
+ "GEMMA3_4B_DRAFT_MODEL_LOC",
668
+ DRAFT_MODEL_LOC + "unsloth_gemma-3-4b-it-qat-GGUF_gemma-3-4b-it-qat-Q4_K_M.gguf",
669
+ )
670
+
671
+ QWEN3_DRAFT_MODEL_LOC = get_or_create_env_var(
672
+ "QWEN3_DRAFT_MODEL_LOC", DRAFT_MODEL_LOC + "Qwen3-0.6B-Q8_0.gguf"
673
+ )
674
+ QWEN3_4B_DRAFT_MODEL_LOC = get_or_create_env_var(
675
+ "QWEN3_4B_DRAFT_MODEL_LOC",
676
+ DRAFT_MODEL_LOC + "Qwen3-4B-Instruct-2507-UD-Q4_K_XL.gguf",
677
+ )
678
+
679
+
680
+ LLM_MAX_GPU_LAYERS = int(
681
+ get_or_create_env_var("LLM_MAX_GPU_LAYERS", "-1")
682
+ ) # Maximum possible
683
+ LLM_TEMPERATURE = float(get_or_create_env_var("LLM_TEMPERATURE", "0.6"))
684
+ LLM_TOP_K = int(
685
+ get_or_create_env_var("LLM_TOP_K", "64")
686
+ ) # https://docs.unsloth.ai/basics/gemma-3-how-to-run-and-fine-tune
687
+ LLM_MIN_P = float(get_or_create_env_var("LLM_MIN_P", "0"))
688
+ LLM_TOP_P = float(get_or_create_env_var("LLM_TOP_P", "0.95"))
689
+ LLM_REPETITION_PENALTY = float(get_or_create_env_var("LLM_REPETITION_PENALTY", "1.0"))
690
+
691
+ LLM_LAST_N_TOKENS = int(get_or_create_env_var("LLM_LAST_N_TOKENS", "512"))
692
+ LLM_MAX_NEW_TOKENS = int(get_or_create_env_var("LLM_MAX_NEW_TOKENS", "4096"))
693
+ LLM_SEED = int(get_or_create_env_var("LLM_SEED", "42"))
694
+ LLM_RESET = get_or_create_env_var("LLM_RESET", "False")
695
+ LLM_STREAM = get_or_create_env_var("LLM_STREAM", "True")
696
+ LLM_THREADS = int(get_or_create_env_var("LLM_THREADS", "-1"))
697
+ LLM_BATCH_SIZE = int(get_or_create_env_var("LLM_BATCH_SIZE", "2048"))
698
+ LLM_CONTEXT_LENGTH = int(get_or_create_env_var("LLM_CONTEXT_LENGTH", "24576"))
699
+ LLM_SAMPLE = get_or_create_env_var("LLM_SAMPLE", "True")
700
+ LLM_STOP_STRINGS = get_or_create_env_var("LLM_STOP_STRINGS", r"['\n\n\n\n\n\n']")
701
+
702
+ SPECULATIVE_DECODING = get_or_create_env_var("SPECULATIVE_DECODING", "False")
703
+ NUM_PRED_TOKENS = int(get_or_create_env_var("NUM_PRED_TOKENS", "2"))
704
+ K_QUANT_LEVEL = get_or_create_env_var(
705
+ "K_QUANT_LEVEL", ""
706
+ ) # 2 = q4_0, 8 = q8_0, 4 = fp16
707
+ V_QUANT_LEVEL = get_or_create_env_var(
708
+ "V_QUANT_LEVEL", ""
709
+ ) # 2 = q4_0, 8 = q8_0, 4 = fp16
710
+
711
+ if not K_QUANT_LEVEL:
712
+ K_QUANT_LEVEL = None
713
+ else:
714
+ K_QUANT_LEVEL = int(K_QUANT_LEVEL)
715
+ if not V_QUANT_LEVEL:
716
+ V_QUANT_LEVEL = None
717
+ else:
718
+ V_QUANT_LEVEL = int(V_QUANT_LEVEL)
719
+
720
+ # If you are using e.g. gpt-oss, you can add a reasoning suffix to set reasoning level, or turn it off in the case of Qwen 3 4B
721
+ if CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
722
+ REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "Reasoning: low")
723
+ elif CHOSEN_LOCAL_MODEL_TYPE == "Qwen 3 4B" and USE_LLAMA_CPP == "False":
724
+ REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "/nothink")
725
+ else:
726
+ REASONING_SUFFIX = get_or_create_env_var("REASONING_SUFFIX", "")
727
+
728
+ # Transformers variables
729
+ COMPILE_TRANSFORMERS = get_or_create_env_var(
730
+ "COMPILE_TRANSFORMERS", "False"
731
+ ) # Whether to compile transformers models
732
+ USE_BITSANDBYTES = get_or_create_env_var(
733
+ "USE_BITSANDBYTES", "True"
734
+ ) # Whether to use bitsandbytes for quantization
735
+ COMPILE_MODE = get_or_create_env_var(
736
+ "COMPILE_MODE", "reduce-overhead"
737
+ ) # alternatively 'max-autotune'
738
+ MODEL_DTYPE = get_or_create_env_var(
739
+ "MODEL_DTYPE", "bfloat16"
740
+ ) # alternatively 'bfloat16'
741
+ INT8_WITH_OFFLOAD_TO_CPU = get_or_create_env_var(
742
+ "INT8_WITH_OFFLOAD_TO_CPU", "False"
743
+ ) # Whether to offload to CPU
744
+
745
+ DEFAULT_SAMPLED_SUMMARIES = int(
746
+ get_or_create_env_var("DEFAULT_SAMPLED_SUMMARIES", "75")
747
+ )
748
+
749
+ ###
750
+ # Gradio app variables
751
+ ###
752
+
753
+ # Get some environment variables and Launch the Gradio app
754
+ COGNITO_AUTH = get_or_create_env_var("COGNITO_AUTH", "0")
755
+
756
+ RUN_DIRECT_MODE = get_or_create_env_var("RUN_DIRECT_MODE", "0")
757
+
758
+ # Direct mode environment variables
759
+ DIRECT_MODE_TASK = get_or_create_env_var("DIRECT_MODE_TASK", "extract")
760
+ DIRECT_MODE_INPUT_FILE = get_or_create_env_var("DIRECT_MODE_INPUT_FILE", "")
761
+ DIRECT_MODE_OUTPUT_DIR = get_or_create_env_var("DIRECT_MODE_OUTPUT_DIR", OUTPUT_FOLDER)
762
+ DIRECT_MODE_S3_OUTPUT_BUCKET = get_or_create_env_var(
763
+ "DIRECT_MODE_S3_OUTPUT_BUCKET", S3_OUTPUTS_BUCKET
764
+ )
765
+ DIRECT_MODE_TEXT_COLUMN = get_or_create_env_var("DIRECT_MODE_TEXT_COLUMN", "")
766
+ DIRECT_MODE_PREVIOUS_OUTPUT_FILES = get_or_create_env_var(
767
+ "DIRECT_MODE_PREVIOUS_OUTPUT_FILES", ""
768
+ )
769
+ DIRECT_MODE_USERNAME = get_or_create_env_var("DIRECT_MODE_USERNAME", "")
770
+ DIRECT_MODE_GROUP_BY = get_or_create_env_var("DIRECT_MODE_GROUP_BY", "")
771
+ DIRECT_MODE_EXCEL_SHEETS = get_or_create_env_var("DIRECT_MODE_EXCEL_SHEETS", "")
772
+ DIRECT_MODE_MODEL_CHOICE = get_or_create_env_var(
773
+ "DIRECT_MODE_MODEL_CHOICE", default_model_choice
774
+ )
775
+ DIRECT_MODE_TEMPERATURE = get_or_create_env_var(
776
+ "DIRECT_MODE_TEMPERATURE", str(LLM_TEMPERATURE)
777
+ )
778
+ DIRECT_MODE_BATCH_SIZE = get_or_create_env_var(
779
+ "DIRECT_MODE_BATCH_SIZE", str(BATCH_SIZE_DEFAULT)
780
+ )
781
+ DIRECT_MODE_MAX_TOKENS = get_or_create_env_var(
782
+ "DIRECT_MODE_MAX_TOKENS", str(LLM_MAX_NEW_TOKENS)
783
+ )
784
+ DIRECT_MODE_CONTEXT = get_or_create_env_var("DIRECT_MODE_CONTEXT", "")
785
+ DIRECT_MODE_CANDIDATE_TOPICS = get_or_create_env_var("DIRECT_MODE_CANDIDATE_TOPICS", "")
786
+ DIRECT_MODE_FORCE_ZERO_SHOT = get_or_create_env_var("DIRECT_MODE_FORCE_ZERO_SHOT", "No")
787
+ DIRECT_MODE_FORCE_SINGLE_TOPIC = get_or_create_env_var(
788
+ "DIRECT_MODE_FORCE_SINGLE_TOPIC", "No"
789
+ )
790
+ DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY = get_or_create_env_var(
791
+ "DIRECT_MODE_PRODUCE_STRUCTURED_SUMMARY", "No"
792
+ )
793
+ DIRECT_MODE_SENTIMENT = get_or_create_env_var(
794
+ "DIRECT_MODE_SENTIMENT", "Negative or Positive"
795
+ )
796
+ DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS = get_or_create_env_var(
797
+ "DIRECT_MODE_ADDITIONAL_SUMMARY_INSTRUCTIONS", ""
798
+ )
799
+ DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES = get_or_create_env_var(
800
+ "DIRECT_MODE_ADDITIONAL_VALIDATION_ISSUES", ""
801
+ )
802
+ DIRECT_MODE_SHOW_PREVIOUS_TABLE = get_or_create_env_var(
803
+ "DIRECT_MODE_SHOW_PREVIOUS_TABLE", "Yes"
804
+ )
805
+ DIRECT_MODE_MAX_TIME_FOR_LOOP = get_or_create_env_var(
806
+ "DIRECT_MODE_MAX_TIME_FOR_LOOP", str(MAX_TIME_FOR_LOOP)
807
+ )
808
+ DIRECT_MODE_DEDUP_METHOD = get_or_create_env_var("DIRECT_MODE_DEDUP_METHOD", "fuzzy")
809
+ DIRECT_MODE_SIMILARITY_THRESHOLD = get_or_create_env_var(
810
+ "DIRECT_MODE_SIMILARITY_THRESHOLD", str(DEDUPLICATION_THRESHOLD)
811
+ )
812
+ DIRECT_MODE_MERGE_SENTIMENT = get_or_create_env_var("DIRECT_MODE_MERGE_SENTIMENT", "No")
813
+ DIRECT_MODE_MERGE_GENERAL_TOPICS = get_or_create_env_var(
814
+ "DIRECT_MODE_MERGE_GENERAL_TOPICS", "Yes"
815
+ )
816
+ DIRECT_MODE_SUMMARY_FORMAT = get_or_create_env_var(
817
+ "DIRECT_MODE_SUMMARY_FORMAT", "two_paragraph"
818
+ )
819
+ DIRECT_MODE_SAMPLE_REFERENCE_TABLE = get_or_create_env_var(
820
+ "DIRECT_MODE_SAMPLE_REFERENCE_TABLE", "True"
821
+ )
822
+ DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES = get_or_create_env_var(
823
+ "DIRECT_MODE_NO_OF_SAMPLED_SUMMARIES", str(DEFAULT_SAMPLED_SUMMARIES)
824
+ )
825
+ DIRECT_MODE_RANDOM_SEED = get_or_create_env_var(
826
+ "DIRECT_MODE_RANDOM_SEED", str(LLM_SEED)
827
+ )
828
+ DIRECT_MODE_CREATE_XLSX_OUTPUT = get_or_create_env_var(
829
+ "DIRECT_MODE_CREATE_XLSX_OUTPUT", "True"
830
+ )
831
+ # CHOSEN_INFERENCE_SERVER_MODEL is defined later, so we'll handle it after that definition
832
+
833
+ MAX_QUEUE_SIZE = int(get_or_create_env_var("MAX_QUEUE_SIZE", "5"))
834
+
835
+ MAX_FILE_SIZE = get_or_create_env_var("MAX_FILE_SIZE", "250mb")
836
+
837
+ GRADIO_SERVER_PORT = int(get_or_create_env_var("GRADIO_SERVER_PORT", "7860"))
838
+
839
+ ROOT_PATH = get_or_create_env_var("ROOT_PATH", "")
840
+
841
+ DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var("DEFAULT_CONCURRENCY_LIMIT", "3")
842
+
843
+ GET_DEFAULT_ALLOW_LIST = get_or_create_env_var("GET_DEFAULT_ALLOW_LIST", "")
844
+
845
+ ALLOW_LIST_PATH = get_or_create_env_var(
846
+ "ALLOW_LIST_PATH", ""
847
+ ) # config/default_allow_list.csv
848
+
849
+ S3_ALLOW_LIST_PATH = get_or_create_env_var(
850
+ "S3_ALLOW_LIST_PATH", ""
851
+ ) # default_allow_list.csv # This is a path within the named S3 bucket
852
+
853
+ if ALLOW_LIST_PATH:
854
+ OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
855
+ else:
856
+ OUTPUT_ALLOW_LIST_PATH = "config/default_allow_list.csv"
857
+
858
+ FILE_INPUT_HEIGHT = int(get_or_create_env_var("FILE_INPUT_HEIGHT", "125"))
859
+
860
+ SHOW_EXAMPLES = get_or_create_env_var("SHOW_EXAMPLES", "True")
861
+
862
+ ###
863
+ # COST CODE OPTIONS
864
+ ###
865
+
866
+ SHOW_COSTS = get_or_create_env_var("SHOW_COSTS", "False")
867
+
868
+ GET_COST_CODES = get_or_create_env_var("GET_COST_CODES", "False")
869
+
870
+ DEFAULT_COST_CODE = get_or_create_env_var("DEFAULT_COST_CODE", "")
871
+
872
+ COST_CODES_PATH = get_or_create_env_var(
873
+ "COST_CODES_PATH", ""
874
+ ) # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
875
+
876
+ S3_COST_CODES_PATH = get_or_create_env_var(
877
+ "S3_COST_CODES_PATH", ""
878
+ ) # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
879
+
880
+ # A default path in case s3 cost code location is provided but no local cost code location given
881
+ if COST_CODES_PATH:
882
+ OUTPUT_COST_CODES_PATH = COST_CODES_PATH
883
+ else:
884
+ OUTPUT_COST_CODES_PATH = "config/cost_codes.csv"
885
+
886
+ ENFORCE_COST_CODES = get_or_create_env_var(
887
+ "ENFORCE_COST_CODES", "False"
888
+ ) # If you have cost codes listed, is it compulsory to choose one before redacting?
889
+
890
+ if ENFORCE_COST_CODES == "True":
891
+ GET_COST_CODES = "True"
892
+
893
+ ###
894
+ # VALIDATE FOLDERS AND CONFIG OPTIONS
895
+ ###
896
+
897
+
898
+ def ensure_folder_exists(output_folder: str):
899
+ """Checks if the specified folder exists, creates it if not."""
900
+
901
+ if not os.path.exists(output_folder):
902
+ # Create the folder if it doesn't exist
903
+ os.makedirs(output_folder, exist_ok=True)
904
+ print(f"Created the {output_folder} folder.")
905
+ else:
906
+ pass
907
+ # print(f"The {output_folder} folder already exists.")
908
+
909
+
910
+ def _get_env_list(env_var_name: str, strip_strings: bool = True) -> List[str]:
911
+ """Parses a comma-separated environment variable into a list of strings."""
912
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
913
+ if not value:
914
+ return []
915
+ # Split by comma and filter out any empty strings that might result from extra commas
916
+ if strip_strings:
917
+ return [s.strip() for s in value.split(",") if s.strip()]
918
+ else:
919
+ return [codecs.decode(s, "unicode_escape") for s in value.split(",") if s]
920
+
921
+
922
+ # Convert string environment variables to string or list
923
+ if SAVE_LOGS_TO_CSV == "True":
924
+ SAVE_LOGS_TO_CSV = True
925
+ else:
926
+ SAVE_LOGS_TO_CSV = False
927
+ if SAVE_LOGS_TO_DYNAMODB == "True":
928
+ SAVE_LOGS_TO_DYNAMODB = True
929
+ else:
930
+ SAVE_LOGS_TO_DYNAMODB = False
931
+
932
+ if CSV_ACCESS_LOG_HEADERS:
933
+ CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
934
+ if CSV_FEEDBACK_LOG_HEADERS:
935
+ CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
936
+ if CSV_USAGE_LOG_HEADERS:
937
+ CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
938
+
939
+ if DYNAMODB_ACCESS_LOG_HEADERS:
940
+ DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
941
+ if DYNAMODB_FEEDBACK_LOG_HEADERS:
942
+ DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
943
+ if DYNAMODB_USAGE_LOG_HEADERS:
944
+ DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
945
+
946
+ # Set DIRECT_MODE_INFERENCE_SERVER_MODEL after CHOSEN_INFERENCE_SERVER_MODEL is defined
947
+ DIRECT_MODE_INFERENCE_SERVER_MODEL = get_or_create_env_var(
948
+ "DIRECT_MODE_INFERENCE_SERVER_MODEL",
949
+ CHOSEN_INFERENCE_SERVER_MODEL if CHOSEN_INFERENCE_SERVER_MODEL else "",
950
+ )
tools/custom_csvlogger.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import os
5
+ import re
6
+ import time
7
+ import uuid
8
+ from collections.abc import Sequence
9
+ from datetime import datetime
10
+
11
+ # from multiprocessing import Lock
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ import boto3
16
+ import botocore
17
+ from gradio import utils
18
+ from gradio_client import utils as client_utils
19
+
20
+ from tools.config import AWS_ACCESS_KEY, AWS_REGION, AWS_SECRET_KEY, RUN_AWS_FUNCTIONS
21
+
22
+ if TYPE_CHECKING:
23
+ from gradio.components import Component
24
+ from threading import Lock
25
+
26
+ from gradio.flagging import FlaggingCallback
27
+
28
+
29
+ class CSVLogger_custom(FlaggingCallback):
30
+ """
31
+ The default implementation of the FlaggingCallback abstract class in gradio>=5.0. Each flagged
32
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running
33
+ the gradio app. Unlike ClassicCSVLogger, this implementation is concurrent-safe and it creates a new
34
+ dataset file every time the headers of the CSV (derived from the labels of the components) change. It also
35
+ only creates columns for "username" and "flag" if the flag_option and username are provided, respectively.
36
+
37
+ Example:
38
+ import gradio as gr
39
+ def image_classifier(inp):
40
+ return {'cat': 0.3, 'dog': 0.7}
41
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
42
+ flagging_callback=CSVLogger())
43
+ Guides: using-flagging
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ simplify_file_data: bool = True,
49
+ verbose: bool = True,
50
+ dataset_file_name: str | None = None,
51
+ ):
52
+ """
53
+ Parameters:
54
+ simplify_file_data: If True, the file data will be simplified before being written to the CSV file. If CSVLogger is being used to cache examples, this is set to False to preserve the original FileData class
55
+ verbose: If True, prints messages to the console about the dataset file creation
56
+ dataset_file_name: The name of the dataset file to be created (should end in ".csv"). If None, the dataset file will be named "dataset1.csv" or the next available number.
57
+ """
58
+ self.simplify_file_data = simplify_file_data
59
+ self.verbose = verbose
60
+ self.dataset_file_name = dataset_file_name
61
+ self.lock = Lock()
62
+
63
+ def setup(
64
+ self,
65
+ components: Sequence[Component],
66
+ flagging_dir: str | Path,
67
+ ):
68
+ self.components = components
69
+ self.flagging_dir = Path(flagging_dir)
70
+ self.first_time = True
71
+
72
+ def _create_dataset_file(
73
+ self,
74
+ additional_headers: list[str] | None = None,
75
+ replacement_headers: list[str] | None = None,
76
+ ):
77
+ os.makedirs(self.flagging_dir, exist_ok=True)
78
+
79
+ if replacement_headers:
80
+ if additional_headers is None:
81
+ additional_headers = []
82
+
83
+ if len(replacement_headers) != len(self.components):
84
+ raise ValueError(
85
+ f"replacement_headers must have the same length as components "
86
+ f"({len(replacement_headers)} provided, {len(self.components)} expected)"
87
+ )
88
+ headers = replacement_headers + additional_headers + ["timestamp"]
89
+ else:
90
+ if additional_headers is None:
91
+ additional_headers = []
92
+ headers = (
93
+ [
94
+ getattr(component, "label", None) or f"component {idx}"
95
+ for idx, component in enumerate(self.components)
96
+ ]
97
+ + additional_headers
98
+ + ["timestamp"]
99
+ )
100
+
101
+ headers = utils.sanitize_list_for_csv(headers)
102
+ dataset_files = list(Path(self.flagging_dir).glob("dataset*.csv"))
103
+
104
+ if self.dataset_file_name:
105
+ self.dataset_filepath = self.flagging_dir / self.dataset_file_name
106
+ elif dataset_files:
107
+ try:
108
+ latest_file = max(
109
+ dataset_files, key=lambda f: int(re.findall(r"\d+", f.stem)[0])
110
+ )
111
+ latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
112
+
113
+ with open(latest_file, newline="", encoding="utf-8-sig") as csvfile:
114
+ reader = csv.reader(csvfile)
115
+ existing_headers = next(reader, None)
116
+
117
+ if existing_headers != headers:
118
+ new_num = latest_num + 1
119
+ self.dataset_filepath = self.flagging_dir / f"dataset{new_num}.csv"
120
+ else:
121
+ self.dataset_filepath = latest_file
122
+ except Exception:
123
+ self.dataset_filepath = self.flagging_dir / "dataset1.csv"
124
+ else:
125
+ self.dataset_filepath = self.flagging_dir / "dataset1.csv"
126
+
127
+ if not Path(self.dataset_filepath).exists():
128
+ with open(
129
+ self.dataset_filepath, "w", newline="", encoding="utf-8-sig"
130
+ ) as csvfile:
131
+ writer = csv.writer(csvfile)
132
+ writer.writerow(utils.sanitize_list_for_csv(headers))
133
+ if self.verbose:
134
+ print("Created dataset file at:", self.dataset_filepath)
135
+ elif self.verbose:
136
+ print("Using existing dataset file at:", self.dataset_filepath)
137
+
138
+ def flag(
139
+ self,
140
+ flag_data: list[Any],
141
+ flag_option: str | None = None,
142
+ username: str | None = None,
143
+ save_to_csv: bool = True,
144
+ save_to_dynamodb: bool = False,
145
+ dynamodb_table_name: str | None = None,
146
+ dynamodb_headers: list[str] | None = None, # New: specify headers for DynamoDB
147
+ replacement_headers: list[str] | None = None,
148
+ ) -> int:
149
+ if self.first_time:
150
+ # print("First time creating log file")
151
+ additional_headers = []
152
+ if flag_option is not None:
153
+ additional_headers.append("flag")
154
+ if username is not None:
155
+ additional_headers.append("username")
156
+ additional_headers.append("id")
157
+ # additional_headers.append("timestamp")
158
+ self._create_dataset_file(
159
+ additional_headers=additional_headers,
160
+ replacement_headers=replacement_headers,
161
+ )
162
+ self.first_time = False
163
+
164
+ csv_data = []
165
+ for idx, (component, sample) in enumerate(
166
+ zip(self.components, flag_data, strict=False)
167
+ ):
168
+ save_dir = (
169
+ self.flagging_dir
170
+ / client_utils.strip_invalid_filename_characters(
171
+ getattr(component, "label", None) or f"component {idx}"
172
+ )
173
+ )
174
+ if utils.is_prop_update(sample):
175
+ csv_data.append(str(sample))
176
+ else:
177
+ data = (
178
+ component.flag(sample, flag_dir=save_dir)
179
+ if sample is not None
180
+ else ""
181
+ )
182
+ if self.simplify_file_data:
183
+ data = utils.simplify_file_data_in_str(data)
184
+ csv_data.append(data)
185
+
186
+ if flag_option is not None:
187
+ csv_data.append(flag_option)
188
+ if username is not None:
189
+ csv_data.append(username)
190
+
191
+ generated_id = str(uuid.uuid4())
192
+ csv_data.append(generated_id)
193
+
194
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
195
+ :-3
196
+ ] # Correct format for Amazon Athena
197
+ csv_data.append(timestamp)
198
+
199
+ # Build the headers
200
+ headers = [
201
+ getattr(component, "label", None) or f"component {idx}"
202
+ for idx, component in enumerate(self.components)
203
+ ]
204
+ if flag_option is not None:
205
+ headers.append("flag")
206
+ if username is not None:
207
+ headers.append("username")
208
+ headers.append("id")
209
+ headers.append("timestamp")
210
+
211
+ line_count = -1
212
+
213
+ if save_to_csv:
214
+ with self.lock:
215
+ with open(
216
+ self.dataset_filepath, "a", newline="", encoding="utf-8-sig"
217
+ ) as csvfile:
218
+ writer = csv.writer(csvfile)
219
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
220
+ with open(self.dataset_filepath, encoding="utf-8-sig") as csvfile:
221
+ line_count = len(list(csv.reader(csvfile))) - 1
222
+
223
+ if save_to_dynamodb is True:
224
+ print("Saving to DynamoDB")
225
+
226
+ if RUN_AWS_FUNCTIONS == "1":
227
+ try:
228
+ print("Connecting to DynamoDB via existing SSO connection")
229
+ dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
230
+ # client = boto3.client('dynamodb')
231
+
232
+ dynamodb.meta.client.list_tables()
233
+
234
+ except Exception as e:
235
+ print("No SSO credentials found:", e)
236
+ if AWS_ACCESS_KEY and AWS_SECRET_KEY:
237
+ print("Trying DynamoDB credentials from environment variables")
238
+ dynamodb = boto3.resource(
239
+ "dynamodb",
240
+ aws_access_key_id=AWS_ACCESS_KEY,
241
+ aws_secret_access_key=AWS_SECRET_KEY,
242
+ region_name=AWS_REGION,
243
+ )
244
+ # client = boto3.client('dynamodb',aws_access_key_id=AWS_ACCESS_KEY,
245
+ # aws_secret_access_key=AWS_SECRET_KEY, region_name=AWS_REGION)
246
+ else:
247
+ raise Exception(
248
+ "AWS credentials for DynamoDB logging not found"
249
+ )
250
+ else:
251
+ raise Exception("AWS credentials for DynamoDB logging not found")
252
+
253
+ if dynamodb_table_name is None:
254
+ raise ValueError(
255
+ "You must provide a dynamodb_table_name if save_to_dynamodb is True"
256
+ )
257
+
258
+ if dynamodb_headers:
259
+ dynamodb_headers = dynamodb_headers
260
+ if not dynamodb_headers and replacement_headers:
261
+ dynamodb_headers = replacement_headers
262
+ elif headers:
263
+ dynamodb_headers = headers
264
+ elif not dynamodb_headers:
265
+ raise ValueError(
266
+ "Headers not found. You must provide dynamodb_headers or replacement_headers to create a new table."
267
+ )
268
+
269
+ if flag_option is not None:
270
+ if "flag" not in dynamodb_headers:
271
+ dynamodb_headers.append("flag")
272
+ if username is not None:
273
+ if "username" not in dynamodb_headers:
274
+ dynamodb_headers.append("username")
275
+ if "timestamp" not in dynamodb_headers:
276
+ dynamodb_headers.append("timestamp")
277
+ if "id" not in dynamodb_headers:
278
+ dynamodb_headers.append("id")
279
+
280
+ # Table doesn't exist β€” create it
281
+ try:
282
+ table = dynamodb.Table(dynamodb_table_name)
283
+ table.load()
284
+ except botocore.exceptions.ClientError as e:
285
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
286
+
287
+ attribute_definitions = [
288
+ {
289
+ "AttributeName": "id",
290
+ "AttributeType": "S",
291
+ } # Only define key attributes here
292
+ ]
293
+
294
+ table = dynamodb.create_table(
295
+ TableName=dynamodb_table_name,
296
+ KeySchema=[
297
+ {"AttributeName": "id", "KeyType": "HASH"} # Partition key
298
+ ],
299
+ AttributeDefinitions=attribute_definitions,
300
+ BillingMode="PAY_PER_REQUEST",
301
+ )
302
+ # Wait until the table exists
303
+ table.meta.client.get_waiter("table_exists").wait(
304
+ TableName=dynamodb_table_name
305
+ )
306
+ time.sleep(5)
307
+ print(f"Table '{dynamodb_table_name}' created successfully.")
308
+ else:
309
+ raise
310
+
311
+ # Prepare the DynamoDB item to upload
312
+ try:
313
+ item = {
314
+ "id": str(generated_id), # UUID primary key
315
+ #'created_by': username if username else "unknown",
316
+ "timestamp": timestamp,
317
+ }
318
+
319
+ # Map the headers to values
320
+ item.update(
321
+ {
322
+ header: str(value)
323
+ for header, value in zip(dynamodb_headers, csv_data)
324
+ }
325
+ )
326
+
327
+ table.put_item(Item=item)
328
+
329
+ print("Successfully uploaded log to DynamoDB")
330
+ except Exception as e:
331
+ print("Could not upload log to DynamobDB due to", e)
332
+
333
+ return line_count
tools/dedup_summaries.py ADDED
The diff for this file is too large to render. See raw diff
 
tools/example_table_outputs.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dummy_consultation_table = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
2
+ |:---------------------|:-----------------------|:------------|:--------|----------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
3
+ | Development proposal | Affordable housing | Positive | All | 6 | The proposed development is overwhelmingly viewed favorably by the local community, primarily due to<br>its potential to address significant needs. residents express strong support for the provision of<br>amenities, particularly much-needed housing for young people and families. crucially, the<br>development is also considered vital for increasing green space, with some suggesting this could<br>even contribute to further housing opportunities. <br>furthermore, a key theme emerging from the<br>responses is the des... |
4
+ | Development proposal | Environmental damage | Negative | All | 5 | A primary concern expressed across the dataset relates to the potential negative consequences of a<br>proposed development on the local environment and wildlife. multiple respondents highlighted worries<br>about environmental damage, suggesting a significant apprehension regarding the project’s impact.<br>specifically, there’s a shared concern about the detrimental effects on local wildlife, indicating a<br>potential disruption to the natural ecosystem.<br>furthermore, the proposed development is perceived<br>to ... |
5
+ | Community impact | Loss of local business | Negative | All | 4 | A significant concern expressed within the dataset relates to the potential negative consequences of<br>a development project on local businesses and the wider community. multiple respondents voiced<br>worries about increased traffic congestion, directly impacting the viability of local businesses and<br>leading to apprehension about their future. specifically, there is a palpable sadness surrounding<br>the possibility of a beloved local cafe closing, emphasizing the detrimental effect on community<br>connect... |
6
+ | Development proposal | Architectural style | Negative | All | 4 | The primary concern regarding the proposed development is its incompatibility with the established<br>character of the area. residents express a strong feeling that the design is fundamentally at odds<br>with the existing aesthetic and atmosphere, suggesting a lack of sensitivity to the local context.<br>this sentiment highlights a significant apprehension about disrupting the area’s unique<br>identity.<br>furthermore, significant anxieties center on the potential negative impact of the<br>development on the surr... |
7
+ | Economic impact | Investment and jobs | Positive | All | 4 | The proposed development is widely anticipated to generate significant positive economic impacts<br>within the local community. residents believe it will lead to substantial investment and the<br>creation of numerous job opportunities, directly boosting the local economy and revitalizing the<br>town centre. there’s a strong consensus that this development represents a key step towards economic<br>growth and prosperity for the area.<br>specifically, the anticipated benefits include the creation<br>of jobs for loca... |
8
+ | Development proposal | Height of building | Negative | All | 3 | Residents expressed significant concerns regarding the proposed development’s height, specifically<br>highlighting the five-storey structure as a major issue. this height was perceived as excessively<br>tall and likely to cause overshadowing of existing buildings in the area, directly impacting the<br>views enjoyed by current residents. the potential for diminished sunlight and altered visual<br>landscapes was a central point of contention.<br>furthermore, the impact on the existing character<br>of the neighborho... |
9
+ | Development proposal | Infrastructure impact | Negative | All | 3 | Analysis of the provided text reveals significant concerns regarding the potential consequences of a<br>proposed project on existing infrastructure. specifically, there is a notable worry about the<br>detrimental effects on local infrastructure, with the possibility of widespread disruption as a key<br>consequence. this suggests a need for careful assessment and mitigation strategies to avoid<br>negatively impacting essential services and community operations.<br>furthermore, the repeated<br>emphasis on infrastru... |
10
+ | Development proposal | Noise pollution | Negative | All | 3 | The primary concern highlighted within the dataset relates to anticipated noise pollution stemming<br>from the proposed development. multiple responses explicitly express this worry, emphasizing it as a<br>β€œsignificant concern” and a key area of apprehension. there’s a clear understanding that the<br>development will likely exacerbate existing noise levels within the surrounding area, suggesting a<br>potential negative impact on residents and the local environment.<br>several respondents reiterate<br>this concer... |
11
+ | Housing needs | Supply of housing | Positive | All | 3 | The proposed development is viewed as a crucial solution to address the town’s significant housing<br>shortage, reflecting a clear desire for increased housing supply within the community. residents<br>express a strong need for more homes, and this development is seen as a key step towards meeting<br>that demand. <br>furthermore, the project is anticipated to alleviate existing parking issues, which<br>is considered a valuable contribution to the overall housing supply. the provision of additional<br>parking spac... |
12
+ | Development proposal | Height of building | Neutral | All | 2 | The analysis of the provided text reveals a nuanced perspective regarding a development project,<br>with no explicit sentiment expressed concerning the building's height itself. however, significant<br>concerns are raised about the potential impact of the development on local schools. these concerns<br>appear to be linked, at least in part, to the building's height, suggesting a worry that the<br>increased scale could strain existing resources and infrastructure within the school system.<br><br>further investigat... |
13
+ | Community impact | Community facilities | Negative | All | 1 | Concerns exist regarding the negative impact on local amenities. |
14
+ | Community impact | Community facilities | Positive | All | 1 | The development will provide much-needed community facilities, enhancing the local area. |
15
+ | Development proposal | Architectural style | Neutral | All | 1 | The development is expected to provide facilities for young people, but no specific architectural<br>concerns. |
16
+ | Development proposal | Noise pollution | Neutral | All | 1 | Potential for increased noise pollution due to the development is a concern. |
17
+ | Economic impact | Economic decline | Negative | All | 1 | Worries about a negative impact on the local economy are expressed, suggesting potential harm. |"""
18
+
19
+ dummy_consultation_table_zero_shot = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
20
+ |:---------------------------|:------------------------------------|:------------|:--------|----------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
21
+ | Planning & development | Impact on the character of the area | Negative | All | 10 | Residents overwhelmingly express strong objections to the proposed development, primarily focusing<br>on its incompatibility with the established character of the area. A central concern is the<br>development's height and design, which they believe clashes significantly with the existing<br>aesthetic and creates a sense of being overshadowed by taller structures, leading to a feeling of<br>crampedness. Many respondents specifically highlighted the potential for the development to<br>negatively impact Main Stre... |
22
+ | Environmental impact | Impact on the local environment | Negative | All | 8 | Several concerns have been raised regarding the potential negative impacts of a development on the<br>local environment. Multiple respondents expressed worry about the development’s possible detrimental<br>effects on the surrounding environment and quality of life, highlighting a significant area of<br>concern. These anxieties include potential damage to the environment and a general feeling of unease<br>about the development’s consequences.<br><br>Despite a single positive note regarding the provision<br>of green s... |
23
+ | Infrastructure & transport | Traffic congestion | Negative | All | 7 | Concerns regarding increased traffic congestion are prevalent in the dataset, largely stemming from<br>the anticipated impact of the proposed development. Specifically, Main Street is predicted to<br>experience heightened congestion due to the increased volume of traffic it will attract. Multiple<br>responses repeatedly highlight this anticipation as a key issue associated with the<br>project.<br><br>Despite the consistent apprehension about traffic congestion, no direct responses<br>offer specific solutions or miti... |
24
+ | Planning & development | Need for family housing | Positive | All | 7 | The proposed development is overwhelmingly viewed as a crucial solution to the need for family<br>housing within the community. Multiple sources highlight its significance in providing much-needed<br>homes, particularly for families, and specifically addressing the demand for affordable family<br>housing options. Several respondents emphasized the beneficial impact on local residents, with the<br>development also anticipated to create jobs and offer facilities geared towards young people<br>alongside housing. ... |
25
+ | Quality of life | Impact on quality of life | Negative | All | 7 | Analysis of the provided text reveals significant concerns regarding a proposed development's<br>potential negative impact on the quality of life within the area. Residents are particularly worried<br>that the development will overshadow existing buildings, creating a sense of crampedness and<br>diminishing their living experience. Furthermore, anxieties extend beyond immediate residential<br>impacts, encompassing broader concerns about the development’s effects on local businesses, schools,<br>and crucial inf... |
26
+ | Economic impact | Investment and job creation | Positive | All | 6 | The proposed development is overwhelmingly viewed positively, with significant anticipation for its<br>economic impact on the area. Residents and observers alike believe it will stimulate considerable<br>investment and generate numerous job opportunities, particularly for local residents. Furthermore,<br>the project is expected to revitalize the town center and provide crucial affordable housing,<br>potentially benefiting young people seeking to establish themselves in the<br>community.<br><br>Specifically, the deve... |
27
+ | Infrastructure & transport | Parking | Negative | All | 6 | Analysis of the '{column_name}' column reveals significant concerns regarding the potential impact<br>of a new development on Main Street. The primary issue identified is increased traffic congestion,<br>directly linked to the development’s activity. Furthermore, there is widespread apprehension that<br>the project will worsen existing parking problems, with multiple respondents explicitly stating a<br>lack of adequate parking provisions as a key worry. <br><br>Specifically, numerous individuals<br>expressed concern... |
28
+ | Community & local life | Amenities for the local community | Positive | All | 5 | The proposed development is anticipated to significantly benefit the local community, offering a<br>range of amenities and a positive contribution to the area. Specifically, the project will deliver<br>crucial green space alongside facilities designed to cater to the needs of young people and the<br>broader community.<br><br>Furthermore, the development is expected to address critical social needs<br>by providing much-needed community facilities and social housing, indicating a commitment to<br>supporting local resi... |
29
+ | Environmental impact | Impact on local wildlife | Neutral | All | 4 | No specific responses were provided, and the dataset contained no information relevant to the<br>specified consultation context. Consequently, a summary cannot be generated based on the provided<br>data. <br><br>Due to the absence of any textual data within the dataset, there is no content to<br>consolidate and summarize. |
30
+ | Improvement of main street | Improvement of main street | Positive | All | 4 | This development is being hailed as a positive step for the revitalization of Main Street, primarily<br>due to its anticipated improvement in the street’s appearance. Stakeholders view this initiative as<br>a crucial element in breathing new life into the area, suggesting a significant upgrade to the<br>existing landscape.<br><br>Specifically, the project aims to enhance the visual appeal of Main<br>Street, representing a tangible advancement in its overall attractiveness and desirability. The<br>development is wide... |
31
+ | Planning & development | Impact on views | Negative | All | 4 | A primary concern expressed regarding the proposed development is its potential negative impact on<br>existing views. Multiple respondents voiced worries about how the development might obstruct or<br>diminish the current vistas, alongside specific concerns about its effect on views from neighboring<br>properties. This suggests a significant sensitivity to the visual landscape and its value within the<br>community.<br><br>Furthermore, the potential aesthetic consequences of the development are<br>highlighted, with s... |
32
+ | Community & local life | Amenities for the local community | Negative | All | 2 | Residents are voicing significant concerns regarding a proposed development, primarily focusing on<br>its anticipated detrimental effects on local amenities. A key point of contention is the planned<br>removal of the existing cafe, which is being viewed as a substantial loss to the community’s social<br>fabric and a vital local resource.<br><br>The overall sentiment suggests a strong apprehension that<br>the development will diminish the quality of life for those living nearby, highlighting a desire to<br>preserve c... |
33
+ | Impact on local businesses | Impact on local businesses | Negative | All | 2 | A primary concern expressed relates to the potential detrimental effects of the development on local<br>businesses. There’s a clear worry that the project will negatively impact these businesses,<br>suggesting a potential loss of revenue, customer base, or even business closure. The repeated<br>emphasis on a β€œnegative impact” highlights a significant apprehension regarding the economic<br>repercussions for the existing business community.<br><br>The sentiment underscores a desire to<br>mitigate potential harm and li... |
34
+ | Impact on local heritage | Impact on local heritage | Negative | All | 2 | There are growing concerns regarding the potential negative impact of the development on the local<br>heritage. While specific details and references haven’t been explicitly stated, the underlying<br>sentiment suggests a worry about the development’s effects on historically significant elements<br>within the area. This implies a recognition that the proposed project could, perhaps inadvertently,<br>threaten or diminish the cultural value and character of the local environment.<br><br>The presence<br>of these concern... |
35
+ | Environmental impact | Impact on local wildlife | Negative | All | 1 | Concerns regarding the negative impact of the development on local wildlife. |
36
+ | Impact on local heritage | Impact on local heritage | Neutral | All | 1 | No specific responses mention this topic. |
37
+ | Impact on local schools | Impact on local schools | Negative | All | 1 | Concerns about the negative impact on the local schools. |
38
+ | Impact on local schools | Impact on local schools | Neutral | All | 1 | No specific responses mention this topic. |
39
+ | Infrastructure & transport | Parking | Positive | All | 1 | The development is expected to provide much-needed parking spaces. |"""
40
+
41
+ case_notes_table = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
42
+ |:------------------|:----------------------------|:------------|:--------|----------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
43
+ | Family dynamics | Parental conflict | Negative | All | 6 | Several parents expressed significant concerns regarding the well-being of their children, primarily<br>focusing on escalating aggression and withdrawal. alex’s mother specifically highlighted a pattern<br>of arguments at home and attributed the aggressive behavior to external provocation, suggesting a<br>destabilizing family environment. furthermore, parents voiced a lack of confidence in existing<br>interventions for their children, particularly jamie, indicating a perceived need for supplemental<br>support ... |
44
+ | Mental health | Feelings of isolation | Negative | All | 6 | Several individuals expressed significant emotional distress, primarily centered around feelings of<br>isolation and hopelessness. jamie’s withdrawal and reported emptiness suggest a deep-seated sense of<br>disconnection, while alex powerfully articulated his feeling of being misunderstood with the<br>statement β€œno one gets me.” these experiences appear to be impacting daily functioning, as evidenced<br>by jamie’s struggles and disruption of his sleep patterns. <br>the observations from parents further<br>indicat... |
45
+ | Overall mood | Agitation | Negative | All | 6 | Several individuals expressed concerns regarding heightened emotional distress within the subject<br>group, primarily focusing on alex. alex repeatedly demonstrated significant frustration and<br>aggression, necessitating anger management support, and appeared increasingly agitated in recent<br>meetings, suggesting a worsening emotional state. parents specifically highlighted jamie’s ongoing<br>struggles and agitation, emphasizing the need for a comprehensive assessment and subsequent<br>intervention.<br>while so... |
46
+ | Family dynamics | Peer influence | Negative | All | 5 | Concerns centered around alex’s social circle and behavior highlighted potential negative peer<br>influence, specifically due to his new friends and frequent late-night activities. further<br>investigation revealed troubling admissions from alex himself, including alcohol use and<br>participation in a physical altercation with a fellow student, indicating a concerning pattern of<br>risk-taking behavior. <br>simultaneously, jamie’s situation presented a separate area of concern,<br>characterized by isolation and l... |
47
+ | Substance use | Potential substance abuse | Negative | All | 5 | Alex has disclosed alcohol use, raising significant concerns about potential ongoing substance<br>abuse. the situation is further complicated by reports from alex’s mother, who has observed<br>potential signs of substance abuse and expressed her worries regarding this matter. these<br>observations highlight a need for further assessment and support to address the individual’s<br>substance use patterns and ensure their well-being.<br>the situation requires careful monitoring and<br>intervention. the mother’s repor... |
48
+ | Mental health | Depression | Positive | All | 3 | Jamie is diagnosed with major depressive disorder and initiated on antidepressant medication, with<br>initial positive feedback on mood and energy. |
49
+ | Mental health | Self-harm | Negative | All | 3 | The assessment revealed a complex picture regarding the individual’s mental state. while initial<br>observations did not indicate any active self-harm, a thorough evaluation is strongly recommended to<br>identify potential underlying issues contributing to the risk. this proactive approach is crucial<br>for a complete understanding of the individual’s needs.<br>more concerningly, alex presented with<br>visible self-harm indicators on his arms and explicitly communicated thoughts of self-harm,<br>signifying a sign... |
50
+ | School engagement | Absenteeism | Negative | All | 3 | Recent reports highlight a concerning trend of declining student engagement within the school<br>environment. specifically, there has been an increase in absences alongside a decrease in academic<br>performance, suggesting a fundamental lack of connection with schoolwork and learning. several<br>students, including jamie, are exhibiting problematic behaviors that further underscore this issue,<br>such as consistent tardiness and reduced participation in classroom activities.<br>furthermore,<br>observations indica... |
51
+ | Mental health | Depression | Negative | All | 2 | Jamie exhibits symptoms of moderate depression, requiring further evaluation and intervention. |
52
+ | Mental health | Self-harm | Neutral | All | 2 | The psychiatrist’s assessment centered on the potential advantages and drawbacks of antidepressant<br>medication, with a notable emphasis on evaluating the possibility of self-harm risk. this indicates<br>a proactive approach to patient safety and a recognition of the complex interplay between medication<br>and mental health. the discussion highlights a careful consideration of the potential for increased<br>suicidal ideation, suggesting a thorough risk assessment was undertaken.<br>furthermore, the<br>analysis o... |
53
+ | School engagement | Academic performance | Negative | All | 2 | Analysis of the provided text reveals concerns regarding student engagement and academic<br>performance. specifically, jamie’s reduced involvement in class is flagged as a potential indicator<br>of negative consequences, with declining grades reported as a direct result. this suggests a<br>concerning downward trend in alex’s academic progress, highlighting a need for further investigation<br>into the underlying causes of this shift.<br>the combined observations point to a possible<br>correlation between decreased... |
54
+ | Substance use | Substance use (unspecified) | Negative | All | 2 | Concerns regarding ongoing substance use prompted discussion about the possibility of a short-term<br>residential treatment program. alex’s involvement highlighted a potential issue, as they reported<br>occasional substance use, though the specific substances involved were not detailed during the<br>consultation. this lack of specificity regarding the substances used raises a need for further<br>investigation into the nature and frequency of alex’s substance use.<br>the consultation focused on<br>assessing the ri... |
55
+ | Family dynamics | Stepfather relationship | Negative | All | 1 | Alex displayed sudden outbursts of anger when discussing his new stepfather, indicating significant<br>distress related to this family change. |
56
+ | School engagement | Academic performance | Positive | All | 1 | Jamie's academic performance has slightly improved, indicating a potential positive change. |"""
57
+
58
+ case_notes_table_grouped = """| General topic | Subtopic | Sentiment | Group | Number of responses | Revised summary |
59
+ |:--------------------|:---------------------------|:------------|:---------|----------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
60
+ | Trends over time | Trends over time | Negative | Alex D. | 7 | Alex’s case note reveals a troubling deterioration in his well-being marked by a gradual escalation<br>of issues. Initially, the record details an incident involving a physical altercation, which quickly<br>spiraled into increasingly concerning behaviours at home, specifically escalating aggression. Over<br>subsequent meetings, observations consistently pointed towards heightened agitation and expressions<br>of hopelessness, indicating a worsening emotional state and a significant decline in his overall<br>con... |
61
+ | Physical health | Substance misuse | Negative | Alex D. | 6 | Alex’s substance use remains a significant concern, necessitating continued vigilance and support<br>despite recent positive developments in group therapy. While Alex has acknowledged instances of<br>substance use, the details surrounding these occurrences have not been shared, raising questions<br>about the extent and nature of the problem. Concerns were specifically noted regarding potential<br>substance abuse, highlighting a need for further investigation and assessment.<br><br>Ongoing<br>monitoring is crucial to... |
62
+ | Behaviour at school | Behaviour at school | Negative | Alex D. | 3 | A recent case note details a troubling incident involving a physical altercation at school,<br>alongside concerning admissions from Alex regarding alcohol use. This event has sparked worries<br>about potential behavioural issues within the school setting, suggesting a need for further<br>investigation and support. Alex’s demeanor was notably problematic, characterized by sullen behavior<br>and a deliberate avoidance of eye contact, indicating a possible struggle with emotional<br>regulation.<br><br>Furthermore, Alex... |
63
+ | Mental health | Anger | Negative | Alex D. | 3 | Alex exhibits a pronounced anger issue, characterized by frustration and a tendency to blame others<br>for triggering his aggressive behavior. He demonstrated this significantly when discussing his<br>personal life, particularly relating to his new stepfather, suggesting a volatile emotional response<br>to this change. The observed outbursts highlight a need for immediate intervention to manage his<br>escalating anger.<br><br>Further investigation reveals that Alex’s anger is closely linked to his<br>home environmen... |
64
+ | Mental health | Self-harm | Negative | Alex D. | 3 | The analysis reveals significant concerns regarding Alex’s mental health, centering around potential<br>self-harm behaviors. Indications suggest a possible diagnosis of Oppositional Defiant Disorder<br>alongside a co-occurring substance use disorder, warranting a comprehensive treatment plan. Alex<br>demonstrated visible signs of self-harm and openly confessed to experiencing thoughts of self-harm,<br>highlighting a critical need for immediate intervention.<br><br>Following this disclosure, an<br>immediate referral ... |
65
+ | Mental health | Social issues | Negative | Alex D. | 3 | Alex exhibits a pattern of blaming others for his problematic behavior, indicating underlying<br>challenges in social interaction and conflict resolution. This behavior appears to be contributing<br>to further instability in his life. Specifically, his mother voiced concerns regarding his new<br>social circle and increasingly frequent late-night activities, suggesting she perceives these<br>relationships and outings as potentially risky.<br><br>The mother’s observations highlight a<br>potential area of concern for A... |
66
+ | Mental health | Depression | Negative | Jamie L. | 6 | Jamie is currently experiencing concerning symptoms indicative of depression, as noted by both<br>Jamie’s behavior and parental observations. Specifically, he demonstrates limited social<br>interaction, struggles with his mood, and has difficulty engaging with his schoolwork. These<br>difficulties appear persistent, with parents reporting ongoing struggles despite occasional positive<br>moments. <br><br>Further assessment suggests a more pronounced picture, with indications of moderate<br>depression characterized by... |
67
+ | Mental health | Social isolation | Negative | Jamie L. | 4 | Jamie is experiencing significant social isolation, which is negatively affecting both his academic<br>performance and his general well-being. He has expressed feelings of loneliness and difficulty<br>sleeping, strongly suggesting a core social issue is contributing to his distress. Current efforts<br>are focused on promoting increased social interaction to address these challenges.<br><br>The report<br>highlights the urgency of this situation, emphasizing the need for intervention to mitigate Jamie’s<br>isolation a... |
68
+ | Mental health | Medication | Neutral | Jamie L. | 3 | Consideration is being given to medication as a potential intervention alongside therapy to manage<br>depressive symptoms. Initial feedback on the antidepressant is positive. |
69
+ | Mental health | Withdrawal & sadness | Negative | Jamie L. | 3 | Jamie is experiencing a significant downturn in his emotional state, characterized by withdrawal,<br>sadness, and a pervasive sense of emptiness and hopelessness. These negative feelings appear to be<br>triggered by recent reports of tardiness and decreased participation, suggesting a possible link<br>between his behavior and external pressures or expectations. The combination of these symptoms<br>points to a low mood and a feeling of struggle, indicating a potentially serious situation requiring<br>attention.... |
70
+ | Mental health | Low self-worth | Negative | Jamie L. | 2 | Parents are increasingly concerned about Jamie’s well-being due to observed difficulties and a<br>potential lack of self-worth. These concerns are primarily fueled by Jamie’s own statements, where<br>he articulated feelings of low self-esteem and a significant struggle to find<br>motivation.<br><br>Further investigation revealed a direct link between Jamie’s emotional state and<br>recent family financial hardships. The pressures of these struggles appear to have deeply impacted<br>his self-perception and ability to ... |
71
+ | Trends over time | Increasing withdrawal | Negative | Jamie L. | 2 | A significant and worrying trend is emerging regarding withdrawal, necessitating continuous<br>observation and targeted intervention strategies. Specifically, Jamie is exhibiting a noticeable<br>decline in engagement with family activities, representing a key indicator of this broader issue.<br>This withdrawal suggests a potential underlying problem requiring careful assessment and proactive<br>support.<br><br>The observed pattern of withdrawal highlights the importance of sustained monitoring<br>to understand its p... |
72
+ | Behaviour at school | Attendance issues | Negative | Jamie L. | 1 | Jamie’s consistent tardiness was a concern leading to a meeting. |
73
+ | Behaviour at school | Reduced participation | Negative | Jamie L. | 1 | Jamie’s decreased participation in class was noted. |
74
+ | Behaviour at school | Social engagement | Negative | Jamie L. | 1 | Jamie's withdrawal from family activities and hobbies was highlighted. |
75
+ | Behaviour at school | Social engagement | Positive | Jamie L. | 1 | Encouraging Jamie to join school clubs and groups is a strategy to foster social connection and<br>improve his social engagement. |
76
+ | Family & social | Family communication | Negative | Jamie L. | 1 | Parents expressed concerns about Jamie’s withdrawal and lack of communication within the family. |
77
+ | Family & social | Family communication | Neutral | Jamie L. | 1 | Parents are actively involved in Jamie's care and are communicating their observations to the care<br>team. |
78
+ | Family & social | Family financial struggles | Negative | Jamie L. | 1 | Jamie's low motivation is attributed to recent family financial difficulties. |"""
79
+
80
+ case_notes_table_structured_summary = """| Main heading | Subheading | Summary | Group |
81
+ |:--------------------|:--------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------|
82
+ | Behaviour at school | Behaviour at school | Several cases involved disruptions at school, including increased absences, declining grades, and a<br>physical altercation. Alex displayed sullenness, avoidance, and agitation, sometimes reacting with<br>frustration. A key theme was isolation and a lack of connection with peers and school staff. | Alex D. |
83
+ | Mental health | Anger | Anger was a prominent feature across multiple cases, particularly when discussing home life and<br>family dynamics. Outbursts of anger were observed, especially related to a new stepfather, and Alex<br>displayed defensiveness when questioned about his actions. | Alex D. |
84
+ | Mental health | Social issues | Alex experienced feelings of isolation and difficulty connecting with others. He had a new group of<br>friends and engaged in late-night outings, which raised concerns about potential risky behaviours<br>and social influences. | Alex D. |
85
+ | Physical health | General | Signs of self-harm were present on Alex’s arms, indicating a heightened level of distress and<br>potentially a need for immediate support. He displayed visible agitation and defensive behaviour<br>during questioning. | Alex D. |
86
+ | Physical health | Substance misuse | Substance use was a recurring concern, with Alex admitting to occasional substance use and his<br>mother reporting potential signs of abuse. Alcohol use was noted in several instances, leading to<br>recommendations for assessment and potential intervention. | Alex D. |
87
+ | Trends over time | Trends over time | There was a gradual escalation of concerning behaviours over time. Early interventions focused on<br>initial meetings and observation, progressing to more intensive interventions like referrals to<br>mental health professionals, residential treatment programs, and family counseling. | Alex D. |
88
+ | Behaviour at school | Behaviour at school | Jamie exhibited concerning behaviours at school, including consistent tardiness and decreased<br>participation in class. This was accompanied by withdrawn behaviour and signs of sadness, suggesting<br>a need for immediate intervention to address potential underlying issues impacting his academic<br>performance. | Jamie L. |
89
+ | Mental health | Anger | There is no direct indication of anger in Jamie's case notes. | Jamie L. |
90
+ | Mental health | Mental health | Jamie displayed concerning signs of mental health difficulties, including feelings of emptiness,<br>hopelessness, low self-worth, and isolation. He reported difficulty sleeping and a lack of<br>motivation. The need for a comprehensive mental health assessment was highlighted to fully<br>understand the nature and severity of his condition. | Jamie L. |
91
+ | Mental health | Social issues | Jamie experienced significant social difficulties, including limited social interactions, feelings<br>of isolation, and a lack of engagement with family activities and hobbies. He spends a lot of time<br>alone in his room. Recommendations focused on fostering connection through school clubs and family<br>therapy were made. | Jamie L. |
92
+ | Physical health | General | While no direct physical health concerns were explicitly stated, Jamie's emotional state and<br>associated symptoms (difficulty sleeping) warrant consideration of his overall well-being and<br>potential physical manifestations of his mental health challenges. | Jamie L. |
93
+ | Physical health | Substance misuse | There is no indication of substance misuse in the provided case notes. | Jamie L. |
94
+ | Trends over time | Trends over time | Jamie’s case demonstrates fluctuating progress. Initial feedback indicated slight improvements in<br>mood on some days, but overall he continues to struggle. A shift occurred with the commencement of<br>antidepressant medication, showing initial positive feedback in terms of mood and energy levels,<br>requiring continued monitoring and adjustment. | Jamie L. |"""
tools/helper_functions.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import math
3
+ import os
4
+ import re
5
+ from typing import List
6
+
7
+ import boto3
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pandas as pd
11
+ from botocore.exceptions import ClientError
12
+
13
+ from tools.config import (
14
+ AWS_USER_POOL_ID,
15
+ CUSTOM_HEADER,
16
+ CUSTOM_HEADER_VALUE,
17
+ INPUT_FOLDER,
18
+ MAXIMUM_ZERO_SHOT_TOPICS,
19
+ OUTPUT_FOLDER,
20
+ SESSION_OUTPUT_FOLDER,
21
+ model_full_names,
22
+ model_name_map,
23
+ )
24
+
25
+
26
+ def empty_output_vars_extract_topics():
27
+ # Empty output objects before processing a new file
28
+
29
+ master_topic_df_state = pd.DataFrame()
30
+ master_topic_summary_df_state = pd.DataFrame()
31
+ master_reference_df_state = pd.DataFrame()
32
+ text_output_file = list()
33
+ text_output_file_list_state = list()
34
+ latest_batch_completed = 0
35
+ log_files_output = list()
36
+ log_files_output_list_state = list()
37
+ conversation_metadata_textbox = ""
38
+ estimated_time_taken_number = 0
39
+ file_data_state = pd.DataFrame()
40
+ reference_data_file_name_textbox = ""
41
+ display_topic_table_markdown = ""
42
+ summary_output_file_list = list()
43
+ summary_input_file_list = list()
44
+ overall_summarisation_input_files = list()
45
+ overall_summary_output_files = list()
46
+
47
+ return (
48
+ master_topic_df_state,
49
+ master_topic_summary_df_state,
50
+ master_reference_df_state,
51
+ text_output_file,
52
+ text_output_file_list_state,
53
+ latest_batch_completed,
54
+ log_files_output,
55
+ log_files_output_list_state,
56
+ conversation_metadata_textbox,
57
+ estimated_time_taken_number,
58
+ file_data_state,
59
+ reference_data_file_name_textbox,
60
+ display_topic_table_markdown,
61
+ summary_output_file_list,
62
+ summary_input_file_list,
63
+ overall_summarisation_input_files,
64
+ overall_summary_output_files,
65
+ )
66
+
67
+
68
+ def empty_output_vars_summarise():
69
+ # Empty output objects before summarising files
70
+
71
+ summary_reference_table_sample_state = pd.DataFrame()
72
+ master_topic_summary_df_revised_summaries_state = pd.DataFrame()
73
+ master_reference_df_revised_summaries_state = pd.DataFrame()
74
+ summary_output_files = list()
75
+ summarised_outputs_list = list()
76
+ latest_summary_completed_num = 0
77
+ overall_summarisation_input_files = list()
78
+
79
+ return (
80
+ summary_reference_table_sample_state,
81
+ master_topic_summary_df_revised_summaries_state,
82
+ master_reference_df_revised_summaries_state,
83
+ summary_output_files,
84
+ summarised_outputs_list,
85
+ latest_summary_completed_num,
86
+ overall_summarisation_input_files,
87
+ )
88
+
89
+
90
+ def get_or_create_env_var(var_name: str, default_value: str):
91
+ # Get the environment variable if it exists
92
+ value = os.environ.get(var_name)
93
+
94
+ # If it doesn't exist, set it to the default value
95
+ if value is None:
96
+ os.environ[var_name] = default_value
97
+ value = default_value
98
+
99
+ return value
100
+
101
+
102
+ def get_file_path_with_extension(file_path: str):
103
+ # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
104
+ basename = os.path.basename(file_path)
105
+
106
+ # Return the basename with its extension
107
+ return basename
108
+
109
+
110
+ def get_file_name_no_ext(file_path: str):
111
+ # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
112
+ basename = os.path.basename(file_path)
113
+
114
+ # Then, split the basename and its extension and return only the basename without the extension
115
+ filename_without_extension, _ = os.path.splitext(basename)
116
+
117
+ # print(filename_without_extension)
118
+
119
+ return filename_without_extension
120
+
121
+
122
+ def detect_file_type(filename: str):
123
+ """Detect the file type based on its extension."""
124
+
125
+ # Strip quotes and whitespace that might have been accidentally included
126
+ filename = filename.strip().strip("'\"")
127
+
128
+ if (
129
+ (filename.endswith(".csv"))
130
+ | (filename.endswith(".csv.gz"))
131
+ | (filename.endswith(".zip"))
132
+ ):
133
+ return "csv"
134
+ elif filename.endswith(".xlsx"):
135
+ return "xlsx"
136
+ elif filename.endswith(".parquet"):
137
+ return "parquet"
138
+ elif filename.endswith(".pdf"):
139
+ return "pdf"
140
+ elif filename.endswith(".jpg"):
141
+ return "jpg"
142
+ elif filename.endswith(".jpeg"):
143
+ return "jpeg"
144
+ elif filename.endswith(".png"):
145
+ return "png"
146
+ else:
147
+ raise ValueError("Unsupported file type.")
148
+
149
+
150
+ def read_file(filename: str, sheet: str = ""):
151
+ """Read the file based on its detected type."""
152
+ # Strip quotes and whitespace that might have been accidentally included
153
+ filename = filename.strip().strip("'\"")
154
+ file_type = detect_file_type(filename)
155
+
156
+ if file_type == "csv":
157
+ return pd.read_csv(filename, low_memory=False)
158
+ elif file_type == "xlsx":
159
+ if sheet:
160
+ return pd.read_excel(filename, sheet_name=sheet)
161
+ else:
162
+ return pd.read_excel(filename)
163
+ elif file_type == "parquet":
164
+ return pd.read_parquet(filename)
165
+
166
+
167
+ def load_in_file(file_path: str, colnames: List[str] = "", excel_sheet: str = ""):
168
+ """
169
+ Loads in a tabular data file and returns data and file name.
170
+
171
+ Parameters:
172
+ - file_path (str): The path to the file to be processed.
173
+ - colnames (List[str], optional): list of colnames to load in
174
+ """
175
+
176
+ file_name = get_file_name_no_ext(file_path)
177
+ file_data = read_file(file_path, excel_sheet)
178
+
179
+ if colnames and isinstance(colnames, list):
180
+ col_list = colnames
181
+ else:
182
+ col_list = list(file_data.columns)
183
+
184
+ if not isinstance(col_list, List):
185
+ col_list = [col_list]
186
+
187
+ col_list = [item for item in col_list if item not in ["", "NA"]]
188
+
189
+ for col in col_list:
190
+ file_data[col] = file_data[col].fillna("")
191
+ file_data[col] = (
192
+ file_data[col].astype(str).str.replace("\bnan\b", "", regex=True)
193
+ )
194
+
195
+ # print(file_data[colnames])
196
+
197
+ return file_data, file_name
198
+
199
+
200
+ def load_in_data_file(
201
+ file_paths: List[str],
202
+ in_colnames: List[str],
203
+ batch_size: int = 5,
204
+ in_excel_sheets: str = "",
205
+ ):
206
+ """Load in data table, work out how many batches needed."""
207
+
208
+ if not isinstance(in_colnames, list):
209
+ in_colnames = [in_colnames]
210
+
211
+ # print("in_colnames:", in_colnames)
212
+
213
+ try:
214
+ file_data, file_name = load_in_file(
215
+ file_paths[0], colnames=in_colnames, excel_sheet=in_excel_sheets
216
+ )
217
+ num_batches = math.ceil(len(file_data) / batch_size)
218
+ print(
219
+ f"File {file_name} loaded successfully. Number of rows: {len(file_data)}. Total number of batches: {num_batches}"
220
+ )
221
+
222
+ except Exception as e:
223
+ print("Could not load data file due to:", e)
224
+ file_data = pd.DataFrame()
225
+ file_name = ""
226
+ num_batches = 1
227
+
228
+ return file_data, file_name, num_batches
229
+
230
+
231
+ def clean_column_name(
232
+ column_name: str, max_length: int = 20, front_characters: bool = True
233
+ ):
234
+ # Convert to string
235
+ column_name = str(column_name)
236
+ # Replace non-alphanumeric characters (except underscores) with underscores
237
+ column_name = re.sub(r"\W+", "_", column_name)
238
+ # Remove leading/trailing underscores
239
+ column_name = column_name.strip("_")
240
+ # Ensure the result is not empty; fall back to "column" if necessary
241
+ column_name = column_name if column_name else "column"
242
+ # Truncate to max_length
243
+ if front_characters is True:
244
+ output_text = column_name[:max_length]
245
+ else:
246
+ output_text = column_name[-max_length:]
247
+ return output_text
248
+
249
+
250
+ def load_in_previous_reference_file(file: str):
251
+ """Load in data table from a partially completed consultation summary to continue it."""
252
+
253
+ reference_file_data = pd.DataFrame()
254
+ reference_file_name = ""
255
+ out_message = ""
256
+
257
+ # for file in file_paths:
258
+
259
+ print("file:", file)
260
+
261
+ # If reference table
262
+ if "reference_table" in file:
263
+ try:
264
+ reference_file_data, reference_file_name = load_in_file(file)
265
+ # print("reference_file_data:", reference_file_data.head(2))
266
+ out_message = out_message + " Reference file load successful."
267
+ except Exception as e:
268
+ out_message = "Could not load reference file data:" + str(e)
269
+ raise Exception("Could not load reference file data:", e)
270
+
271
+ if reference_file_data.empty:
272
+ out_message = out_message + " No reference data table provided."
273
+ raise Exception(out_message)
274
+
275
+ print(out_message)
276
+
277
+ return reference_file_data, reference_file_name
278
+
279
+
280
+ def load_in_previous_data_files(
281
+ file_paths_partial_output: List[str], for_modified_table: bool = False
282
+ ):
283
+ """Load in data table from a partially completed consultation summary to continue it."""
284
+
285
+ reference_file_data = pd.DataFrame()
286
+ reference_file_name = ""
287
+ unique_file_data = pd.DataFrame()
288
+ unique_file_name = ""
289
+ out_message = ""
290
+ latest_batch = 0
291
+
292
+ if not file_paths_partial_output:
293
+ out_message = out_message + " No reference or unique data table provided."
294
+ return (
295
+ reference_file_data,
296
+ unique_file_data,
297
+ latest_batch,
298
+ out_message,
299
+ reference_file_name,
300
+ unique_file_name,
301
+ )
302
+
303
+ if not isinstance(file_paths_partial_output, list):
304
+ file_paths_partial_output = [file_paths_partial_output]
305
+
306
+ for file in file_paths_partial_output:
307
+
308
+ if isinstance(file, gr.FileData):
309
+ name = file.name
310
+ else:
311
+ name = file
312
+
313
+ # If reference table
314
+ if "reference_table" in name:
315
+ try:
316
+ reference_file_data, reference_file_name = load_in_file(file)
317
+ # print("reference_file_data:", reference_file_data.head(2))
318
+ out_message = out_message + " Reference file load successful."
319
+
320
+ except Exception as e:
321
+ out_message = "Could not load reference file data:" + str(e)
322
+ raise Exception("Could not load reference file data:", e)
323
+ # If unique table
324
+ if "unique_topic" in name:
325
+ try:
326
+ unique_file_data, unique_file_name = load_in_file(file)
327
+ # print("unique_topics_file:", unique_file_data.head(2))
328
+ out_message = out_message + " Unique table file load successful."
329
+ except Exception as e:
330
+ out_message = "Could not load unique table file data:" + str(e)
331
+ raise Exception("Could not load unique table file data:", e)
332
+ if "batch_" in name:
333
+ latest_batch = re.search(r"batch_(\d+)", file.name).group(1)
334
+ print("latest batch:", latest_batch)
335
+ latest_batch = int(latest_batch)
336
+
337
+ if latest_batch == 0:
338
+ out_message = out_message + " Latest batch number not found."
339
+ if reference_file_data.empty:
340
+ out_message = out_message + " No reference data table provided."
341
+ # raise Exception(out_message)
342
+ if unique_file_data.empty:
343
+ out_message = out_message + " No unique data table provided."
344
+
345
+ print(out_message)
346
+
347
+ # Return all data if using for deduplication task. Return just modified unique table if using just for table modification
348
+ if for_modified_table is False:
349
+ return (
350
+ reference_file_data,
351
+ unique_file_data,
352
+ latest_batch,
353
+ out_message,
354
+ reference_file_name,
355
+ unique_file_name,
356
+ )
357
+ else:
358
+ reference_file_data.drop("Topic number", axis=1, inplace=True, errors="ignore")
359
+
360
+ unique_file_data = create_topic_summary_df_from_reference_table(
361
+ reference_file_data
362
+ )
363
+
364
+ unique_file_data.drop("Summary", axis=1, inplace=True)
365
+
366
+ # Then merge the topic numbers back to the original dataframe
367
+ reference_file_data = reference_file_data.merge(
368
+ unique_file_data[
369
+ ["General topic", "Subtopic", "Sentiment", "Topic number"]
370
+ ],
371
+ on=["General topic", "Subtopic", "Sentiment"],
372
+ how="left",
373
+ )
374
+
375
+ out_file_names = [reference_file_name + ".csv"]
376
+ out_file_names.append(unique_file_name + ".csv")
377
+
378
+ return (
379
+ unique_file_data,
380
+ reference_file_data,
381
+ unique_file_data,
382
+ reference_file_name,
383
+ unique_file_name,
384
+ out_file_names,
385
+ ) # gr.Dataframe(value=unique_file_data, headers=None, column_count=(unique_file_data.shape[1], "fixed"), row_count = (unique_file_data.shape[0], "fixed"), visible=True, type="pandas")
386
+
387
+
388
+ def join_cols_onto_reference_df(
389
+ reference_df: pd.DataFrame,
390
+ original_data_df: pd.DataFrame,
391
+ join_columns: List[str],
392
+ original_file_name: str,
393
+ output_folder: str = OUTPUT_FOLDER,
394
+ ):
395
+
396
+ # print("original_data_df columns:", original_data_df.columns)
397
+ # print("original_data_df:", original_data_df)
398
+
399
+ original_data_df.reset_index(names="Response References", inplace=True)
400
+ original_data_df["Response References"] += 1
401
+
402
+ # print("reference_df columns:", reference_df.columns)
403
+ # print("reference_df:", reference_df)
404
+
405
+ join_columns.append("Response References")
406
+
407
+ reference_df["Response References"] = (
408
+ reference_df["Response References"].fillna("-1").astype(int)
409
+ )
410
+
411
+ save_file_name = output_folder + original_file_name + "_j.csv"
412
+
413
+ out_reference_df = reference_df.merge(
414
+ original_data_df[join_columns], on="Response References", how="left"
415
+ )
416
+ out_reference_df.to_csv(save_file_name, index=None)
417
+
418
+ file_data_outputs = [save_file_name]
419
+
420
+ return out_reference_df, file_data_outputs
421
+
422
+
423
+ def get_basic_response_data(
424
+ file_data: pd.DataFrame, chosen_cols: List[str], verify_titles: bool = False
425
+ ) -> pd.DataFrame:
426
+
427
+ if not isinstance(chosen_cols, list):
428
+ chosen_cols = [chosen_cols]
429
+
430
+ if chosen_cols[0] not in file_data.columns:
431
+ error_msg = (
432
+ f"Column '{chosen_cols[0]}' not found in file_data columns. "
433
+ f"Available columns: {list(file_data.columns)}"
434
+ )
435
+ print(error_msg)
436
+ raise KeyError(error_msg)
437
+
438
+ # If verify_titles is True, we need to check and include the second column
439
+ if verify_titles is True:
440
+ if len(chosen_cols) < 2:
441
+ error_msg = (
442
+ "verify_titles is True but only one column provided. "
443
+ "Need at least 2 columns: one for response text and one for title."
444
+ )
445
+ print(error_msg)
446
+ raise ValueError(error_msg)
447
+ if chosen_cols[1] not in file_data.columns:
448
+ error_msg = (
449
+ f"Column '{chosen_cols[1]}' not found in file_data columns for title. "
450
+ f"Available columns: {list(file_data.columns)}"
451
+ )
452
+ print(error_msg)
453
+ raise KeyError(error_msg)
454
+ # Include both columns when verify_titles is True
455
+ basic_response_data = file_data[[chosen_cols[0], chosen_cols[1]]]
456
+ basic_response_data = basic_response_data.rename(
457
+ columns={
458
+ basic_response_data.columns[0]: "Response",
459
+ basic_response_data.columns[1]: "Title",
460
+ }
461
+ )
462
+ else:
463
+ basic_response_data = file_data[[chosen_cols[0]]]
464
+ basic_response_data = basic_response_data.rename(
465
+ columns={basic_response_data.columns[0]: "Response"}
466
+ )
467
+ basic_response_data = basic_response_data.reset_index(
468
+ names="Original Reference"
469
+ ) # .reset_index(drop=True) #
470
+ # Try to convert to int, if it fails, return a range of 1 to last row + 1
471
+ try:
472
+ basic_response_data["Original Reference"] = (
473
+ basic_response_data["Original Reference"].astype(int) + 1
474
+ )
475
+ except (ValueError, TypeError):
476
+ basic_response_data["Original Reference"] = range(
477
+ 1, len(basic_response_data) + 1
478
+ )
479
+
480
+ basic_response_data["Reference"] = basic_response_data.index.astype(int) + 1
481
+
482
+ if verify_titles is True:
483
+ basic_response_data["Title"] = basic_response_data["Title"].str.strip()
484
+ basic_response_data["Title"] = basic_response_data["Title"].apply(initial_clean)
485
+ else:
486
+ basic_response_data = basic_response_data[
487
+ ["Reference", "Response", "Original Reference"]
488
+ ]
489
+
490
+ basic_response_data["Response"] = basic_response_data["Response"].str.strip()
491
+ basic_response_data["Response"] = basic_response_data["Response"].apply(
492
+ initial_clean
493
+ )
494
+
495
+ return basic_response_data
496
+
497
+
498
+ def convert_reference_table_to_pivot_table(
499
+ df: pd.DataFrame, basic_response_data: pd.DataFrame = pd.DataFrame()
500
+ ):
501
+
502
+ df_in = df[["Response References", "General topic", "Subtopic", "Sentiment"]].copy()
503
+
504
+ df_in["Response References"] = df_in["Response References"].astype(int)
505
+
506
+ # Create a combined category column
507
+ df_in["Category"] = (
508
+ df_in["General topic"] + " - " + df_in["Subtopic"] + " - " + df_in["Sentiment"]
509
+ )
510
+
511
+ # Create pivot table counting occurrences of each unique combination
512
+ pivot_table = pd.crosstab(
513
+ index=df_in["Response References"],
514
+ columns=[df_in["General topic"], df_in["Subtopic"], df_in["Sentiment"]],
515
+ margins=True,
516
+ )
517
+
518
+ # Flatten column names to make them more readable
519
+ pivot_table.columns = [" - ".join(col) for col in pivot_table.columns]
520
+
521
+ pivot_table.reset_index(inplace=True)
522
+
523
+ if not basic_response_data.empty:
524
+ pivot_table = basic_response_data.merge(
525
+ pivot_table, right_on="Response References", left_on="Reference", how="left"
526
+ )
527
+
528
+ pivot_table.drop("Response References", axis=1, inplace=True)
529
+
530
+ pivot_table.columns = pivot_table.columns.str.replace(
531
+ "Not assessed - ", ""
532
+ ).str.replace("- Not assessed", "")
533
+
534
+ return pivot_table
535
+
536
+
537
+ def create_topic_summary_df_from_reference_table(reference_df: pd.DataFrame):
538
+
539
+ if "Group" not in reference_df.columns:
540
+ reference_df["Group"] = "All"
541
+
542
+ # Ensure 'Start row of group' column is numeric to avoid comparison errors
543
+ if "Start row of group" in reference_df.columns:
544
+ reference_df["Start row of group"] = pd.to_numeric(
545
+ reference_df["Start row of group"], errors="coerce"
546
+ )
547
+
548
+ out_topic_summary_df = (
549
+ reference_df.groupby(["General topic", "Subtopic", "Sentiment", "Group"])
550
+ .agg(
551
+ {
552
+ "Response References": "size", # Count the number of references
553
+ "Summary": lambda x: "<br>".join(
554
+ sorted(
555
+ set(x),
556
+ key=lambda summary: reference_df.loc[
557
+ reference_df["Summary"] == summary, "Start row of group"
558
+ ].min(),
559
+ )
560
+ ),
561
+ }
562
+ )
563
+ .reset_index()
564
+ # .sort_values('Response References', ascending=False) # Sort by size, biggest first
565
+ )
566
+
567
+ out_topic_summary_df = out_topic_summary_df.rename(
568
+ columns={"Response References": "Number of responses"}, errors="ignore"
569
+ )
570
+
571
+ # Sort the dataframe first
572
+ out_topic_summary_df = out_topic_summary_df.sort_values(
573
+ ["Group", "Number of responses", "General topic", "Subtopic", "Sentiment"],
574
+ ascending=[True, False, True, True, True],
575
+ )
576
+
577
+ # Then assign Topic number based on the final sorted order
578
+ out_topic_summary_df = out_topic_summary_df.assign(
579
+ Topic_number=lambda df: np.arange(1, len(df) + 1)
580
+ )
581
+
582
+ out_topic_summary_df.rename(columns={"Topic_number": "Topic number"}, inplace=True)
583
+
584
+ return out_topic_summary_df
585
+
586
+
587
+ # Wrap text in each column to the specified max width, including whole words
588
+ def wrap_text(text: str, max_width=80, max_text_length=None):
589
+ if not isinstance(text, str):
590
+ return text
591
+
592
+ # If max_text_length is set, truncate the text and add ellipsis
593
+ if max_text_length and len(text) > max_text_length:
594
+ text = text[:max_text_length] + "..."
595
+
596
+ text = text.replace("\r\n", "<br>").replace("\n", "<br>")
597
+
598
+ words = text.split()
599
+ if not words:
600
+ return text
601
+
602
+ # First pass: initial word wrapping
603
+ wrapped_lines = list()
604
+ current_line = list()
605
+ current_length = 0
606
+
607
+ def add_line():
608
+ if current_line:
609
+ wrapped_lines.append(" ".join(current_line))
610
+ current_line.clear()
611
+
612
+ for i, word in enumerate(words):
613
+ word_length = len(word)
614
+
615
+ # Handle words longer than max_width
616
+ if word_length > max_width:
617
+ add_line()
618
+ wrapped_lines.append(word)
619
+ current_length = 0
620
+ continue
621
+
622
+ # Calculate space needed for this word
623
+ space_needed = word_length if not current_line else word_length + 1
624
+
625
+ # Check if adding this word would exceed max_width
626
+ if current_length + space_needed > max_width:
627
+ add_line()
628
+ current_line.append(word)
629
+ current_length = word_length
630
+ else:
631
+ current_line.append(word)
632
+ current_length += space_needed
633
+
634
+ add_line() # Add any remaining text
635
+
636
+ # Second pass: redistribute words from lines following single-word lines
637
+ def can_fit_in_previous_line(prev_line, word):
638
+ return len(prev_line) + 1 + len(word) <= max_width
639
+
640
+ i = 0
641
+ while i < len(wrapped_lines) - 1:
642
+ words_in_line = wrapped_lines[i].split()
643
+ next_line_words = wrapped_lines[i + 1].split()
644
+
645
+ # If current line has only one word and isn't too long
646
+ if len(words_in_line) == 1 and len(words_in_line[0]) < max_width * 0.8:
647
+ # Try to bring words back from the next line
648
+ words_to_bring_back = list()
649
+ remaining_words = list()
650
+ current_length = len(words_in_line[0])
651
+
652
+ for word in next_line_words:
653
+ if current_length + len(word) + 1 <= max_width:
654
+ words_to_bring_back.append(word)
655
+ current_length += len(word) + 1
656
+ else:
657
+ remaining_words.append(word)
658
+
659
+ if words_to_bring_back:
660
+ # Update current line with additional words
661
+ wrapped_lines[i] = " ".join(words_in_line + words_to_bring_back)
662
+
663
+ # Update next line with remaining words
664
+ if remaining_words:
665
+ wrapped_lines[i + 1] = " ".join(remaining_words)
666
+ else:
667
+ wrapped_lines.pop(i + 1)
668
+ continue # Don't increment i if we removed a line
669
+ i += 1
670
+
671
+ return "<br>".join(wrapped_lines)
672
+
673
+
674
+ def initial_clean(text: str):
675
+ #### Some of my cleaning functions
676
+ html_pattern_regex = r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0|&nbsp;"
677
+ html_start_pattern_end_dots_regex = r"<(.*?)\.\."
678
+ non_ascii_pattern = r"[^\x00-\x7F]+"
679
+ multiple_spaces_regex = r"\s{2,}"
680
+
681
+ # Define a list of patterns and their replacements
682
+ patterns = [
683
+ (html_pattern_regex, " "),
684
+ (html_start_pattern_end_dots_regex, " "),
685
+ (non_ascii_pattern, " "),
686
+ (multiple_spaces_regex, " "),
687
+ ]
688
+
689
+ # Apply each regex replacement
690
+ for pattern, replacement in patterns:
691
+ text = re.sub(pattern, replacement, text)
692
+
693
+ return text
694
+
695
+
696
+ def view_table(file_path: str): # Added max_width parameter
697
+ df = pd.read_csv(file_path)
698
+
699
+ df_cleaned = df.replace("\n", " ", regex=True)
700
+
701
+ # Use apply with axis=1 to apply wrap_text to each element
702
+ df_cleaned = df_cleaned.apply(lambda col: col.map(wrap_text))
703
+
704
+ table_out = df_cleaned.to_markdown(index=False)
705
+
706
+ return table_out
707
+
708
+
709
+ def ensure_output_folder_exists():
710
+ """Checks if the 'output/' folder exists, creates it if not."""
711
+
712
+ folder_name = "output/"
713
+
714
+ if not os.path.exists(folder_name):
715
+ # Create the folder if it doesn't exist
716
+ os.makedirs(folder_name)
717
+ print("Created the 'output/' folder.")
718
+ else:
719
+ print("The 'output/' folder already exists.")
720
+
721
+
722
+ def put_columns_in_df(in_file: List[str]):
723
+ new_choices = list()
724
+ concat_choices = list()
725
+ all_sheet_names = list()
726
+ number_of_excel_files = 0
727
+
728
+ if not in_file:
729
+ return (
730
+ gr.Dropdown(choices=list()),
731
+ gr.Dropdown(choices=list()),
732
+ "",
733
+ gr.Dropdown(choices=list()),
734
+ gr.Dropdown(choices=list()),
735
+ )
736
+
737
+ for file in in_file:
738
+ file_name = file.name
739
+ file_type = detect_file_type(file_name)
740
+ # print("File type is:", file_type)
741
+
742
+ file_end = get_file_path_with_extension(file_name)
743
+
744
+ if file_type == "xlsx":
745
+ number_of_excel_files += 1
746
+ new_choices = list()
747
+ print("Running through all xlsx sheets")
748
+ anon_xlsx = pd.ExcelFile(file_name)
749
+ new_sheet_names = anon_xlsx.sheet_names
750
+ # Iterate through the sheet names
751
+ for sheet_name in new_sheet_names:
752
+ # Read each sheet into a DataFrame
753
+ df = pd.read_excel(file_name, sheet_name=sheet_name)
754
+
755
+ new_choices.extend(list(df.columns))
756
+
757
+ all_sheet_names.extend(new_sheet_names)
758
+
759
+ else:
760
+ df = read_file(file_name)
761
+ new_choices = list(df.columns)
762
+
763
+ concat_choices.extend(new_choices)
764
+
765
+ # Drop duplicate columns
766
+ concat_choices = sorted(set(concat_choices))
767
+
768
+ if number_of_excel_files > 0:
769
+ return (
770
+ gr.Dropdown(choices=concat_choices, value=concat_choices[0]),
771
+ gr.Dropdown(
772
+ choices=all_sheet_names,
773
+ value=all_sheet_names[0],
774
+ visible=True,
775
+ interactive=True,
776
+ ),
777
+ file_end,
778
+ gr.Dropdown(choices=concat_choices),
779
+ gr.Dropdown(choices=concat_choices),
780
+ )
781
+ else:
782
+ return (
783
+ gr.Dropdown(choices=concat_choices, value=concat_choices[0]),
784
+ gr.Dropdown(visible=False),
785
+ file_end,
786
+ gr.Dropdown(choices=concat_choices),
787
+ gr.Dropdown(choices=concat_choices),
788
+ )
789
+
790
+
791
+ # Following function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
792
+ def add_folder_to_path(folder_path: str):
793
+ """
794
+ Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist.
795
+ """
796
+
797
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
798
+ print(folder_path, "folder exists.")
799
+
800
+ # Resolve relative path to absolute path
801
+ absolute_path = os.path.abspath(folder_path)
802
+
803
+ current_path = os.environ["PATH"]
804
+ if absolute_path not in current_path.split(os.pathsep):
805
+ full_path_extension = absolute_path + os.pathsep + current_path
806
+ os.environ["PATH"] = full_path_extension
807
+ # print(f"Updated PATH with: ", full_path_extension)
808
+ else:
809
+ print(f"Directory {folder_path} already exists in PATH.")
810
+ else:
811
+ print(f"Folder not found at {folder_path} - not added to PATH")
812
+
813
+
814
+ # Upon running a process, the feedback buttons are revealed
815
+ def reveal_feedback_buttons():
816
+ return (
817
+ gr.Radio(visible=True),
818
+ gr.Textbox(visible=True),
819
+ gr.Button(visible=True),
820
+ gr.Markdown(visible=True),
821
+ )
822
+
823
+
824
+ def wipe_logs(feedback_logs_loc: str, usage_logs_loc: str):
825
+ try:
826
+ os.remove(feedback_logs_loc)
827
+ except Exception as e:
828
+ print("Could not remove feedback logs file", e)
829
+ try:
830
+ os.remove(usage_logs_loc)
831
+ except Exception as e:
832
+ print("Could not remove usage logs file", e)
833
+
834
+
835
+ async def get_connection_params(
836
+ request: gr.Request,
837
+ output_folder_textbox: str = OUTPUT_FOLDER,
838
+ input_folder_textbox: str = INPUT_FOLDER,
839
+ session_output_folder: str = SESSION_OUTPUT_FOLDER,
840
+ ):
841
+
842
+ # print("Session hash:", request.session_hash)
843
+
844
+ if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
845
+ if CUSTOM_HEADER in request.headers:
846
+ supplied_custom_header_value = request.headers[CUSTOM_HEADER]
847
+ if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
848
+ print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
849
+ else:
850
+ print("Custom header value does not match expected value.")
851
+ raise ValueError("Custom header value does not match expected value.")
852
+ else:
853
+ print("Custom header value not found.")
854
+ raise ValueError("Custom header value not found.")
855
+
856
+ # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
857
+
858
+ if request.username:
859
+ out_session_hash = request.username
860
+ # print("Request username found:", out_session_hash)
861
+
862
+ elif "x-cognito-id" in request.headers:
863
+ out_session_hash = request.headers["x-cognito-id"]
864
+ # print("Cognito ID found:", out_session_hash)
865
+
866
+ elif "x-amzn-oidc-identity" in request.headers:
867
+ out_session_hash = request.headers["x-amzn-oidc-identity"]
868
+
869
+ # Fetch email address using Cognito client
870
+ cognito_client = boto3.client("cognito-idp")
871
+ try:
872
+ response = cognito_client.admin_get_user(
873
+ UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
874
+ Username=out_session_hash,
875
+ )
876
+ email = next(
877
+ attr["Value"]
878
+ for attr in response["UserAttributes"]
879
+ if attr["Name"] == "email"
880
+ )
881
+ # print("Email address found:", email)
882
+
883
+ out_session_hash = email
884
+ except ClientError as e:
885
+ print("Error fetching user details:", e)
886
+ email = None
887
+
888
+ print("Cognito ID found:", out_session_hash)
889
+
890
+ else:
891
+ out_session_hash = request.session_hash
892
+
893
+ if session_output_folder == "True" or session_output_folder is True:
894
+ output_folder = output_folder_textbox + out_session_hash + "/"
895
+ input_folder = input_folder_textbox + out_session_hash + "/"
896
+ else:
897
+ output_folder = output_folder_textbox
898
+ input_folder = input_folder_textbox
899
+
900
+ if not os.path.exists(output_folder):
901
+ os.mkdir(output_folder)
902
+ if not os.path.exists(input_folder):
903
+ os.mkdir(input_folder)
904
+
905
+ return out_session_hash, output_folder, out_session_hash, input_folder
906
+
907
+
908
+ def load_in_default_cost_codes(cost_codes_path: str, default_cost_code: str = ""):
909
+ """
910
+ Load in the cost codes list from file.
911
+ """
912
+ cost_codes_df = pd.read_csv(cost_codes_path)
913
+ dropdown_choices = cost_codes_df.iloc[:, 0].astype(str).tolist()
914
+
915
+ # Avoid inserting duplicate or empty cost code values
916
+ if default_cost_code and default_cost_code not in dropdown_choices:
917
+ dropdown_choices.insert(0, default_cost_code)
918
+
919
+ # Always have a blank option at the top
920
+ if "" not in dropdown_choices:
921
+ dropdown_choices.insert(0, "")
922
+
923
+ out_dropdown = gr.Dropdown(
924
+ value=default_cost_code if default_cost_code in dropdown_choices else "",
925
+ label="Choose cost code for analysis",
926
+ choices=dropdown_choices,
927
+ allow_custom_value=False,
928
+ )
929
+
930
+ return cost_codes_df, cost_codes_df, out_dropdown
931
+
932
+
933
+ def df_select_callback_cost(df: pd.DataFrame, evt: gr.SelectData):
934
+ row_value_code = evt.row_value[0] # This is the value for cost code
935
+
936
+ return row_value_code
937
+
938
+
939
+ def update_cost_code_dataframe_from_dropdown_select(
940
+ cost_dropdown_selection: str, cost_code_df: pd.DataFrame
941
+ ):
942
+ cost_code_df = cost_code_df.loc[
943
+ cost_code_df.iloc[:, 0] == cost_dropdown_selection, :
944
+ ]
945
+ return cost_code_df
946
+
947
+
948
+ def reset_base_dataframe(df: pd.DataFrame):
949
+ return df
950
+
951
+
952
+ def enforce_cost_codes(
953
+ enforce_cost_code_textbox: str,
954
+ cost_code_choice: str,
955
+ cost_code_df: pd.DataFrame,
956
+ verify_cost_codes: bool = True,
957
+ ):
958
+ """
959
+ Check if the enforce cost codes variable is set to true, and then check that a cost cost has been chosen. If not, raise an error. Then, check against the values in the cost code dataframe to ensure that the cost code exists.
960
+ """
961
+
962
+ if enforce_cost_code_textbox == "True":
963
+ if not cost_code_choice:
964
+ raise Exception("Please choose a cost code before continuing")
965
+
966
+ if verify_cost_codes is True:
967
+ if cost_code_df.empty:
968
+ # Warn but don't block - cost code is still required above
969
+ print(
970
+ "Warning: Cost code dataframe is empty. Verification skipped. Please ensure cost codes are loaded for full validation."
971
+ )
972
+ else:
973
+ valid_cost_codes_list = list(cost_code_df.iloc[:, 0].unique())
974
+
975
+ if cost_code_choice not in valid_cost_codes_list:
976
+ raise Exception(
977
+ "Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions."
978
+ )
979
+ return
980
+
981
+
982
+ def _get_env_list(env_var_name: str, strip_strings: bool = True) -> List[str]:
983
+ """Parses a comma-separated environment variable into a list of strings."""
984
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
985
+ if not value:
986
+ return []
987
+ # Split by comma and filter out any empty strings that might result from extra commas
988
+ if strip_strings:
989
+ return [s.strip() for s in value.split(",") if s.strip()]
990
+ else:
991
+ return [codecs.decode(s, "unicode_escape") for s in value.split(",") if s]
992
+
993
+
994
+ def create_batch_file_path_details(
995
+ reference_data_file_name: str,
996
+ latest_batch_completed: int = None,
997
+ batch_size_number: int = None,
998
+ in_column: str = None,
999
+ ) -> str:
1000
+ """
1001
+ Creates a standardised batch file path detail string from a reference data filename.
1002
+
1003
+ Args:
1004
+ reference_data_file_name (str): Name of the reference data file
1005
+ latest_batch_completed (int, optional): Latest batch completed. Defaults to None.
1006
+ batch_size_number (int, optional): Batch size number. Defaults to None.
1007
+ in_column (str, optional): In column. Defaults to None.
1008
+ Returns:
1009
+ str: Formatted batch file path detail string
1010
+ """
1011
+
1012
+ # Extract components from filename using regex
1013
+ file_name = (
1014
+ re.search(
1015
+ r"(.*?)(?:_all_|_final_|_batch_|_col_)", reference_data_file_name
1016
+ ).group(1)
1017
+ if re.search(r"(.*?)(?:_all_|_final_|_batch_|_col_)", reference_data_file_name)
1018
+ else reference_data_file_name
1019
+ )
1020
+ latest_batch_completed = (
1021
+ int(re.search(r"batch_(\d+)_", reference_data_file_name).group(1))
1022
+ if "batch_" in reference_data_file_name
1023
+ else latest_batch_completed
1024
+ )
1025
+ batch_size_number = (
1026
+ int(re.search(r"size_(\d+)_", reference_data_file_name).group(1))
1027
+ if "size_" in reference_data_file_name
1028
+ else batch_size_number
1029
+ )
1030
+ in_column = (
1031
+ re.search(r"col_(.*?)_reference", reference_data_file_name).group(1)
1032
+ if "col_" in reference_data_file_name
1033
+ else in_column
1034
+ )
1035
+
1036
+ # Clean the extracted names
1037
+ file_name_cleaned = clean_column_name(file_name, max_length=20)
1038
+ in_column_cleaned = clean_column_name(in_column, max_length=20)
1039
+
1040
+ # Create batch file path details string
1041
+ if latest_batch_completed:
1042
+ return f"{file_name_cleaned}_batch_{latest_batch_completed}_size_{batch_size_number}_col_{in_column_cleaned}"
1043
+ return f"{file_name_cleaned}_col_{in_column_cleaned}"
1044
+
1045
+
1046
+ def move_overall_summary_output_files_to_front_page(
1047
+ overall_summary_output_files_xlsx: List[str],
1048
+ ):
1049
+ return overall_summary_output_files_xlsx
1050
+
1051
+
1052
+ def generate_zero_shot_topics_df(
1053
+ zero_shot_topics: pd.DataFrame,
1054
+ force_zero_shot_radio: str = "No",
1055
+ create_revised_general_topics: bool = False,
1056
+ max_topic_no: int = MAXIMUM_ZERO_SHOT_TOPICS,
1057
+ ):
1058
+ """
1059
+ Preprocesses a DataFrame of zero-shot topics, cleaning and formatting them
1060
+ for use with a large language model. It handles different column configurations
1061
+ (e.g., only subtopics, general topics and subtopics, or subtopics with descriptions)
1062
+ and enforces a maximum number of topics.
1063
+
1064
+ Args:
1065
+ zero_shot_topics (pd.DataFrame): A DataFrame containing the initial zero-shot topics.
1066
+ Expected columns can vary, but typically include
1067
+ "General topic", "Subtopic", and/or "Description".
1068
+ force_zero_shot_radio (str, optional): A string indicating whether to force
1069
+ the use of zero-shot topics. Defaults to "No".
1070
+ (Currently not used in the function logic, but kept for signature consistency).
1071
+ create_revised_general_topics (bool, optional): A boolean indicating whether to
1072
+ create revised general topics. Defaults to False.
1073
+ (Currently not used in the function logic, but kept for signature consistency).
1074
+ max_topic_no (int, optional): The maximum number of topics allowed to fit within
1075
+ LLM context limits. If `zero_shot_topics` exceeds this,
1076
+ it will be truncated. Defaults to 120.
1077
+
1078
+ Returns:
1079
+ tuple: A tuple containing:
1080
+ - zero_shot_topics_gen_topics_list (list): A list of cleaned general topics.
1081
+ - zero_shot_topics_subtopics_list (list): A list of cleaned subtopics.
1082
+ - zero_shot_topics_description_list (list): A list of cleaned topic descriptions.
1083
+ """
1084
+
1085
+ zero_shot_topics_gen_topics_list = list()
1086
+ zero_shot_topics_subtopics_list = list()
1087
+ zero_shot_topics_description_list = list()
1088
+
1089
+ # Max 120 topics allowed
1090
+ if zero_shot_topics.shape[0] > max_topic_no:
1091
+ out_message = (
1092
+ "Maximum "
1093
+ + str(max_topic_no)
1094
+ + " zero-shot topics allowed according to application configuration."
1095
+ )
1096
+ print(out_message)
1097
+ raise Exception(out_message)
1098
+
1099
+ # Forward slashes in the topic names seems to confuse the model
1100
+ if zero_shot_topics.shape[1] >= 1: # Check if there is at least one column
1101
+ for x in zero_shot_topics.columns:
1102
+ if not zero_shot_topics[x].isnull().all():
1103
+ zero_shot_topics[x] = zero_shot_topics[x].apply(initial_clean)
1104
+
1105
+ zero_shot_topics.loc[:, x] = (
1106
+ zero_shot_topics.loc[:, x]
1107
+ .str.strip()
1108
+ .str.replace("\n", " ")
1109
+ .str.replace("\r", " ")
1110
+ .str.replace("/", " or ")
1111
+ .str.replace("&", " and ")
1112
+ .str.replace(" s ", "s ")
1113
+ .str.lower()
1114
+ .str.capitalize()
1115
+ )
1116
+
1117
+ # If number of columns is 1, keep only subtopics
1118
+ if (
1119
+ zero_shot_topics.shape[1] == 1
1120
+ and "General topic" not in zero_shot_topics.columns
1121
+ ):
1122
+ print("Found only Subtopic in zero shot topics")
1123
+ zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
1124
+ zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 0])
1125
+ # Allow for possibility that the user only wants to set general topics and not subtopics
1126
+ elif (
1127
+ zero_shot_topics.shape[1] == 1
1128
+ and "General topic" in zero_shot_topics.columns
1129
+ ):
1130
+ print("Found only General topic in zero shot topics")
1131
+ zero_shot_topics_gen_topics_list = list(zero_shot_topics["General topic"])
1132
+ zero_shot_topics_subtopics_list = [""] * zero_shot_topics.shape[0]
1133
+ # If general topic and subtopic are specified
1134
+ elif set(["General topic", "Subtopic"]).issubset(zero_shot_topics.columns):
1135
+ print("Found General topic and Subtopic in zero shot topics")
1136
+ zero_shot_topics_gen_topics_list = list(zero_shot_topics["General topic"])
1137
+ zero_shot_topics_subtopics_list = list(zero_shot_topics["Subtopic"])
1138
+ # If subtopic and description are specified
1139
+ elif set(["Subtopic", "Description"]).issubset(zero_shot_topics.columns):
1140
+ print("Found Subtopic and Description in zero shot topics")
1141
+ zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
1142
+ zero_shot_topics_subtopics_list = list(zero_shot_topics["Subtopic"])
1143
+ zero_shot_topics_description_list = list(zero_shot_topics["Description"])
1144
+
1145
+ # If number of columns is at least 2, keep general topics and subtopics
1146
+ elif (
1147
+ zero_shot_topics.shape[1] >= 2
1148
+ and "Description" not in zero_shot_topics.columns
1149
+ ):
1150
+ zero_shot_topics_gen_topics_list = list(zero_shot_topics.iloc[:, 0])
1151
+ zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 1])
1152
+ else:
1153
+ # If there are more columns, just assume that the first column was meant to be a subtopic
1154
+ zero_shot_topics_gen_topics_list = [""] * zero_shot_topics.shape[0]
1155
+ zero_shot_topics_subtopics_list = list(zero_shot_topics.iloc[:, 0])
1156
+
1157
+ # Add a description if column is present
1158
+ if not zero_shot_topics_description_list:
1159
+ if "Description" in zero_shot_topics.columns:
1160
+ zero_shot_topics_description_list = list(
1161
+ zero_shot_topics["Description"]
1162
+ )
1163
+ # print("Description found in topic title. List is:", zero_shot_topics_description_list)
1164
+ elif zero_shot_topics.shape[1] >= 3:
1165
+ zero_shot_topics_description_list = list(
1166
+ zero_shot_topics.iloc[:, 2]
1167
+ ) # Assume the third column is description
1168
+ else:
1169
+ zero_shot_topics_description_list = [""] * zero_shot_topics.shape[0]
1170
+
1171
+ # If the responses are being forced into zero shot topics, allow an option for nothing relevant
1172
+ if force_zero_shot_radio == "Yes":
1173
+ zero_shot_topics_gen_topics_list.append("")
1174
+ zero_shot_topics_subtopics_list.append("No relevant topic")
1175
+ zero_shot_topics_description_list.append("")
1176
+
1177
+ # Add description or not
1178
+ zero_shot_topics_df = pd.DataFrame(
1179
+ data={
1180
+ "General topic": zero_shot_topics_gen_topics_list,
1181
+ "Subtopic": zero_shot_topics_subtopics_list,
1182
+ "Description": zero_shot_topics_description_list,
1183
+ }
1184
+ )
1185
+
1186
+ # Filter out duplicate General topic and subtopic names
1187
+ zero_shot_topics_df = zero_shot_topics_df.drop_duplicates(
1188
+ ["General topic", "Subtopic"], keep="first"
1189
+ )
1190
+
1191
+ # Sort the dataframe by General topic and subtopic
1192
+ zero_shot_topics_df = zero_shot_topics_df.sort_values(
1193
+ ["General topic", "Subtopic"], ascending=[True, True]
1194
+ )
1195
+
1196
+ return zero_shot_topics_df
1197
+
1198
+
1199
+ def update_model_choice(model_source):
1200
+ # Filter models by source and return the first matching model name
1201
+ matching_models = [
1202
+ model_name
1203
+ for model_name, model_info in model_name_map.items()
1204
+ if model_info["source"] == model_source
1205
+ ]
1206
+
1207
+ output_model = matching_models[0] if matching_models else model_full_names[0]
1208
+
1209
+ return gr.Dropdown(
1210
+ value=output_model,
1211
+ choices=matching_models,
1212
+ label="Large language model for topic extraction and summarisation",
1213
+ multiselect=False,
1214
+ )
1215
+
1216
+
1217
+ def ensure_model_in_map(model_choice: str, model_name_map_dict: dict = None) -> dict:
1218
+ """
1219
+ Ensures that a model_choice is registered in model_name_map.
1220
+ If the model_choice is not found, it assumes it's an inference-server model
1221
+ and adds it to the map with source "inference-server".
1222
+
1223
+ Args:
1224
+ model_choice (str): The model name to check/register
1225
+ model_name_map_dict (dict, optional): The model_name_map dictionary to update.
1226
+ If None, uses the global model_name_map from config.
1227
+
1228
+ Returns:
1229
+ dict: The model_name_map dictionary (updated if needed)
1230
+ """
1231
+ # Use provided dict or global one
1232
+ if model_name_map_dict is None:
1233
+ from tools.config import model_name_map
1234
+
1235
+ model_name_map_dict = model_name_map
1236
+
1237
+ # If model_choice is not in the map, assume it's an inference-server model
1238
+ if model_choice not in model_name_map_dict:
1239
+ model_name_map_dict[model_choice] = {
1240
+ "short_name": model_choice,
1241
+ "source": "inference-server",
1242
+ }
1243
+ print(f"Registered custom model '{model_choice}' as inference-server model")
1244
+
1245
+ return model_name_map_dict
tools/llm_api_call.py ADDED
The diff for this file is too large to render. See raw diff
 
tools/llm_funcs.py ADDED
@@ -0,0 +1,1999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import time
5
+ from typing import List, Tuple
6
+
7
+ import boto3
8
+ import pandas as pd
9
+ import requests
10
+
11
+ # Import mock patches if in test mode
12
+ if os.environ.get("USE_MOCK_LLM") == "1" or os.environ.get("TEST_MODE") == "1":
13
+ try:
14
+ # Try to import and apply mock patches
15
+ import sys
16
+
17
+ # Add project root to sys.path so we can import test.mock_llm_calls
18
+ project_root = os.path.dirname(os.path.dirname(__file__))
19
+ if project_root not in sys.path:
20
+ sys.path.insert(0, project_root)
21
+ try:
22
+ from test.mock_llm_calls import apply_mock_patches
23
+
24
+ apply_mock_patches()
25
+ except ImportError:
26
+ # If mock module not found, continue without mocking
27
+ pass
28
+ except Exception:
29
+ # If anything fails, continue without mocking
30
+ pass
31
+ from google import genai as ai
32
+ from google.genai import types
33
+ from gradio import Progress
34
+ from huggingface_hub import hf_hub_download
35
+ from openai import OpenAI
36
+ from tqdm import tqdm
37
+
38
+ model_type = None # global variable setup
39
+ full_text = (
40
+ "" # Define dummy source text (full text) just to enable highlight function to load
41
+ )
42
+
43
+ # Global variables for model and tokenizer
44
+ _model = None
45
+ _tokenizer = None
46
+ _assistant_model = None
47
+
48
+ from tools.config import (
49
+ ASSISTANT_MODEL,
50
+ BATCH_SIZE_DEFAULT,
51
+ CHOSEN_LOCAL_MODEL_TYPE,
52
+ COMPILE_MODE,
53
+ COMPILE_TRANSFORMERS,
54
+ DEDUPLICATION_THRESHOLD,
55
+ HF_TOKEN,
56
+ INT8_WITH_OFFLOAD_TO_CPU,
57
+ K_QUANT_LEVEL,
58
+ LLM_BATCH_SIZE,
59
+ LLM_CONTEXT_LENGTH,
60
+ LLM_LAST_N_TOKENS,
61
+ LLM_MAX_GPU_LAYERS,
62
+ LLM_MAX_NEW_TOKENS,
63
+ LLM_MIN_P,
64
+ LLM_REPETITION_PENALTY,
65
+ LLM_RESET,
66
+ LLM_SAMPLE,
67
+ LLM_SEED,
68
+ LLM_STOP_STRINGS,
69
+ LLM_STREAM,
70
+ LLM_TEMPERATURE,
71
+ LLM_THREADS,
72
+ LLM_TOP_K,
73
+ LLM_TOP_P,
74
+ LOAD_LOCAL_MODEL_AT_START,
75
+ LOCAL_MODEL_FILE,
76
+ LOCAL_MODEL_FOLDER,
77
+ LOCAL_REPO_ID,
78
+ MAX_COMMENT_CHARS,
79
+ MAX_TIME_FOR_LOOP,
80
+ MODEL_DTYPE,
81
+ MULTIMODAL_PROMPT_FORMAT,
82
+ NUM_PRED_TOKENS,
83
+ NUMBER_OF_RETRY_ATTEMPTS,
84
+ RUN_LOCAL_MODEL,
85
+ SPECULATIVE_DECODING,
86
+ TIMEOUT_WAIT,
87
+ USE_BITSANDBYTES,
88
+ USE_LLAMA_CPP,
89
+ USE_LLAMA_SWAP,
90
+ V_QUANT_LEVEL,
91
+ )
92
+ from tools.helper_functions import _get_env_list
93
+
94
+ if SPECULATIVE_DECODING == "True":
95
+ SPECULATIVE_DECODING = True
96
+ else:
97
+ SPECULATIVE_DECODING = False
98
+
99
+
100
+ if isinstance(NUM_PRED_TOKENS, str):
101
+ NUM_PRED_TOKENS = int(NUM_PRED_TOKENS)
102
+ if isinstance(LLM_MAX_GPU_LAYERS, str):
103
+ LLM_MAX_GPU_LAYERS = int(LLM_MAX_GPU_LAYERS)
104
+ if isinstance(LLM_THREADS, str):
105
+ LLM_THREADS = int(LLM_THREADS)
106
+
107
+ if LLM_RESET == "True":
108
+ reset = True
109
+ else:
110
+ reset = False
111
+
112
+ if LLM_STREAM == "True":
113
+ stream = True
114
+ else:
115
+ stream = False
116
+
117
+ if LLM_SAMPLE == "True":
118
+ sample = True
119
+ else:
120
+ sample = False
121
+
122
+ if LLM_STOP_STRINGS:
123
+ LLM_STOP_STRINGS = _get_env_list(LLM_STOP_STRINGS, strip_strings=False)
124
+
125
+ max_tokens = LLM_MAX_NEW_TOKENS
126
+ timeout_wait = TIMEOUT_WAIT
127
+ number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
128
+ max_time_for_loop = MAX_TIME_FOR_LOOP
129
+ batch_size_default = BATCH_SIZE_DEFAULT
130
+ deduplication_threshold = DEDUPLICATION_THRESHOLD
131
+ max_comment_character_length = MAX_COMMENT_CHARS
132
+
133
+ temperature = LLM_TEMPERATURE
134
+ top_k = LLM_TOP_K
135
+ top_p = LLM_TOP_P
136
+ min_p = LLM_MIN_P
137
+ repetition_penalty = LLM_REPETITION_PENALTY
138
+ last_n_tokens = LLM_LAST_N_TOKENS
139
+ LLM_MAX_NEW_TOKENS: int = LLM_MAX_NEW_TOKENS
140
+ seed: int = LLM_SEED
141
+ reset: bool = reset
142
+ stream: bool = stream
143
+ batch_size: int = LLM_BATCH_SIZE
144
+ context_length: int = LLM_CONTEXT_LENGTH
145
+ sample = LLM_SAMPLE
146
+ stop_strings = LLM_STOP_STRINGS
147
+ speculative_decoding = SPECULATIVE_DECODING
148
+ if LLM_MAX_GPU_LAYERS != 0:
149
+ gpu_layers = int(LLM_MAX_GPU_LAYERS)
150
+ torch_device = "cuda"
151
+ else:
152
+ gpu_layers = 0
153
+ torch_device = "cpu"
154
+
155
+ if not LLM_THREADS:
156
+ threads = 1
157
+ else:
158
+ threads = LLM_THREADS
159
+
160
+
161
+ class llama_cpp_init_config_gpu:
162
+ def __init__(
163
+ self,
164
+ last_n_tokens=last_n_tokens,
165
+ seed=seed,
166
+ n_threads=threads,
167
+ n_batch=batch_size,
168
+ n_ctx=context_length,
169
+ n_gpu_layers=gpu_layers,
170
+ reset=reset,
171
+ ):
172
+
173
+ self.last_n_tokens = last_n_tokens
174
+ self.seed = seed
175
+ self.n_threads = n_threads
176
+ self.n_batch = n_batch
177
+ self.n_ctx = n_ctx
178
+ self.n_gpu_layers = n_gpu_layers
179
+ self.reset = reset
180
+ # self.stop: list[str] = field(default_factory=lambda: [stop_string])
181
+
182
+ def update_gpu(self, new_value):
183
+ self.n_gpu_layers = new_value
184
+
185
+ def update_context(self, new_value):
186
+ self.n_ctx = new_value
187
+
188
+
189
+ class llama_cpp_init_config_cpu(llama_cpp_init_config_gpu):
190
+ def __init__(self):
191
+ super().__init__()
192
+ self.n_gpu_layers = gpu_layers
193
+ self.n_ctx = context_length
194
+
195
+
196
+ gpu_config = llama_cpp_init_config_gpu()
197
+ cpu_config = llama_cpp_init_config_cpu()
198
+
199
+
200
+ class LlamaCPPGenerationConfig:
201
+ def __init__(
202
+ self,
203
+ temperature=temperature,
204
+ top_k=top_k,
205
+ min_p=min_p,
206
+ top_p=top_p,
207
+ repeat_penalty=repetition_penalty,
208
+ seed=seed,
209
+ stream=stream,
210
+ max_tokens=LLM_MAX_NEW_TOKENS,
211
+ reset=reset,
212
+ ):
213
+ self.temperature = temperature
214
+ self.top_k = top_k
215
+ self.top_p = top_p
216
+ self.repeat_penalty = repeat_penalty
217
+ self.seed = seed
218
+ self.max_tokens = max_tokens
219
+ self.stream = stream
220
+ self.reset = reset
221
+
222
+ def update_temp(self, new_value):
223
+ self.temperature = new_value
224
+
225
+
226
+ # ResponseObject class for AWS Bedrock calls
227
+ class ResponseObject:
228
+ def __init__(self, text, usage_metadata):
229
+ self.text = text
230
+ self.usage_metadata = usage_metadata
231
+
232
+
233
+ ###
234
+ # LOCAL MODEL FUNCTIONS
235
+ ###
236
+
237
+
238
+ def get_model_path(
239
+ repo_id=LOCAL_REPO_ID,
240
+ model_filename=LOCAL_MODEL_FILE,
241
+ model_dir=LOCAL_MODEL_FOLDER,
242
+ hf_token=HF_TOKEN,
243
+ ):
244
+ # Construct the expected local path
245
+ local_path = os.path.join(model_dir, model_filename)
246
+
247
+ print("local path for model load:", local_path)
248
+
249
+ try:
250
+ if os.path.exists(local_path):
251
+ print(f"Model already exists at: {local_path}")
252
+
253
+ return local_path
254
+ else:
255
+ if hf_token:
256
+ print("Downloading model from Hugging Face Hub with HF token")
257
+ downloaded_model_path = hf_hub_download(
258
+ repo_id=repo_id, token=hf_token, filename=model_filename
259
+ )
260
+
261
+ return downloaded_model_path
262
+ else:
263
+ print(
264
+ "No HF token found, downloading model from Hugging Face Hub without token"
265
+ )
266
+ downloaded_model_path = hf_hub_download(
267
+ repo_id=repo_id, filename=model_filename
268
+ )
269
+
270
+ return downloaded_model_path
271
+
272
+ except Exception as e:
273
+ print("Error loading model:", e)
274
+ raise Warning("Error loading model:", e)
275
+
276
+
277
+ def load_model(
278
+ local_model_type: str = CHOSEN_LOCAL_MODEL_TYPE,
279
+ gpu_layers: int = gpu_layers,
280
+ max_context_length: int = context_length,
281
+ gpu_config: llama_cpp_init_config_gpu = gpu_config,
282
+ cpu_config: llama_cpp_init_config_cpu = cpu_config,
283
+ torch_device: str = torch_device,
284
+ repo_id=LOCAL_REPO_ID,
285
+ model_filename=LOCAL_MODEL_FILE,
286
+ model_dir=LOCAL_MODEL_FOLDER,
287
+ compile_mode=COMPILE_MODE,
288
+ model_dtype=MODEL_DTYPE,
289
+ hf_token=HF_TOKEN,
290
+ speculative_decoding=speculative_decoding,
291
+ model=None,
292
+ tokenizer=None,
293
+ assistant_model=None,
294
+ ):
295
+ """
296
+ Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
297
+
298
+ Args:
299
+ local_model_type (str): The type of local model to load (e.g., "llama-cpp").
300
+ gpu_layers (int): The number of GPU layers to offload to the GPU.
301
+ max_context_length (int): The maximum context length for the model.
302
+ gpu_config (llama_cpp_init_config_gpu): Configuration object for GPU-specific Llama.cpp parameters.
303
+ cpu_config (llama_cpp_init_config_cpu): Configuration object for CPU-specific Llama.cpp parameters.
304
+ torch_device (str): The device to load the model on ("cuda" for GPU, "cpu" for CPU).
305
+ repo_id (str): The Hugging Face repository ID where the model is located.
306
+ model_filename (str): The specific filename of the model to download from the repository.
307
+ model_dir (str): The local directory where the model will be stored or downloaded.
308
+ compile_mode (str): The compilation mode to use for the model.
309
+ model_dtype (str): The data type to use for the model.
310
+ hf_token (str): The Hugging Face token to use for the model.
311
+ speculative_decoding (bool): Whether to use speculative decoding.
312
+ model (Llama/transformers model): The model to load.
313
+ tokenizer (list/transformers tokenizer): The tokenizer to load.
314
+ assistant_model (transformers model): The assistant model for speculative decoding.
315
+ Returns:
316
+ tuple: A tuple containing:
317
+ - model (Llama/transformers model): The loaded Llama.cpp/transformers model instance.
318
+ - tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
319
+ - assistant_model (transformers model): The assistant model for speculative decoding (if speculative_decoding is True).
320
+ """
321
+
322
+ if model:
323
+ return model, tokenizer, assistant_model
324
+
325
+ print("Loading model:", local_model_type)
326
+
327
+ # Verify the device and cuda settings
328
+ # Check if CUDA is enabled
329
+
330
+ import torch
331
+
332
+ torch.cuda.empty_cache()
333
+ print("Is CUDA enabled? ", torch.cuda.is_available())
334
+ print("Is a CUDA device available on this computer?", torch.backends.cudnn.enabled)
335
+ if torch.cuda.is_available():
336
+ torch_device = "cuda"
337
+ gpu_layers = int(LLM_MAX_GPU_LAYERS)
338
+ print("CUDA version:", torch.version.cuda)
339
+ # try:
340
+ # os.system("nvidia-smi")
341
+ # except Exception as e:
342
+ # print("Could not print nvidia-smi settings due to:", e)
343
+ else:
344
+ torch_device = "cpu"
345
+ gpu_layers = 0
346
+
347
+ print("Running on device:", torch_device)
348
+ print("GPU layers assigned to cuda:", gpu_layers)
349
+
350
+ if not LLM_THREADS:
351
+ threads = torch.get_num_threads()
352
+ else:
353
+ threads = LLM_THREADS
354
+ print("CPU threads:", threads)
355
+
356
+ # GPU mode
357
+ if torch_device == "cuda":
358
+ torch.cuda.empty_cache()
359
+ gpu_config.update_gpu(gpu_layers)
360
+ gpu_config.update_context(max_context_length)
361
+
362
+ if USE_LLAMA_CPP == "True":
363
+ from llama_cpp import Llama
364
+ from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
365
+
366
+ model_path = get_model_path(
367
+ repo_id=repo_id, model_filename=model_filename, model_dir=model_dir
368
+ )
369
+
370
+ try:
371
+ print("GPU load variables:", vars(gpu_config))
372
+ if speculative_decoding:
373
+ model = Llama(
374
+ model_path=model_path,
375
+ type_k=K_QUANT_LEVEL,
376
+ type_v=V_QUANT_LEVEL,
377
+ flash_attn=True,
378
+ draft_model=LlamaPromptLookupDecoding(
379
+ num_pred_tokens=NUM_PRED_TOKENS
380
+ ),
381
+ **vars(gpu_config),
382
+ )
383
+ else:
384
+ model = Llama(
385
+ model_path=model_path,
386
+ type_k=K_QUANT_LEVEL,
387
+ type_v=V_QUANT_LEVEL,
388
+ flash_attn=True,
389
+ **vars(gpu_config),
390
+ )
391
+
392
+ except Exception as e:
393
+ print("GPU load failed due to:", e, "Loading model in CPU mode")
394
+ # If fails, go to CPU mode
395
+ model = Llama(model_path=model_path, **vars(cpu_config))
396
+
397
+ else:
398
+ from transformers import (
399
+ AutoModelForCausalLM,
400
+ BitsAndBytesConfig,
401
+ )
402
+ from unsloth import FastLanguageModel
403
+
404
+ print("Loading model from transformers")
405
+ # Use the official model ID for Gemma 3 4B
406
+ model_id = (
407
+ repo_id.split("https://huggingface.co/")[-1]
408
+ if "https://huggingface.co/" in repo_id
409
+ else repo_id
410
+ )
411
+ # 1. Set Data Type (dtype)
412
+ # For H200/Hopper: 'bfloat16'
413
+ # For RTX 3060/Ampere: 'float16'
414
+ dtype_str = model_dtype # os.environ.get("MODEL_DTYPE", "bfloat16").lower()
415
+ if dtype_str == "bfloat16":
416
+ torch_dtype = torch.bfloat16
417
+ elif dtype_str == "float16":
418
+ torch_dtype = torch.float16
419
+ else:
420
+ torch_dtype = torch.float32 # A safe fallback
421
+
422
+ # 2. Set Compilation Mode
423
+ # 'max-autotune' is great for both but can be slow initially.
424
+ # 'reduce-overhead' is a faster alternative for compiling.
425
+
426
+ print("--- System Configuration ---")
427
+ print(f"Using model id: {model_id}")
428
+ print(f"Using dtype: {torch_dtype}")
429
+ print(f"Using compile mode: {compile_mode}")
430
+ print(f"Using bitsandbytes: {USE_BITSANDBYTES}")
431
+ print("--------------------------\n")
432
+
433
+ # --- Load Tokenizer and Model ---
434
+
435
+ try:
436
+
437
+ # Load Tokenizer and Model
438
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
439
+
440
+ if USE_BITSANDBYTES == "True":
441
+
442
+ if INT8_WITH_OFFLOAD_TO_CPU == "True":
443
+ # This will be very slow. Requires at least 4GB of VRAM and 32GB of RAM
444
+ print(
445
+ "Using bitsandbytes for quantisation to 8 bits, with offloading to CPU"
446
+ )
447
+ max_memory = {0: "4GB", "cpu": "32GB"}
448
+ BitsAndBytesConfig(
449
+ load_in_8bit=True,
450
+ max_memory=max_memory,
451
+ llm_int8_enable_fp32_cpu_offload=True, # Note: if bitsandbytes has to offload to CPU, inference will be slow
452
+ )
453
+ else:
454
+ # For Gemma 4B, requires at least 6GB of VRAM
455
+ print("Using bitsandbytes for quantisation to 4 bits")
456
+ BitsAndBytesConfig(
457
+ load_in_4bit=True,
458
+ bnb_4bit_quant_type="nf4", # Use the modern NF4 quantisation for better performance
459
+ bnb_4bit_compute_dtype=torch_dtype,
460
+ bnb_4bit_use_double_quant=True, # Optional: uses a second quantisation step to save even more memory
461
+ )
462
+
463
+ # print("Loading model with bitsandbytes quantisation config:", quantisation_config)
464
+
465
+ model, tokenizer = FastLanguageModel.from_pretrained(
466
+ model_id,
467
+ max_seq_length=max_context_length,
468
+ dtype=torch_dtype,
469
+ device_map="auto",
470
+ load_in_4bit=True,
471
+ # quantization_config=quantisation_config, # Not actually used in Unsloth
472
+ token=hf_token,
473
+ )
474
+
475
+ FastLanguageModel.for_inference(model)
476
+ else:
477
+ print("Loading model without bitsandbytes quantisation")
478
+ model, tokenizer = FastLanguageModel.from_pretrained(
479
+ model_id,
480
+ max_seq_length=max_context_length,
481
+ dtype=torch_dtype,
482
+ device_map="auto",
483
+ token=hf_token,
484
+ )
485
+
486
+ FastLanguageModel.for_inference(model)
487
+
488
+ if not tokenizer.pad_token:
489
+ tokenizer.pad_token = tokenizer.eos_token
490
+
491
+ except Exception as e:
492
+ print("Error loading model with bitsandbytes quantisation config:", e)
493
+ raise Warning(
494
+ "Error loading model with bitsandbytes quantisation config:", e
495
+ )
496
+
497
+ # Compile the Model with the selected mode πŸš€
498
+ if COMPILE_TRANSFORMERS == "True":
499
+ try:
500
+ model = torch.compile(model, mode=compile_mode, fullgraph=True)
501
+ except Exception as e:
502
+ print(f"Could not compile model: {e}. Running in eager mode.")
503
+
504
+ print(
505
+ "Loading with",
506
+ gpu_config.n_gpu_layers,
507
+ "model layers sent to GPU and a maximum context length of",
508
+ gpu_config.n_ctx,
509
+ )
510
+
511
+ # CPU mode
512
+ else:
513
+ if USE_LLAMA_CPP == "False":
514
+ raise Warning(
515
+ "Using transformers model in CPU mode is not supported. Please change your config variable USE_LLAMA_CPP to True if you want to do CPU inference."
516
+ )
517
+
518
+ model_path = get_model_path(
519
+ repo_id=repo_id, model_filename=model_filename, model_dir=model_dir
520
+ )
521
+
522
+ # gpu_config.update_gpu(gpu_layers)
523
+ cpu_config.update_gpu(gpu_layers)
524
+
525
+ # Update context length according to slider
526
+ # gpu_config.update_context(max_context_length)
527
+ cpu_config.update_context(max_context_length)
528
+
529
+ if speculative_decoding:
530
+ model = Llama(
531
+ model_path=model_path,
532
+ draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS),
533
+ **vars(cpu_config),
534
+ )
535
+ else:
536
+ model = Llama(model_path=model_path, **vars(cpu_config))
537
+
538
+ print(
539
+ "Loading with",
540
+ cpu_config.n_gpu_layers,
541
+ "model layers sent to GPU and a maximum context length of",
542
+ cpu_config.n_ctx,
543
+ )
544
+
545
+ print("Finished loading model:", local_model_type)
546
+ print("GPU layers assigned to cuda:", gpu_layers)
547
+
548
+ # Load assistant model for speculative decoding if enabled
549
+ if speculative_decoding and USE_LLAMA_CPP == "False" and torch_device == "cuda":
550
+ print("Loading assistant model for speculative decoding:", ASSISTANT_MODEL)
551
+ try:
552
+ from transformers import AutoModelForCausalLM
553
+
554
+ # Load the assistant model with the same configuration as the main model
555
+ assistant_model = AutoModelForCausalLM.from_pretrained(
556
+ ASSISTANT_MODEL, dtype=torch_dtype, device_map="auto", token=hf_token
557
+ )
558
+
559
+ # assistant_model.config._name_or_path = model.config._name_or_path
560
+
561
+ # Compile the assistant model if compilation is enabled
562
+ if COMPILE_TRANSFORMERS == "True":
563
+ try:
564
+ assistant_model = torch.compile(
565
+ assistant_model, mode=compile_mode, fullgraph=True
566
+ )
567
+ except Exception as e:
568
+ print(
569
+ f"Could not compile assistant model: {e}. Running in eager mode."
570
+ )
571
+
572
+ print("Successfully loaded assistant model for speculative decoding")
573
+
574
+ except Exception as e:
575
+ print(f"Error loading assistant model: {e}")
576
+ assistant_model = None
577
+ else:
578
+ assistant_model = None
579
+
580
+ return model, tokenizer, assistant_model
581
+
582
+
583
+ def get_model():
584
+ """Get the globally loaded model. Load it if not already loaded."""
585
+ global _model, _tokenizer, _assistant_model
586
+ if _model is None:
587
+ _model, _tokenizer, _assistant_model = load_model(
588
+ local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
589
+ gpu_layers=gpu_layers,
590
+ max_context_length=context_length,
591
+ gpu_config=gpu_config,
592
+ cpu_config=cpu_config,
593
+ torch_device=torch_device,
594
+ repo_id=LOCAL_REPO_ID,
595
+ model_filename=LOCAL_MODEL_FILE,
596
+ model_dir=LOCAL_MODEL_FOLDER,
597
+ compile_mode=COMPILE_MODE,
598
+ model_dtype=MODEL_DTYPE,
599
+ hf_token=HF_TOKEN,
600
+ model=_model,
601
+ tokenizer=_tokenizer,
602
+ assistant_model=_assistant_model,
603
+ )
604
+ return _model
605
+
606
+
607
+ def get_tokenizer():
608
+ """Get the globally loaded tokenizer. Load it if not already loaded."""
609
+ global _model, _tokenizer, _assistant_model
610
+ if _tokenizer is None:
611
+ _model, _tokenizer, _assistant_model = load_model(
612
+ local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
613
+ gpu_layers=gpu_layers,
614
+ max_context_length=context_length,
615
+ gpu_config=gpu_config,
616
+ cpu_config=cpu_config,
617
+ torch_device=torch_device,
618
+ repo_id=LOCAL_REPO_ID,
619
+ model_filename=LOCAL_MODEL_FILE,
620
+ model_dir=LOCAL_MODEL_FOLDER,
621
+ compile_mode=COMPILE_MODE,
622
+ model_dtype=MODEL_DTYPE,
623
+ hf_token=HF_TOKEN,
624
+ model=_model,
625
+ tokenizer=_tokenizer,
626
+ assistant_model=_assistant_model,
627
+ )
628
+ return _tokenizer
629
+
630
+
631
+ def get_assistant_model():
632
+ """Get the globally loaded assistant model. Load it if not already loaded."""
633
+ global _model, _tokenizer, _assistant_model
634
+ if _assistant_model is None:
635
+ _model, _tokenizer, _assistant_model = load_model(
636
+ local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
637
+ gpu_layers=gpu_layers,
638
+ max_context_length=context_length,
639
+ gpu_config=gpu_config,
640
+ cpu_config=cpu_config,
641
+ torch_device=torch_device,
642
+ repo_id=LOCAL_REPO_ID,
643
+ model_filename=LOCAL_MODEL_FILE,
644
+ model_dir=LOCAL_MODEL_FOLDER,
645
+ compile_mode=COMPILE_MODE,
646
+ model_dtype=MODEL_DTYPE,
647
+ hf_token=HF_TOKEN,
648
+ model=_model,
649
+ tokenizer=_tokenizer,
650
+ assistant_model=_assistant_model,
651
+ )
652
+ return _assistant_model
653
+
654
+
655
+ def set_model(model, tokenizer, assistant_model=None):
656
+ """Set the global model, tokenizer, and assistant model."""
657
+ global _model, _tokenizer, _assistant_model
658
+ _model = model
659
+ _tokenizer = tokenizer
660
+ _assistant_model = assistant_model
661
+
662
+
663
+ # Initialize model at startup if configured
664
+ if LOAD_LOCAL_MODEL_AT_START == "True" and RUN_LOCAL_MODEL == "1":
665
+ get_model() # This will trigger loading
666
+
667
+
668
+ def call_llama_cpp_model(formatted_string: str, gen_config: str, model=None):
669
+ """
670
+ Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
671
+
672
+ Args:
673
+ formatted_string (str): The formatted input text for the model.
674
+ gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
675
+ model: Optional model instance. If None, will use the globally loaded model.
676
+ """
677
+ if model is None:
678
+ model = get_model()
679
+
680
+ if model is None:
681
+ raise ValueError(
682
+ "No model available. Either pass a model parameter or ensure LOAD_LOCAL_MODEL_AT_START is True."
683
+ )
684
+
685
+ # Extracting parameters from the gen_config object
686
+ temperature = gen_config.temperature
687
+ top_k = gen_config.top_k
688
+ top_p = gen_config.top_p
689
+ repeat_penalty = gen_config.repeat_penalty
690
+ seed = gen_config.seed
691
+ max_tokens = gen_config.max_tokens
692
+ stream = gen_config.stream
693
+
694
+ # Now you can call your model directly, passing the parameters:
695
+ output = model(
696
+ formatted_string,
697
+ temperature=temperature,
698
+ top_k=top_k,
699
+ top_p=top_p,
700
+ repeat_penalty=repeat_penalty,
701
+ seed=seed,
702
+ max_tokens=max_tokens,
703
+ stream=stream, # ,
704
+ # stop=["<|eot_id|>", "\n\n"]
705
+ )
706
+
707
+ return output
708
+
709
+
710
+ def call_llama_cpp_chatmodel(
711
+ formatted_string: str,
712
+ system_prompt: str,
713
+ gen_config: LlamaCPPGenerationConfig,
714
+ model=None,
715
+ ):
716
+ """
717
+ Calls your Llama.cpp chat model with a formatted user message and system prompt,
718
+ using generation parameters from the LlamaCPPGenerationConfig object.
719
+
720
+ Args:
721
+ formatted_string (str): The formatted input text for the user's message.
722
+ system_prompt (str): The system-level instructions for the model.
723
+ gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
724
+ model: Optional model instance. If None, will use the globally loaded model.
725
+ """
726
+ if model is None:
727
+ model = get_model()
728
+
729
+ if model is None:
730
+ raise ValueError(
731
+ "No model available. Either pass a model parameter or ensure LOAD_LOCAL_MODEL_AT_START is True."
732
+ )
733
+
734
+ # Extracting parameters from the gen_config object
735
+ temperature = gen_config.temperature
736
+ top_k = gen_config.top_k
737
+ top_p = gen_config.top_p
738
+ repeat_penalty = gen_config.repeat_penalty
739
+ seed = gen_config.seed
740
+ max_tokens = gen_config.max_tokens
741
+ stream = gen_config.stream
742
+ reset = gen_config.reset
743
+
744
+ messages = [
745
+ {"role": "system", "content": system_prompt},
746
+ {"role": "user", "content": formatted_string},
747
+ ]
748
+
749
+ input_tokens = len(
750
+ model.tokenize(
751
+ (system_prompt + "\n" + formatted_string).encode("utf-8"), special=True
752
+ )
753
+ )
754
+
755
+ if stream:
756
+ final_tokens = list()
757
+ output_tokens = 0
758
+ for chunk in model.create_chat_completion(
759
+ messages=messages,
760
+ temperature=temperature,
761
+ top_k=top_k,
762
+ top_p=top_p,
763
+ repeat_penalty=repeat_penalty,
764
+ seed=seed,
765
+ max_tokens=max_tokens,
766
+ stream=True,
767
+ stop=stop_strings,
768
+ ):
769
+ delta = chunk["choices"][0].get("delta", {})
770
+ token = delta.get("content") or chunk["choices"][0].get("text") or ""
771
+ if token:
772
+ print(token, end="", flush=True)
773
+ final_tokens.append(token)
774
+ output_tokens += 1
775
+ print() # newline after stream finishes
776
+
777
+ text = "".join(final_tokens)
778
+
779
+ if reset:
780
+ model.reset()
781
+
782
+ return {
783
+ "choices": [
784
+ {
785
+ "index": 0,
786
+ "finish_reason": "stop",
787
+ "message": {"role": "assistant", "content": text},
788
+ }
789
+ ],
790
+ # Provide a usage object so downstream code can read it
791
+ "usage": {
792
+ "prompt_tokens": input_tokens, # unknown during streaming
793
+ "completion_tokens": output_tokens, # unknown during streaming
794
+ "total_tokens": input_tokens
795
+ + output_tokens, # unknown during streaming
796
+ },
797
+ }
798
+
799
+ else:
800
+ response = model.create_chat_completion(
801
+ messages=messages,
802
+ temperature=temperature,
803
+ top_k=top_k,
804
+ top_p=top_p,
805
+ repeat_penalty=repeat_penalty,
806
+ seed=seed,
807
+ max_tokens=max_tokens,
808
+ stream=False,
809
+ stop=stop_strings,
810
+ )
811
+
812
+ if reset:
813
+ model.reset()
814
+
815
+ return response
816
+
817
+
818
+ def call_inference_server_api(
819
+ formatted_string: str,
820
+ system_prompt: str,
821
+ gen_config: LlamaCPPGenerationConfig,
822
+ api_url: str = "http://localhost:8080",
823
+ model_name: str = None,
824
+ use_llama_swap: bool = USE_LLAMA_SWAP,
825
+ ):
826
+ """
827
+ Calls a inference-server API endpoint with a formatted user message and system prompt,
828
+ using generation parameters from the LlamaCPPGenerationConfig object.
829
+
830
+ This function provides the same interface as call_llama_cpp_chatmodel but calls
831
+ a remote inference-server instance instead of a local model.
832
+
833
+ Args:
834
+ formatted_string (str): The formatted input text for the user's message.
835
+ system_prompt (str): The system-level instructions for the model.
836
+ gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
837
+ api_url (str): The base URL of the inference-server API (default: "http://localhost:8080").
838
+ model_name (str): Optional model name to use. If None, uses the default model.
839
+ use_llama_swap (bool): Whether to use llama-swap for the model.
840
+ Returns:
841
+ dict: Response in the same format as call_llama_cpp_chatmodel
842
+
843
+ Example:
844
+ # Create generation config
845
+ gen_config = LlamaCPPGenerationConfig(temperature=0.7, max_tokens=100)
846
+
847
+ # Call the API
848
+ response = call_inference_server_api(
849
+ formatted_string="Hello, how are you?",
850
+ system_prompt="You are a helpful assistant.",
851
+ gen_config=gen_config,
852
+ api_url="http://localhost:8080"
853
+ )
854
+
855
+ # Extract the response text
856
+ response_text = response['choices'][0]['message']['content']
857
+
858
+ Integration Example:
859
+ # To use inference-server instead of local model:
860
+ # 1. Set model_source to "inference-server"
861
+ # 2. Provide api_url parameter
862
+ # 3. Call your existing functions as normal
863
+
864
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(
865
+ batch_prompts=["Your prompt here"],
866
+ system_prompt="Your system prompt",
867
+ conversation_history=[],
868
+ whole_conversation=[],
869
+ whole_conversation_metadata=[],
870
+ client=None, # Not used for inference-server
871
+ client_config=None, # Not used for inference-server
872
+ model_choice="your-model-name", # Model name on the server
873
+ temperature=0.7,
874
+ reported_batch_no=1,
875
+ local_model=None, # Not used for inference-server
876
+ tokenizer=None, # Not used for inference-server
877
+ bedrock_runtime=None, # Not used for inference-server
878
+ model_source="inference-server",
879
+ MAX_OUTPUT_VALIDATION_ATTEMPTS=3,
880
+ api_url="http://localhost:8080"
881
+ )
882
+ """
883
+ # Extract parameters from the gen_config object
884
+ temperature = gen_config.temperature
885
+ top_k = gen_config.top_k
886
+ top_p = gen_config.top_p
887
+ repeat_penalty = gen_config.repeat_penalty
888
+ seed = gen_config.seed
889
+ max_tokens = gen_config.max_tokens
890
+ stream = gen_config.stream
891
+
892
+ # Prepare the request payload
893
+ messages = [
894
+ {"role": "system", "content": system_prompt},
895
+ {"role": "user", "content": formatted_string},
896
+ ]
897
+
898
+ payload = {
899
+ "messages": messages,
900
+ "temperature": temperature,
901
+ "top_k": top_k,
902
+ "top_p": top_p,
903
+ "repeat_penalty": repeat_penalty,
904
+ "seed": seed,
905
+ "max_tokens": max_tokens,
906
+ "stream": stream,
907
+ "stop": LLM_STOP_STRINGS if LLM_STOP_STRINGS else [],
908
+ }
909
+ # Add model name if specified and use llama-swap
910
+ if model_name and use_llama_swap:
911
+ payload["model"] = model_name
912
+
913
+ # Determine the endpoint based on streaming preference
914
+ if stream:
915
+ endpoint = f"{api_url}/v1/chat/completions"
916
+ else:
917
+ endpoint = f"{api_url}/v1/chat/completions"
918
+
919
+ try:
920
+ if stream:
921
+ # Handle streaming response
922
+ response = requests.post(
923
+ endpoint,
924
+ json=payload,
925
+ headers={"Content-Type": "application/json"},
926
+ stream=True,
927
+ timeout=timeout_wait,
928
+ )
929
+ response.raise_for_status()
930
+
931
+ final_tokens = []
932
+ output_tokens = 0
933
+
934
+ for line in response.iter_lines():
935
+ if line:
936
+ line = line.decode("utf-8")
937
+ if line.startswith("data: "):
938
+ data = line[6:] # Remove 'data: ' prefix
939
+ if data.strip() == "[DONE]":
940
+ break
941
+ try:
942
+ chunk = json.loads(data)
943
+ if "choices" in chunk and len(chunk["choices"]) > 0:
944
+ delta = chunk["choices"][0].get("delta", {})
945
+ token = delta.get("content", "")
946
+ if token:
947
+ print(token, end="", flush=True)
948
+ final_tokens.append(token)
949
+ output_tokens += 1
950
+ except json.JSONDecodeError:
951
+ continue
952
+
953
+ print() # newline after stream finishes
954
+
955
+ text = "".join(final_tokens)
956
+
957
+ # Estimate input tokens (rough approximation)
958
+ input_tokens = len((system_prompt + "\n" + formatted_string).split())
959
+
960
+ return {
961
+ "choices": [
962
+ {
963
+ "index": 0,
964
+ "finish_reason": "stop",
965
+ "message": {"role": "assistant", "content": text},
966
+ }
967
+ ],
968
+ "usage": {
969
+ "prompt_tokens": input_tokens,
970
+ "completion_tokens": output_tokens,
971
+ "total_tokens": input_tokens + output_tokens,
972
+ },
973
+ }
974
+ else:
975
+ # Handle non-streaming response
976
+ response = requests.post(
977
+ endpoint,
978
+ json=payload,
979
+ headers={"Content-Type": "application/json"},
980
+ timeout=timeout_wait,
981
+ )
982
+ response.raise_for_status()
983
+
984
+ result = response.json()
985
+
986
+ # Ensure the response has the expected format
987
+ if "choices" not in result:
988
+ raise ValueError("Invalid response format from inference-server")
989
+
990
+ return result
991
+
992
+ except requests.exceptions.RequestException as e:
993
+ raise ConnectionError(
994
+ f"Failed to connect to inference-server at {api_url}: {str(e)}"
995
+ )
996
+ except json.JSONDecodeError as e:
997
+ raise ValueError(f"Invalid JSON response from inference-server: {str(e)}")
998
+ except Exception as e:
999
+ raise RuntimeError(f"Error calling inference-server API: {str(e)}")
1000
+
1001
+
1002
+ ###
1003
+ # LLM FUNCTIONS
1004
+ ###
1005
+
1006
+
1007
+ def construct_gemini_generative_model(
1008
+ in_api_key: str,
1009
+ temperature: float,
1010
+ model_choice: str,
1011
+ system_prompt: str,
1012
+ max_tokens: int,
1013
+ random_seed=seed,
1014
+ ) -> Tuple[object, dict]:
1015
+ """
1016
+ Constructs a GenerativeModel for Gemini API calls.
1017
+ ...
1018
+ """
1019
+ # Construct a GenerativeModel
1020
+ try:
1021
+ if in_api_key:
1022
+ # print("Getting API key from textbox")
1023
+ api_key = in_api_key
1024
+ client = ai.Client(api_key=api_key)
1025
+ elif "GOOGLE_API_KEY" in os.environ:
1026
+ # print("Searching for API key in environmental variables")
1027
+ api_key = os.environ["GOOGLE_API_KEY"]
1028
+ client = ai.Client(api_key=api_key)
1029
+ else:
1030
+ print("No Gemini API key found")
1031
+ raise Warning("No Gemini API key found.")
1032
+ except Exception as e:
1033
+ print("Error constructing Gemini generative model:", e)
1034
+ raise Warning("Error constructing Gemini generative model:", e)
1035
+
1036
+ config = types.GenerateContentConfig(
1037
+ temperature=temperature, max_output_tokens=max_tokens, seed=random_seed
1038
+ )
1039
+
1040
+ return client, config
1041
+
1042
+
1043
+ def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
1044
+ """
1045
+ Constructs an OpenAI client for Azure/OpenAI AI Inference.
1046
+ """
1047
+ try:
1048
+ key = None
1049
+ if in_api_key:
1050
+ key = in_api_key
1051
+ elif os.environ.get("AZURE_OPENAI_API_KEY"):
1052
+ key = os.environ["AZURE_OPENAI_API_KEY"]
1053
+ if not key:
1054
+ raise Warning("No Azure/OpenAI API key found.")
1055
+
1056
+ if not endpoint:
1057
+ endpoint = os.environ.get("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
1058
+ if not endpoint:
1059
+ # Assume using OpenAI API
1060
+ client = OpenAI(
1061
+ api_key=key,
1062
+ )
1063
+ else:
1064
+ # Use the provided endpoint
1065
+ client = OpenAI(
1066
+ api_key=key,
1067
+ base_url=f"{endpoint}",
1068
+ )
1069
+
1070
+ return client, dict()
1071
+ except Exception as e:
1072
+ print("Error constructing Azure/OpenAI client:", e)
1073
+ raise
1074
+
1075
+
1076
+ def call_aws_bedrock(
1077
+ prompt: str,
1078
+ system_prompt: str,
1079
+ temperature: float,
1080
+ max_tokens: int,
1081
+ model_choice: str,
1082
+ bedrock_runtime: boto3.Session.client,
1083
+ assistant_prefill: str = "",
1084
+ ) -> ResponseObject:
1085
+ """
1086
+ This function sends a request to AWS Claude with the following parameters:
1087
+ - prompt: The user's input prompt to be processed by the model.
1088
+ - system_prompt: A system-defined prompt that provides context or instructions for the model.
1089
+ - temperature: A value that controls the randomness of the model's output, with higher values resulting in more diverse responses.
1090
+ - max_tokens: The maximum number of tokens (words or characters) in the model's response.
1091
+ - model_choice: The specific model to use for processing the request.
1092
+ - bedrock_runtime: The client object for boto3 Bedrock runtime
1093
+ - assistant_prefill: A string indicating the text that the response should start with.
1094
+
1095
+ The function constructs the request configuration, invokes the model, extracts the response text, and returns a ResponseObject containing the text and metadata.
1096
+ """
1097
+
1098
+ inference_config = {
1099
+ "maxTokens": max_tokens,
1100
+ "topP": 0.999,
1101
+ "temperature": temperature,
1102
+ }
1103
+
1104
+ # Using an assistant prefill only works for Anthropic models.
1105
+ if assistant_prefill and "anthropic" in model_choice:
1106
+ assistant_prefill_added = True
1107
+ messages = [
1108
+ {
1109
+ "role": "user",
1110
+ "content": [
1111
+ {"text": prompt},
1112
+ ],
1113
+ },
1114
+ {
1115
+ "role": "assistant",
1116
+ # Pre-filling with '|'
1117
+ "content": [{"text": assistant_prefill}],
1118
+ },
1119
+ ]
1120
+ else:
1121
+ assistant_prefill_added = False
1122
+ messages = [
1123
+ {
1124
+ "role": "user",
1125
+ "content": [
1126
+ {"text": prompt},
1127
+ ],
1128
+ }
1129
+ ]
1130
+
1131
+ system_prompt_list = [{"text": system_prompt}]
1132
+
1133
+ # The converse API call.
1134
+ api_response = bedrock_runtime.converse(
1135
+ modelId=model_choice,
1136
+ messages=messages,
1137
+ system=system_prompt_list,
1138
+ inferenceConfig=inference_config,
1139
+ )
1140
+
1141
+ output_message = api_response["output"]["message"]
1142
+
1143
+ if "reasoningContent" in output_message["content"][0]:
1144
+ # Extract the reasoning text
1145
+ output_message["content"][0]["reasoningContent"]["reasoningText"]["text"]
1146
+
1147
+ # Extract the output text
1148
+ if assistant_prefill_added:
1149
+ text = assistant_prefill + output_message["content"][1]["text"]
1150
+ else:
1151
+ text = output_message["content"][1]["text"]
1152
+ else:
1153
+ if assistant_prefill_added:
1154
+ text = assistant_prefill + output_message["content"][0]["text"]
1155
+ else:
1156
+ text = output_message["content"][0]["text"]
1157
+
1158
+ # The usage statistics are neatly provided in the 'usage' key.
1159
+ usage = api_response["usage"]
1160
+
1161
+ # The full API response metadata is in 'ResponseMetadata' if you still need it.
1162
+ api_response["ResponseMetadata"]
1163
+
1164
+ # Create ResponseObject with the cleanly extracted data.
1165
+ response = ResponseObject(text=text, usage_metadata=usage)
1166
+
1167
+ return response
1168
+
1169
+
1170
+ def call_transformers_model(
1171
+ prompt: str,
1172
+ system_prompt: str,
1173
+ gen_config: LlamaCPPGenerationConfig,
1174
+ model=None,
1175
+ tokenizer=None,
1176
+ assistant_model=None,
1177
+ speculative_decoding=speculative_decoding,
1178
+ ):
1179
+ """
1180
+ This function sends a request to a transformers model (through Unsloth) with the given prompt, system prompt, and generation configuration.
1181
+ """
1182
+ from transformers import TextStreamer
1183
+
1184
+ if model is None:
1185
+ model = get_model()
1186
+ if tokenizer is None:
1187
+ tokenizer = get_tokenizer()
1188
+ if assistant_model is None and speculative_decoding:
1189
+ assistant_model = get_assistant_model()
1190
+
1191
+ if model is None or tokenizer is None:
1192
+ raise ValueError(
1193
+ "No model or tokenizer available. Either pass them as parameters or ensure LOAD_LOCAL_MODEL_AT_START is True."
1194
+ )
1195
+
1196
+ # 1. Define the conversation as a list of dictionaries
1197
+ # Note: The multimodal format [{"type": "text", "text": text}] is only needed for actual multimodal models
1198
+ # with images/videos. For text-only content, even multimodal models expect plain strings.
1199
+
1200
+ # Always use string format for text-only content, regardless of MULTIMODAL_PROMPT_FORMAT setting
1201
+ # MULTIMODAL_PROMPT_FORMAT should only be used when you actually have multimodal inputs (images, etc.)
1202
+ if MULTIMODAL_PROMPT_FORMAT == "True":
1203
+ conversation = [
1204
+ {
1205
+ "role": "system",
1206
+ "content": [{"type": "text", "text": str(system_prompt)}],
1207
+ },
1208
+ {"role": "user", "content": [{"type": "text", "text": str(prompt)}]},
1209
+ ]
1210
+ else:
1211
+ conversation = [
1212
+ {"role": "system", "content": str(system_prompt)},
1213
+ {"role": "user", "content": str(prompt)},
1214
+ ]
1215
+
1216
+ # 2. Apply the chat template
1217
+ try:
1218
+ # Try applying chat template
1219
+ input_ids = tokenizer.apply_chat_template(
1220
+ conversation,
1221
+ add_generation_prompt=True,
1222
+ tokenize=True,
1223
+ return_tensors="pt",
1224
+ ).to("cuda")
1225
+ except (TypeError, KeyError, IndexError) as e:
1226
+ # If chat template fails, try manual formatting
1227
+ print(f"Chat template failed ({e}), using manual tokenization")
1228
+ # Combine system and user prompts manually
1229
+ full_prompt = f"{system_prompt}\n\n{prompt}"
1230
+ # Tokenize manually with special tokens
1231
+ encoded = tokenizer(full_prompt, return_tensors="pt", add_special_tokens=True)
1232
+ if encoded is None:
1233
+ raise ValueError(
1234
+ "Tokenizer returned None - tokenizer may not be properly initialized"
1235
+ )
1236
+ if not hasattr(encoded, "input_ids") or encoded.input_ids is None:
1237
+ raise ValueError("Tokenizer output does not contain input_ids")
1238
+ input_ids = encoded.input_ids.to("cuda")
1239
+ except Exception as e:
1240
+ print("Error applying chat template:", e)
1241
+ import traceback
1242
+
1243
+ traceback.print_exc()
1244
+ raise
1245
+
1246
+ # Map LlamaCPP parameters to transformers parameters
1247
+ generation_kwargs = {
1248
+ "max_new_tokens": gen_config.max_tokens,
1249
+ "temperature": gen_config.temperature,
1250
+ "top_p": gen_config.top_p,
1251
+ "top_k": gen_config.top_k,
1252
+ "do_sample": True,
1253
+ #'pad_token_id': tokenizer.eos_token_id
1254
+ }
1255
+
1256
+ if gen_config.stream:
1257
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
1258
+ else:
1259
+ streamer = None
1260
+
1261
+ # Remove parameters that don't exist in transformers
1262
+ if hasattr(gen_config, "repeat_penalty"):
1263
+ generation_kwargs["repetition_penalty"] = gen_config.repeat_penalty
1264
+
1265
+ # --- Timed Inference Test ---
1266
+ print("\nStarting model inference...")
1267
+ start_time = time.time()
1268
+
1269
+ # Use speculative decoding if assistant model is available
1270
+ try:
1271
+ if speculative_decoding and assistant_model is not None:
1272
+ # print("Using speculative decoding with assistant model")
1273
+ outputs = model.generate(
1274
+ input_ids,
1275
+ assistant_model=assistant_model,
1276
+ **generation_kwargs,
1277
+ streamer=streamer,
1278
+ )
1279
+ else:
1280
+ # print("Generating without speculative decoding")
1281
+ outputs = model.generate(input_ids, **generation_kwargs, streamer=streamer)
1282
+ except Exception as e:
1283
+ error_msg = str(e)
1284
+ # Check if this is a CUDA compilation error
1285
+ if (
1286
+ "sm_120" in error_msg
1287
+ or "LLVM ERROR" in error_msg
1288
+ or "Cannot select" in error_msg
1289
+ ):
1290
+ print("\n" + "=" * 80)
1291
+ print("CUDA COMPILATION ERROR DETECTED")
1292
+ print("=" * 80)
1293
+ print(
1294
+ "\nThe error is caused by torch.compile() trying to compile CUDA kernels"
1295
+ )
1296
+ print(
1297
+ "with incompatible settings. This is a known issue with certain CUDA/PyTorch"
1298
+ )
1299
+ print("combinations.\n")
1300
+ print(
1301
+ "SOLUTION: Disable model compilation by setting COMPILE_TRANSFORMERS=False"
1302
+ )
1303
+ print("in your config file (config/app_config.env).")
1304
+ print(
1305
+ "\nThe model will still work without compilation, just slightly slower."
1306
+ )
1307
+ print("=" * 80 + "\n")
1308
+ raise RuntimeError(
1309
+ "CUDA compilation error detected. Please set COMPILE_TRANSFORMERS=False "
1310
+ "in your config file to disable model compilation and avoid this error."
1311
+ ) from e
1312
+ else:
1313
+ # Re-raise other errors as-is
1314
+ raise
1315
+
1316
+ end_time = time.time()
1317
+
1318
+ # --- Decode and Display Results ---
1319
+ new_tokens = outputs[0][input_ids.shape[-1] :]
1320
+ assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
1321
+
1322
+ num_input_tokens = input_ids.shape[
1323
+ -1
1324
+ ] # This gets the sequence length (number of tokens)
1325
+ num_generated_tokens = len(new_tokens)
1326
+ duration = end_time - start_time
1327
+ tokens_per_second = num_generated_tokens / duration
1328
+
1329
+ print("\n--- Performance ---")
1330
+ print(f"Time taken: {duration:.2f} seconds")
1331
+ print(f"Generated tokens: {num_generated_tokens}")
1332
+ print(f"Tokens per second: {tokens_per_second:.2f}")
1333
+
1334
+ return assistant_reply, num_input_tokens, num_generated_tokens
1335
+
1336
+
1337
+ # Function to send a request and update history
1338
+ def send_request(
1339
+ prompt: str,
1340
+ conversation_history: List[dict],
1341
+ client: ai.Client | OpenAI,
1342
+ config: types.GenerateContentConfig,
1343
+ model_choice: str,
1344
+ system_prompt: str,
1345
+ temperature: float,
1346
+ bedrock_runtime: boto3.Session.client,
1347
+ model_source: str,
1348
+ local_model=list(),
1349
+ tokenizer=None,
1350
+ assistant_model=None,
1351
+ assistant_prefill="",
1352
+ progress=Progress(track_tqdm=True),
1353
+ api_url: str = None,
1354
+ ) -> Tuple[str, List[dict]]:
1355
+ """Sends a request to a language model and manages the conversation history.
1356
+
1357
+ This function constructs the full prompt by appending the new user prompt to the conversation history,
1358
+ generates a response from the model, and updates the conversation history with the new prompt and response.
1359
+ It handles different model sources (Gemini, AWS, Local, inference-server) and includes retry logic for API calls.
1360
+
1361
+ Args:
1362
+ prompt (str): The user's input prompt to be sent to the model.
1363
+ conversation_history (List[dict]): A list of dictionaries representing the ongoing conversation.
1364
+ Each dictionary should have 'role' and 'parts' keys.
1365
+ client (ai.Client): The API client object for the chosen model (e.g., Gemini `ai.Client`, or Azure/OpenAI `OpenAI`).
1366
+ config (types.GenerateContentConfig): Configuration settings for content generation (e.g., Gemini `types.GenerateContentConfig`).
1367
+ model_choice (str): The specific model identifier to use (e.g., "gemini-pro", "claude-v2").
1368
+ system_prompt (str): An optional system-level instruction or context for the model.
1369
+ temperature (float): Controls the randomness of the model's output, with higher values leading to more diverse responses.
1370
+ bedrock_runtime (boto3.Session.client): The boto3 Bedrock runtime client object for AWS models.
1371
+ model_source (str): Indicates the source/provider of the model (e.g., "Gemini", "AWS", "Local", "inference-server").
1372
+ local_model (list, optional): A list containing the local model and its tokenizer (if `model_source` is "Local"). Defaults to [].
1373
+ tokenizer (object, optional): The tokenizer object for local models. Defaults to None.
1374
+ assistant_model (object, optional): An optional assistant model used for speculative decoding with local models. Defaults to None.
1375
+ assistant_prefill (str, optional): A string to pre-fill the assistant's response, useful for certain models like Claude. Defaults to "".
1376
+ progress (Progress, optional): A progress object for tracking the operation, typically from `tqdm`. Defaults to Progress(track_tqdm=True).
1377
+ api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
1378
+
1379
+ Returns:
1380
+ Tuple[str, List[dict]]: A tuple containing the model's response text and the updated conversation history.
1381
+ """
1382
+ # Constructing the full prompt from the conversation history
1383
+ full_prompt = "Conversation history:\n"
1384
+ num_transformer_input_tokens = 0
1385
+ num_transformer_generated_tokens = 0
1386
+ response_text = ""
1387
+
1388
+ for entry in conversation_history:
1389
+ role = entry[
1390
+ "role"
1391
+ ].capitalize() # Assuming the history is stored with 'role' and 'parts'
1392
+ message = " ".join(entry["parts"]) # Combining all parts of the message
1393
+ full_prompt += f"{role}: {message}\n"
1394
+
1395
+ # Adding the new user prompt
1396
+ full_prompt += f"\nUser: {prompt}"
1397
+
1398
+ # Clear any existing progress bars
1399
+ tqdm._instances.clear()
1400
+
1401
+ progress_bar = range(0, number_of_api_retry_attempts)
1402
+
1403
+ # Generate the model's response
1404
+ if "Gemini" in model_source:
1405
+
1406
+ for i in progress_bar:
1407
+ try:
1408
+ print("Calling Gemini model, attempt", i + 1)
1409
+
1410
+ response = client.models.generate_content(
1411
+ model=model_choice, contents=full_prompt, config=config
1412
+ )
1413
+
1414
+ # print("Successful call to Gemini model.")
1415
+ break
1416
+ except Exception as e:
1417
+ # If fails, try again after X seconds in case there is a throttle limit
1418
+ print(
1419
+ "Call to Gemini model failed:",
1420
+ e,
1421
+ " Waiting for ",
1422
+ str(timeout_wait),
1423
+ "seconds and trying again.",
1424
+ )
1425
+
1426
+ time.sleep(timeout_wait)
1427
+
1428
+ if i == number_of_api_retry_attempts:
1429
+ return (
1430
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1431
+ conversation_history,
1432
+ response_text,
1433
+ num_transformer_input_tokens,
1434
+ num_transformer_generated_tokens,
1435
+ )
1436
+
1437
+ elif "AWS" in model_source:
1438
+ for i in progress_bar:
1439
+ try:
1440
+ print("Calling AWS Bedrock model, attempt", i + 1)
1441
+ response = call_aws_bedrock(
1442
+ prompt,
1443
+ system_prompt,
1444
+ temperature,
1445
+ max_tokens,
1446
+ model_choice,
1447
+ bedrock_runtime=bedrock_runtime,
1448
+ assistant_prefill=assistant_prefill,
1449
+ )
1450
+
1451
+ # print("Successful call to Claude model.")
1452
+ break
1453
+ except Exception as e:
1454
+ # If fails, try again after X seconds in case there is a throttle limit
1455
+ print(
1456
+ "Call to Bedrock model failed:",
1457
+ e,
1458
+ " Waiting for ",
1459
+ str(timeout_wait),
1460
+ "seconds and trying again.",
1461
+ )
1462
+ time.sleep(timeout_wait)
1463
+
1464
+ if i == number_of_api_retry_attempts:
1465
+ return (
1466
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1467
+ conversation_history,
1468
+ response_text,
1469
+ num_transformer_input_tokens,
1470
+ num_transformer_generated_tokens,
1471
+ )
1472
+ elif "Azure/OpenAI" in model_source:
1473
+ for i in progress_bar:
1474
+ try:
1475
+ print("Calling Azure/OpenAI inference model, attempt", i + 1)
1476
+
1477
+ messages = [
1478
+ {
1479
+ "role": "system",
1480
+ "content": system_prompt,
1481
+ },
1482
+ {
1483
+ "role": "user",
1484
+ "content": prompt,
1485
+ },
1486
+ ]
1487
+
1488
+ response_raw = client.chat.completions.create(
1489
+ messages=messages,
1490
+ model=model_choice,
1491
+ temperature=temperature,
1492
+ max_completion_tokens=max_tokens,
1493
+ )
1494
+
1495
+ response_text = response_raw.choices[0].message.content
1496
+ usage = getattr(response_raw, "usage", None)
1497
+ input_tokens = 0
1498
+ output_tokens = 0
1499
+ if usage is not None:
1500
+ input_tokens = getattr(
1501
+ usage, "input_tokens", getattr(usage, "prompt_tokens", 0)
1502
+ )
1503
+ output_tokens = getattr(
1504
+ usage, "output_tokens", getattr(usage, "completion_tokens", 0)
1505
+ )
1506
+ response = ResponseObject(
1507
+ text=response_text,
1508
+ usage_metadata={
1509
+ "inputTokens": input_tokens,
1510
+ "outputTokens": output_tokens,
1511
+ },
1512
+ )
1513
+ break
1514
+ except Exception as e:
1515
+ print(
1516
+ "Call to Azure/OpenAI model failed:",
1517
+ e,
1518
+ " Waiting for ",
1519
+ str(timeout_wait),
1520
+ "seconds and trying again.",
1521
+ )
1522
+ time.sleep(timeout_wait)
1523
+ if i == number_of_api_retry_attempts:
1524
+ return (
1525
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1526
+ conversation_history,
1527
+ response_text,
1528
+ num_transformer_input_tokens,
1529
+ num_transformer_generated_tokens,
1530
+ )
1531
+ elif "Local" in model_source:
1532
+ # This is the local model
1533
+ for i in progress_bar:
1534
+ try:
1535
+ print("Calling local model, attempt", i + 1)
1536
+
1537
+ gen_config = LlamaCPPGenerationConfig()
1538
+ gen_config.update_temp(temperature)
1539
+
1540
+ if USE_LLAMA_CPP == "True":
1541
+ response = call_llama_cpp_chatmodel(
1542
+ prompt, system_prompt, gen_config, model=local_model
1543
+ )
1544
+
1545
+ else:
1546
+ (
1547
+ response,
1548
+ num_transformer_input_tokens,
1549
+ num_transformer_generated_tokens,
1550
+ ) = call_transformers_model(
1551
+ prompt,
1552
+ system_prompt,
1553
+ gen_config,
1554
+ model=local_model,
1555
+ tokenizer=tokenizer,
1556
+ assistant_model=assistant_model,
1557
+ )
1558
+ response_text = response
1559
+
1560
+ break
1561
+ except Exception as e:
1562
+ # If fails, try again after X seconds in case there is a throttle limit
1563
+ print(
1564
+ "Call to local model failed:",
1565
+ e,
1566
+ " Waiting for ",
1567
+ str(timeout_wait),
1568
+ "seconds and trying again.",
1569
+ )
1570
+
1571
+ time.sleep(timeout_wait)
1572
+
1573
+ if i == number_of_api_retry_attempts:
1574
+ return (
1575
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1576
+ conversation_history,
1577
+ response_text,
1578
+ num_transformer_input_tokens,
1579
+ num_transformer_generated_tokens,
1580
+ )
1581
+ elif "inference-server" in model_source:
1582
+ # This is the inference-server API
1583
+ for i in progress_bar:
1584
+ try:
1585
+ print("Calling inference-server API, attempt", i + 1)
1586
+
1587
+ if api_url is None:
1588
+ raise ValueError(
1589
+ "api_url is required when model_source is 'inference-server'"
1590
+ )
1591
+
1592
+ gen_config = LlamaCPPGenerationConfig()
1593
+ gen_config.update_temp(temperature)
1594
+
1595
+ response = call_inference_server_api(
1596
+ prompt,
1597
+ system_prompt,
1598
+ gen_config,
1599
+ api_url=api_url,
1600
+ model_name=model_choice,
1601
+ )
1602
+
1603
+ break
1604
+ except Exception as e:
1605
+ # If fails, try again after X seconds in case there is a throttle limit
1606
+ print(
1607
+ "Call to inference-server API failed:",
1608
+ e,
1609
+ " Waiting for ",
1610
+ str(timeout_wait),
1611
+ "seconds and trying again.",
1612
+ )
1613
+
1614
+ time.sleep(timeout_wait)
1615
+
1616
+ if i == number_of_api_retry_attempts:
1617
+ return (
1618
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1619
+ conversation_history,
1620
+ response_text,
1621
+ num_transformer_input_tokens,
1622
+ num_transformer_generated_tokens,
1623
+ )
1624
+ else:
1625
+ print("Model source not recognised")
1626
+ return (
1627
+ ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
1628
+ conversation_history,
1629
+ response_text,
1630
+ num_transformer_input_tokens,
1631
+ num_transformer_generated_tokens,
1632
+ )
1633
+
1634
+ # Update the conversation history with the new prompt and response
1635
+ conversation_history.append({"role": "user", "parts": [prompt]})
1636
+
1637
+ # Check if is a LLama.cpp model response or inference-server response
1638
+ if isinstance(response, ResponseObject):
1639
+ response_text = response.text
1640
+ elif "choices" in response: # LLama.cpp model response or inference-server response
1641
+ if "gpt-oss" in model_choice:
1642
+ response_text = response["choices"][0]["message"]["content"].split(
1643
+ "<|start|>assistant<|channel|>final<|message|>"
1644
+ )[1]
1645
+ else:
1646
+ response_text = response["choices"][0]["message"]["content"]
1647
+ elif model_source == "Gemini":
1648
+ response_text = response.text
1649
+ else: # Assume transformers model response
1650
+ if "gpt-oss" in model_choice:
1651
+ response_text = response.split(
1652
+ "<|start|>assistant<|channel|>final<|message|>"
1653
+ )[1]
1654
+ else:
1655
+ response_text = response
1656
+
1657
+ # Replace multiple spaces with single space
1658
+ response_text = re.sub(r" {2,}", " ", response_text)
1659
+ response_text = response_text.strip()
1660
+
1661
+ conversation_history.append({"role": "assistant", "parts": [response_text]})
1662
+
1663
+ return (
1664
+ response,
1665
+ conversation_history,
1666
+ response_text,
1667
+ num_transformer_input_tokens,
1668
+ num_transformer_generated_tokens,
1669
+ )
1670
+
1671
+
1672
+ def process_requests(
1673
+ prompts: List[str],
1674
+ system_prompt: str,
1675
+ conversation_history: List[dict],
1676
+ whole_conversation: List[str],
1677
+ whole_conversation_metadata: List[str],
1678
+ client: ai.Client | OpenAI,
1679
+ config: types.GenerateContentConfig,
1680
+ model_choice: str,
1681
+ temperature: float,
1682
+ bedrock_runtime: boto3.Session.client,
1683
+ model_source: str,
1684
+ batch_no: int = 1,
1685
+ local_model=list(),
1686
+ tokenizer=None,
1687
+ assistant_model=None,
1688
+ master: bool = False,
1689
+ assistant_prefill="",
1690
+ api_url: str = None,
1691
+ ) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
1692
+ """
1693
+ Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
1694
+
1695
+ Args:
1696
+ prompts (List[str]): A list of prompts to be processed.
1697
+ system_prompt (str): The system prompt.
1698
+ conversation_history (List[dict]): The history of the conversation.
1699
+ whole_conversation (List[str]): The complete conversation including prompts and responses.
1700
+ whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1701
+ client (object): The client to use for processing the prompts, from either Gemini or OpenAI client.
1702
+ config (dict): Configuration for the model.
1703
+ model_choice (str): The choice of model to use.
1704
+ temperature (float): The temperature parameter for the model.
1705
+ model_source (str): Source of the model, whether local, AWS, Gemini, or inference-server
1706
+ batch_no (int): Batch number of the large language model request.
1707
+ local_model: Local gguf model (if loaded)
1708
+ master (bool): Is this request for the master table.
1709
+ assistant_prefill (str, optional): Is there a prefill for the assistant response. Currently only working for AWS model calls
1710
+ bedrock_runtime: The client object for boto3 Bedrock runtime
1711
+ api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
1712
+
1713
+ Returns:
1714
+ Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
1715
+ """
1716
+ responses = list()
1717
+
1718
+ # Clear any existing progress bars
1719
+ tqdm._instances.clear()
1720
+
1721
+ for prompt in prompts:
1722
+
1723
+ (
1724
+ response,
1725
+ conversation_history,
1726
+ response_text,
1727
+ num_transformer_input_tokens,
1728
+ num_transformer_generated_tokens,
1729
+ ) = send_request(
1730
+ prompt,
1731
+ conversation_history,
1732
+ client=client,
1733
+ config=config,
1734
+ model_choice=model_choice,
1735
+ system_prompt=system_prompt,
1736
+ temperature=temperature,
1737
+ local_model=local_model,
1738
+ tokenizer=tokenizer,
1739
+ assistant_model=assistant_model,
1740
+ assistant_prefill=assistant_prefill,
1741
+ bedrock_runtime=bedrock_runtime,
1742
+ model_source=model_source,
1743
+ api_url=api_url,
1744
+ )
1745
+
1746
+ responses.append(response)
1747
+ whole_conversation.append(system_prompt)
1748
+ whole_conversation.append(prompt)
1749
+ whole_conversation.append(response_text)
1750
+
1751
+ whole_conversation_metadata.append(f"Batch {batch_no}:")
1752
+
1753
+ try:
1754
+ if "AWS" in model_source:
1755
+ output_tokens = response.usage_metadata.get("outputTokens", 0)
1756
+ input_tokens = response.usage_metadata.get("inputTokens", 0)
1757
+
1758
+ elif "Gemini" in model_source:
1759
+ output_tokens = response.usage_metadata.candidates_token_count
1760
+ input_tokens = response.usage_metadata.prompt_token_count
1761
+
1762
+ elif "Azure/OpenAI" in model_source:
1763
+ input_tokens = response.usage_metadata.get("inputTokens", 0)
1764
+ output_tokens = response.usage_metadata.get("outputTokens", 0)
1765
+
1766
+ elif "Local" in model_source:
1767
+ if USE_LLAMA_CPP == "True":
1768
+ output_tokens = response["usage"].get("completion_tokens", 0)
1769
+ input_tokens = response["usage"].get("prompt_tokens", 0)
1770
+
1771
+ if USE_LLAMA_CPP == "False":
1772
+ input_tokens = num_transformer_input_tokens
1773
+ output_tokens = num_transformer_generated_tokens
1774
+
1775
+ elif "inference-server" in model_source:
1776
+ # inference-server returns the same format as llama-cpp
1777
+ output_tokens = response["usage"].get("completion_tokens", 0)
1778
+ input_tokens = response["usage"].get("prompt_tokens", 0)
1779
+
1780
+ else:
1781
+ input_tokens = 0
1782
+ output_tokens = 0
1783
+
1784
+ whole_conversation_metadata.append(
1785
+ "input_tokens: "
1786
+ + str(input_tokens)
1787
+ + " output_tokens: "
1788
+ + str(output_tokens)
1789
+ )
1790
+
1791
+ except KeyError as e:
1792
+ print(f"Key error: {e} - Check the structure of response.usage_metadata")
1793
+
1794
+ return (
1795
+ responses,
1796
+ conversation_history,
1797
+ whole_conversation,
1798
+ whole_conversation_metadata,
1799
+ response_text,
1800
+ )
1801
+
1802
+
1803
+ def call_llm_with_markdown_table_checks(
1804
+ batch_prompts: List[str],
1805
+ system_prompt: str,
1806
+ conversation_history: List[dict],
1807
+ whole_conversation: List[str],
1808
+ whole_conversation_metadata: List[str],
1809
+ client: ai.Client | OpenAI,
1810
+ client_config: types.GenerateContentConfig,
1811
+ model_choice: str,
1812
+ temperature: float,
1813
+ reported_batch_no: int,
1814
+ local_model: object,
1815
+ tokenizer: object,
1816
+ bedrock_runtime: boto3.Session.client,
1817
+ model_source: str,
1818
+ MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
1819
+ assistant_prefill: str = "",
1820
+ master: bool = False,
1821
+ CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
1822
+ random_seed: int = seed,
1823
+ api_url: str = None,
1824
+ ) -> Tuple[List[ResponseObject], List[dict], List[str], List[str], str]:
1825
+ """
1826
+ Call the large language model with checks for a valid markdown table.
1827
+
1828
+ Parameters:
1829
+ - batch_prompts (List[str]): A list of prompts to be processed.
1830
+ - system_prompt (str): The system prompt.
1831
+ - conversation_history (List[dict]): The history of the conversation.
1832
+ - whole_conversation (List[str]): The complete conversation including prompts and responses.
1833
+ - whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1834
+ - client (ai.Client | OpenAI): The client object for running Gemini or Azure/OpenAI API calls.
1835
+ - client_config (types.GenerateContentConfig): Configuration for the model.
1836
+ - model_choice (str): The choice of model to use.
1837
+ - temperature (float): The temperature parameter for the model.
1838
+ - reported_batch_no (int): The reported batch number.
1839
+ - local_model (object): The local model to use.
1840
+ - tokenizer (object): The tokenizer to use.
1841
+ - bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
1842
+ - model_source (str): The source of the model, whether in AWS, Gemini, local, or inference-server.
1843
+ - MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
1844
+ - assistant_prefill (str, optional): The text to prefill the LLM response. Currently only working with AWS Claude calls.
1845
+ - master (bool, optional): Boolean to determine whether this call is for the master output table.
1846
+ - CHOSEN_LOCAL_MODEL_TYPE (str, optional): String to determine model type loaded.
1847
+ - random_seed (int, optional): The random seed used for LLM generation.
1848
+ - api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
1849
+
1850
+ Returns:
1851
+ - Tuple[List[ResponseObject], List[dict], List[str], List[str], str]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, the updated whole conversation metadata, and the response text.
1852
+ """
1853
+
1854
+ call_temperature = temperature # This is correct now with the fixed parameter name
1855
+
1856
+ # Update Gemini config with the new temperature settings
1857
+ client_config = types.GenerateContentConfig(
1858
+ temperature=call_temperature, max_output_tokens=max_tokens, seed=random_seed
1859
+ )
1860
+
1861
+ for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
1862
+ # Process requests to large language model
1863
+ (
1864
+ responses,
1865
+ conversation_history,
1866
+ whole_conversation,
1867
+ whole_conversation_metadata,
1868
+ response_text,
1869
+ ) = process_requests(
1870
+ batch_prompts,
1871
+ system_prompt,
1872
+ conversation_history,
1873
+ whole_conversation,
1874
+ whole_conversation_metadata,
1875
+ client,
1876
+ client_config,
1877
+ model_choice,
1878
+ call_temperature,
1879
+ bedrock_runtime,
1880
+ model_source,
1881
+ reported_batch_no,
1882
+ local_model,
1883
+ tokenizer=tokenizer,
1884
+ master=master,
1885
+ assistant_prefill=assistant_prefill,
1886
+ api_url=api_url,
1887
+ )
1888
+
1889
+ stripped_response = response_text.strip()
1890
+
1891
+ # Check if response meets our criteria (length and contains table) OR is "No change"
1892
+ if (
1893
+ len(stripped_response) > 120 and "|" in stripped_response
1894
+ ) or stripped_response.lower().startswith("no change"):
1895
+ if stripped_response.lower().startswith("no change"):
1896
+ print(f"Attempt {attempt + 1} produced 'No change' response.")
1897
+ else:
1898
+ print(f"Attempt {attempt + 1} produced response with markdown table.")
1899
+ break # Success - exit loop
1900
+
1901
+ # Increase temperature for next attempt
1902
+ call_temperature = temperature + (0.1 * (attempt + 1))
1903
+ print(
1904
+ f"Attempt {attempt + 1} resulted in invalid table: {stripped_response}. "
1905
+ f"Trying again with temperature: {call_temperature}"
1906
+ )
1907
+
1908
+ else: # This runs if no break occurred (all attempts failed)
1909
+ print(
1910
+ f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts"
1911
+ )
1912
+
1913
+ return (
1914
+ responses,
1915
+ conversation_history,
1916
+ whole_conversation,
1917
+ whole_conversation_metadata,
1918
+ stripped_response,
1919
+ )
1920
+
1921
+
1922
+ def create_missing_references_df(
1923
+ basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame
1924
+ ) -> pd.DataFrame:
1925
+ """
1926
+ Identifies references in basic_response_df that are not present in existing_reference_df.
1927
+ Returns a DataFrame with the missing references and the character count of their responses.
1928
+
1929
+ Args:
1930
+ basic_response_df (pd.DataFrame): DataFrame containing 'Reference' and 'Response' columns.
1931
+ existing_reference_df (pd.DataFrame): DataFrame containing 'Response References' column.
1932
+
1933
+ Returns:
1934
+ pd.DataFrame: A DataFrame with 'Missing Reference' and 'Response Character Count' columns.
1935
+ 'Response Character Count' will be 0 for empty strings and NaN for actual missing data.
1936
+ """
1937
+ # Ensure columns are treated as strings for robust comparison
1938
+ existing_references_unique = (
1939
+ existing_reference_df["Response References"].astype(str).unique()
1940
+ )
1941
+
1942
+ # Step 1: Identify all rows from basic_response_df that correspond to missing references
1943
+ # We want the entire row to access the 'Response' column later
1944
+ missing_data_rows = basic_response_df[
1945
+ ~basic_response_df["Reference"].astype(str).isin(existing_references_unique)
1946
+ ].copy() # .copy() to avoid SettingWithCopyWarning
1947
+
1948
+ # Step 2: Create the new DataFrame
1949
+ # Populate the 'Missing Reference' column directly
1950
+ missing_df = pd.DataFrame({"Missing Reference": missing_data_rows["Reference"]})
1951
+
1952
+ # Step 3: Calculate and add 'Response Character Count'
1953
+ # .str.len() works on Series of strings, handling empty strings (0) and NaN (NaN)
1954
+ missing_df["Response Character Count"] = missing_data_rows["Response"].str.len()
1955
+
1956
+ # Optional: Add the actual response text for easier debugging/inspection if needed
1957
+ # missing_df['Response Text'] = missing_data_rows['Response']
1958
+
1959
+ # Reset index to have a clean, sequential index for the new DataFrame
1960
+ missing_df = missing_df.reset_index(drop=True)
1961
+
1962
+ return missing_df
1963
+
1964
+
1965
+ def calculate_tokens_from_metadata(
1966
+ metadata_string: str, model_choice: str, model_name_map: dict
1967
+ ):
1968
+ """
1969
+ Calculate the number of input and output tokens for given queries based on metadata strings.
1970
+
1971
+ Args:
1972
+ metadata_string (str): A string containing all relevant metadata from the string.
1973
+ model_choice (str): A string describing the model name
1974
+ model_name_map (dict): A dictionary mapping model name to source
1975
+ """
1976
+
1977
+ model_name_map[model_choice]["source"]
1978
+
1979
+ # Regex to find the numbers following the keys in the "Query summary metadata" section
1980
+ # This ensures we get the final, aggregated totals for the whole query.
1981
+ input_regex = r"input_tokens: (\d+)"
1982
+ output_regex = r"output_tokens: (\d+)"
1983
+
1984
+ # re.findall returns a list of all matching strings (the captured groups).
1985
+ input_token_strings = re.findall(input_regex, metadata_string)
1986
+ output_token_strings = re.findall(output_regex, metadata_string)
1987
+
1988
+ # Convert the lists of strings to lists of integers and sum them up
1989
+ total_input_tokens = sum([int(token) for token in input_token_strings])
1990
+ total_output_tokens = sum([int(token) for token in output_token_strings])
1991
+
1992
+ number_of_calls = len(input_token_strings)
1993
+
1994
+ print(f"Found {number_of_calls} LLM call entries in metadata.")
1995
+ print("-" * 20)
1996
+ print(f"Total Input Tokens: {total_input_tokens}")
1997
+ print(f"Total Output Tokens: {total_output_tokens}")
1998
+
1999
+ return total_input_tokens, total_output_tokens, number_of_calls
tools/prompts.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # System prompt
3
+ ###
4
+
5
+ generic_system_prompt = """You are a researcher analysing responses from an open text dataset. You are analysing a single column from this dataset."""
6
+
7
+ system_prompt = """You are a researcher analysing responses from an open text dataset. You are analysing a single column from this dataset called '{column_name}'. {consultation_context}"""
8
+
9
+ markdown_additional_prompt = """ You will be given a request for a markdown table. You must respond with ONLY the markdown table. Do not include any introduction, explanation, or concluding text."""
10
+
11
+ ###
12
+ # Initial topic table prompt
13
+ ###
14
+ initial_table_system_prompt = system_prompt + markdown_additional_prompt
15
+
16
+ initial_table_assistant_prefill = "|"
17
+
18
+ default_response_reference_format = "In the next column named 'Response References', list each specific Response reference number that is relevant to the Subtopic, separated by commas. Do not write any other text in this column."
19
+
20
+ initial_table_prompt = """{validate_prompt_prefix}Your task is to create one new markdown table based on open text responses in the reponse table below.
21
+ In the first column named 'General topic', identify general topics relevant to responses. Create as many general topics as you can.
22
+ In the second column named 'Subtopic', list subtopics relevant to responses. Make the subtopics as specific as possible and make sure they cover every issue mentioned. The subtopic should never be empty.
23
+ {sentiment_choices}{response_reference_format}
24
+ In the final column named 'Summary', write a summary of the subtopic based on relevant responses - highlight specific issues that appear. {add_existing_topics_summary_format}
25
+ Do not add any other columns. Do not add any other text to your response. Only mention topics that are relevant to at least one response.
26
+
27
+ Response table:
28
+ {response_table}
29
+
30
+ New table:{previous_table_introduction}{previous_table}{validate_prompt_suffix}"""
31
+
32
+ ###
33
+ # Adding existing topics to consultation responses
34
+ ###
35
+
36
+ add_existing_topics_system_prompt = system_prompt + markdown_additional_prompt
37
+
38
+ add_existing_topics_assistant_prefill = "|"
39
+
40
+ force_existing_topics_prompt = """Create a new markdown table. In the first column named 'Placeholder', write 'Not assessed'. In the second column named 'Subtopics', assign Topics from the above table to Responses. Assign topics only if they are very relevant to the text of the Response. The assigned Subtopics should be chosen from the topics table above, exactly as written. Do not add any new topics, or modify existing topic names."""
41
+
42
+ allow_new_topics_prompt = """Create a new markdown table. In the first column named 'General topic', and the second column named 'Subtopic', assign General Topics and Subtopics to Responses. Assign topics from the Topics table above only if they are very relevant to the text of the Response. Fill in the General topic and Subtopic for the Topic if they do not already exist. If you find a new topic that does not exist in the Topics table, add a new row to the new table. Make the General topic and Subtopic as specific as possible. The subtopic should never be blank or empty."""
43
+
44
+ force_single_topic_prompt = """ Assign each response to one single topic only."""
45
+
46
+ add_existing_topics_prompt = """{validate_prompt_prefix}Your task is to create one new markdown table, assigning responses from the Response table below to topics.
47
+ {topic_assignment}{force_single_topic}
48
+ {sentiment_choices}{response_reference_format}
49
+ In the final column named 'Summary', write a summary of the Subtopic based on relevant responses - highlight specific issues that appear. {add_existing_topics_summary_format}
50
+ Do not add any other columns. Do not add any other text to your response. Only mention topics that are relevant to at least one response.
51
+
52
+ Choose from among the following topic names to assign to the responses, only if they are directly relevant to responses from the response table below:
53
+ {topics}
54
+
55
+ {response_table}
56
+
57
+ New table:{previous_table_introduction}{previous_table}{validate_prompt_suffix}"""
58
+
59
+ ###
60
+ # VALIDATION PROMPTS
61
+ ###
62
+ # These are prompts used to validate previous LLM outputs, and create corrected versions of the outputs if errors are found.
63
+ validation_system_prompt = system_prompt
64
+
65
+ validation_prompt_prefix_default = """The following instructions were previously provided to create an output table:\n'"""
66
+
67
+ previous_table_introduction_default = (
68
+ """'\n\nThe following output table was created based on the above instructions:\n"""
69
+ )
70
+
71
+ validation_prompt_suffix_default = """\n\nBased on the above information, you need to create a corrected version of the output table. Examples of issues to correct include:
72
+
73
+ - Remove rows where responses are not relevant to the assigned topic, or where responses are not relevant to any topic.
74
+ - Remove rows where a topic is not assigned to any specific response.
75
+ - If the current topic assignment does not cover all information in a response, assign responses to relevant topics from the suggested topics table, or create a new topic if necessary.
76
+ - Correct any false information in the summary column, which is a summary of the relevant response text.
77
+ {additional_validation_issues}
78
+ - Any other obvious errors that you can identify.
79
+
80
+ With the above issues in mind, create a new, corrected version of the markdown table below. If there are no issues to correct, write simply "No change". Return only the corrected table without additional text, or 'no change' alone."""
81
+
82
+ validation_prompt_suffix_struct_summary_default = """\n\nBased on the above information, you need to create a corrected version of the output table. Examples of issues to correct include:
83
+
84
+ - Any misspellings in the Main heading or Subheading columns
85
+ - Correct any false information in the summary column, which is a summary of the relevant response text.
86
+ {additional_validation_issues}
87
+ - Any other obvious errors that you can identify.
88
+
89
+ With the above issues in mind, create a new, corrected version of the markdown table below. If there are no issues to correct, write simply "No change". Return only the corrected table without additional text, or 'no change' alone."""
90
+
91
+ ###
92
+ # SENTIMENT CHOICES
93
+ ###
94
+
95
+ negative_neutral_positive_sentiment_prompt = (
96
+ "write the sentiment of the Subtopic: Negative, Neutral, or Positive"
97
+ )
98
+ negative_or_positive_sentiment_prompt = (
99
+ "write the sentiment of the Subtopic: Negative or Positive"
100
+ )
101
+ do_not_assess_sentiment_prompt = "write the text 'Not assessed'" # Not used anymore. Instead, the column is filled in automatically with 'Not assessed'
102
+ default_sentiment_prompt = (
103
+ "write the sentiment of the Subtopic: Negative, Neutral, or Positive"
104
+ )
105
+
106
+ ###
107
+ # STRUCTURED SUMMARY PROMPT
108
+ ###
109
+
110
+ structured_summary_prompt = """Your task is to write a structured summary for open text responses.
111
+
112
+ Create a new markdown table based on the response table below with the headings 'Main heading', 'Subheading' and 'Summary'.
113
+
114
+ For each of the responses in the Response table, you will create a row for each summary associated with each of the Main headings and Subheadings from the Headings table. If there is no Headings table, created your own headings. In the first and second columns, write a Main heading and Subheading from the Headings table. Then in Summary, write a detailed and comprehensive summary that covers all information relevant to the Main heading and Subheading on the same row.
115
+ {summary_format}
116
+
117
+ Do not add any other columns. Do not add any other text to your response.
118
+
119
+ {response_table}
120
+
121
+ Headings to structure the summary are in the following table:
122
+ {topics}
123
+
124
+ New table:"""
125
+
126
+ ###
127
+ # SUMMARISE TOPICS PROMPT
128
+ ###
129
+
130
+ summary_assistant_prefill = ""
131
+
132
+ summarise_topic_descriptions_system_prompt = system_prompt
133
+
134
+ summarise_topic_descriptions_prompt = """Your task is to make a consolidated summary of the text below. {summary_format}
135
+
136
+ Return only the summary and no other text:
137
+
138
+ {summaries}
139
+
140
+ Summary:"""
141
+
142
+ single_para_summary_format_prompt = "Return a concise summary up to one paragraph long that summarises only the most important themes from the original text"
143
+
144
+ two_para_summary_format_prompt = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text"
145
+
146
+ ###
147
+ # OVERALL SUMMARY PROMPTS
148
+ ###
149
+
150
+ summarise_everything_system_prompt = system_prompt
151
+
152
+ summarise_everything_prompt = """Below is a table that gives an overview of the main topics from a dataset of open text responses along with a description of each topic, and the number of responses that mentioned each topic:
153
+
154
+ '{topic_summary_table}'
155
+
156
+ Your task is to summarise the above table. {summary_format}. Return only the summary and no other text.
157
+
158
+ Summary:"""
159
+
160
+ comprehensive_summary_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format the output for Excel display using: **bold text** for main headings, β€’ bullet points for sub-items, and line breaks between sections. Avoid markdown symbols like # or ##."
161
+
162
+ comprehensive_summary_format_prompt_by_group = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Compare and contrast differences between the topics and themes from each Group. Format the output for Excel display using: **bold text** for main headings, β€’ bullet points for sub-items, and line breaks between sections. Avoid markdown symbols like # or ##."
163
+
164
+ # Alternative Excel formatting options
165
+ excel_rich_text_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format for Excel using: BOLD for main headings, bullet points (β€’) for sub-items, and line breaks between sections. Use simple text formatting that Excel can interpret."
166
+
167
+ excel_plain_text_format_prompt = "Return a comprehensive summary that covers all the important topics and themes described in the table. Structure the summary with General Topics as headings, with significant Subtopics described in bullet points below them in order of relative significance. Do not explicitly mention the Sentiment, Number of responses, or Group values. Do not use the words 'General topic' or 'Subtopic' directly in the summary. Format as plain text with clear structure: use ALL CAPS for main headings, bullet points (β€’) for sub-items, and line breaks between sections. Avoid any special formatting symbols."
168
+
169
+ ###
170
+ # LLM-BASED TOPIC DEDUPLICATION PROMPTS
171
+ ###
172
+
173
+ llm_deduplication_system_prompt = """You are an expert at analysing and consolidating topic categories. Your task is to identify semantically similar topics that should be merged together, even if they use different wording or synonyms."""
174
+
175
+ llm_deduplication_prompt = """You are given a table of topics with their General topics, Subtopics, and Sentiment classifications. Your task is to identify topics that are semantically similar and should be merged together. Only merge topics that are almost identical in terms of meaning - if in doubt, do not merge.
176
+
177
+ Analyse the following topics table and identify groups of topics that describe essentially the same concept but may use different words or phrases. For example:
178
+ - "Transportation issues" and "Public transport problems"
179
+ - "Housing costs" and "Rent prices"
180
+ - "Environmental concerns" and "Green issues"
181
+
182
+ Create a markdown table with the following columns:
183
+ 1. 'Original General topic' - The current general topic name
184
+ 2. 'Original Subtopic' - The current subtopic name
185
+ 3. 'Original Sentiment' - The current sentiment
186
+ 4. 'Merged General topic' - The consolidated general topic name (use the most descriptive)
187
+ 5. 'Merged Subtopic' - The consolidated subtopic name (use the most descriptive)
188
+ 6. 'Merged Sentiment' - The consolidated sentiment (use 'Mixed' if sentiments differ)
189
+ 7. 'Merge Reason' - Brief explanation of why these topics should be merged
190
+
191
+ Only include rows where topics should actually be merged. If a topic has no semantic duplicates, do not include it in the table. Produce only a markdown table in the format described above. Do not add any other text to your response.
192
+
193
+ Topics to analyse:
194
+ {topics_table}
195
+
196
+ Merged topics table:"""
197
+
198
+ llm_deduplication_prompt_with_candidates = """You are given a table of topics with their General topics, Subtopics, and Sentiment classifications. Your task is to identify topics that are semantically similar and should be merged together, even if they use different wording.
199
+
200
+ Additionally, you have been provided with a list of candidate topics that represent preferred topic categories. When merging topics, prioritise fitting similar topics into these existing candidate categories rather than creating new ones. Only merge topics that are almost identical in terms of meaning - if in doubt, do not merge.
201
+
202
+ Analyse the following topics table and identify groups of topics that describe essentially the same concept but may use different words or phrases. For example:
203
+ - "Transportation issues" and "Public transport problems"
204
+ - "Housing costs" and "Rent prices"
205
+ - "Environmental concerns" and "Green issues"
206
+
207
+ When merging topics, consider the candidate topics provided below and try to map similar topics to these preferred categories when possible.
208
+
209
+ Create a markdown table with the following columns:
210
+ 1. 'Original General topic' - The current general topic name
211
+ 2. 'Original Subtopic' - The current subtopic name
212
+ 3. 'Original Sentiment' - The current sentiment
213
+ 4. 'Merged General topic' - The consolidated general topic name (prefer candidate topics when similar)
214
+ 5. 'Merged Subtopic' - The consolidated subtopic name (prefer candidate topics when similar)
215
+ 6. 'Merged Sentiment' - The consolidated sentiment (use 'Mixed' if sentiments differ)
216
+ 7. 'Merge Reason' - Brief explanation of why these topics should be merged
217
+
218
+ Only include rows where topics should actually be merged. If a topic has no semantic duplicates, do not include it in the table. Produce only a markdown table in the format described above. Do not add any other text to your response.
219
+
220
+ Topics to analyse:
221
+ {topics_table}
222
+
223
+ Candidate topics to consider for mapping:
224
+ {candidate_topics_table}
225
+
226
+ Merged topics table:"""
227
+
228
+ ###
229
+ # VERIFY EXISTING DESCRIPTIONS/TITLES - Currently not used
230
+ ###
231
+
232
+ verify_assistant_prefill = "|"
233
+
234
+ verify_titles_system_prompt = system_prompt
235
+
236
+ verify_titles_prompt = """Response numbers alongside the Response text and assigned descriptions are shown in the table below:
237
+ {response_table}
238
+
239
+ The criteria for a suitable description for these responses is that they should be readable, concise, and fully encapsulate the main subject of the response.
240
+
241
+ Create a markdown table with four columns.
242
+ The first column is 'Response References', and should contain just the response number under consideration.
243
+ The second column is 'Is this a suitable description', answer the question with 'Yes' or 'No', with no other text.
244
+ The third column is 'Explanation', give a short explanation for your response in the second column.
245
+ The fourth column is 'Alternative description', suggest an alternative description for the response that meet the criteria stated above.
246
+ Do not add any other text to your response.
247
+
248
+ Output markdown table:"""
249
+
250
+
251
+ ## The following didn't work well in testing and so is not currently used
252
+
253
+ create_general_topics_system_prompt = system_prompt
254
+
255
+ create_general_topics_prompt = """Subtopics known to be relevant to this dataset are shown in the following Topics table:
256
+ {topics}
257
+
258
+ Your task is to create a General topic name for each Subtopic. The new Topics table should have the columns 'General topic' and 'Subtopic' only. Write a 'General topic' text label relevant to the Subtopic next to it in the new table. The text label should describe the general theme of the Subtopic. Do not add any other text, thoughts, or notes to your response.
259
+
260
+ New Topics table:"""
windows_install_llama-cpp-python.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ #How to build llama-cpp-python on Windows: Step-by-Step Guide
4
+
5
+ First, you need to set up a proper C++ development environment.
6
+
7
+ # Step 1: Install the C++ Compiler
8
+ Scroll down the page past the main programs to "Tools for Visual Studio" and download the "Build Tools for Visual Studio". This is a standalone installer that gives you the C++ compiler and libraries without installing the full Visual Studio IDE.
9
+
10
+ Run the installer. In the "Workloads" tab, check the box for "Desktop development with C++".
11
+
12
+ MSVC v143
13
+ C++ ATL
14
+ C++ Profiling tools
15
+ C++ CMake tools for Windows
16
+ C++ MFC
17
+ C++ Modules
18
+ Windows 10 SDK (10.0.20348.0)
19
+
20
+ Proceed with the installation.
21
+
22
+ Need to use 'x64 Native Tools Command Prompt for VS 2022' to install the below. Run as administrator
23
+
24
+ # Step 2: Install CMake
25
+ Go to the CMake download page: https://cmake.org/download
26
+
27
+ Download the latest Windows installer (e.g., cmake-x.xx.x-windows-x86_64.msi).
28
+
29
+ Run the installer. Crucially, when prompted, select the option to "Add CMake to the system PATH for all users" or "for the current user." This allows you to run cmake from any command prompt.
30
+
31
+
32
+ # Step 3: (FOR CPU INFERENCE ONLY) Download and Place OpenBLAS
33
+ This is often the trickiest part.
34
+
35
+ Go to the OpenBLAS releases on GitHub.
36
+
37
+ Find a recent release and download the pre-compiled version for Windows. It will typically be a file named something like OpenBLAS-0.3.21-x64.zip (the version number will change). Make sure you get the 64-bit (x64) version if you are using 64-bit Python.
38
+
39
+ Create a folder somewhere easily accessible, for example, C:\libs\.
40
+
41
+ Extract the contents of the OpenBLAS zip file into that folder. Your final directory structure should look something like this:
42
+
43
+ C:\libs\OpenBLAS\
44
+ β”œβ”€β”€ bin\
45
+ β”œβ”€β”€ include\
46
+ └── lib\
47
+
48
+ ## 3.b. Install Chocolatey
49
+ https://chocolatey.org/install
50
+
51
+ Step 1: Install Chocolatey (if you don't already have it)
52
+ Open PowerShell as an Administrator. (Right-click the Start Menu -> "Windows PowerShell (Admin)" or "Terminal (Admin)").
53
+
54
+ Run the following command to install Chocolatey. It's a single, long line:
55
+
56
+ Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
57
+
58
+ Once it's done, close the Administrator PowerShell window.
59
+
60
+ Step 2: Install pkg-config-lite using Chocolatey
61
+ IMPORTANT: Open a NEW command prompt or PowerShell window (as a regular user is fine). This is necessary so it recognises the new choco command.
62
+
63
+ Run the following command in console to install a lightweight version of pkg-config:
64
+
65
+ choco install pkgconfiglite
66
+
67
+ Approve the installation by typing Y or A if prompted.
68
+
69
+ # Step 4: Run the Installation Command
70
+ Now you have all the pieces. The final step is to run the command in a terminal that is aware of your new build environment.
71
+
72
+ Open the "Developer Command Prompt for VS" from your Start Menu. This is important! This special command prompt automatically configures all the necessary paths for the C++ compiler.
73
+
74
+ ## For CPU
75
+
76
+ set PKG_CONFIG_PATH=C:\<path-to-openblas>\OpenBLAS\lib\pkgconfig # Set this in environment variables
77
+
78
+ pip install llama-cpp-python==0.3.16 --force-reinstall --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
79
+
80
+ pip install llama-cpp-python==0.3.16 --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/Users/s_cas/libs/OpenBLAS/include;-DBLAS_LIBRARIES=C:/Users/s_cas/OpenBLAS/lib/libopenblas.lib";-DPKG_CONFIG_PATH=C:/users/s_cas/openblas/lib/pkgconfig"
81
+
82
+ or to make a wheel:
83
+
84
+ pip install llama-cpp-python==0.3.16 --wheel-dir dist --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
85
+
86
+ pip wheel llama-cpp-python==0.3.16 --wheel-dir dist --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/Users/<user>/libs/OpenBLAS/include;-DBLAS_LIBRARIES=C:/Users/<user>/libs/OpenBLAS/lib/libopenblas.lib"
87
+
88
+
89
+
90
+ ## With Cuda (NVIDIA GPUs only)
91
+
92
+ Make sure that the have the CUDA 12.4 toolkit for windows installed: https://developer.nvidia.com/cuda-12-4-0-download-archive
93
+
94
+ ### Make sure you are using the x64 version of Developer command tools for the below, e.g. 'x64 Native Tools Command Prompt for VS 2022' ###
95
+
96
+ Use NVIDIA GPU (cuBLAS): If you have an NVIDIA GPU, using cuBLAS is often easier because the CUDA Toolkit installer handles most of the setup.
97
+
98
+ Install the NVIDIA CUDA Toolkit.
99
+
100
+ Run the install command specifying cuBLAS (for faster inference):
101
+
102
+ pip install llama-cpp-python==0.3.16 --force-reinstall --verbose -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
103
+
104
+ If you want to create a new wheel to help with future installs, you can run:
105
+
106
+ cd first to a folder that you have edit access for
107
+
108
+ pip wheel llama-cpp-python==0.3.16 --wheel-dir dist --verbose -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
109
+
110
+
111
+