Akshat Bhatt commited on
Commit
e2e0c18
·
0 Parent(s):

added code

Browse files
.dockerignore ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ env/
14
+ ENV/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+ *~
22
+
23
+ # OS
24
+ .DS_Store
25
+ Thumbs.db
26
+
27
+ # Logs
28
+ *.log
29
+ logs/
30
+
31
+ # Git
32
+ .git/
33
+ .gitignore
34
+
35
+ # Environment files (will be set via GCP Secrets Manager or env vars)
36
+ .env
37
+ .env.local
38
+
39
+ # Data files (if too large, consider mounting as volume or using GCS)
40
+ # Uncomment if you want to exclude data files
41
+ # data/*.csv
42
+
43
+ # Reports (optional - uncomment if you don't need reports in container)
44
+ # reports/
45
+
46
+ # Test files
47
+ test_*.py
48
+ *_test.py
49
+
50
+ # Jupyter notebooks
51
+ *.ipynb
52
+ *.ipynb_checkpoints/
53
+
54
+ # Documentation
55
+ *.md
56
+ README.md
57
+ docs/
58
+
59
+ # Docker files
60
+ Dockerfile*
61
+ docker-compose*.yml
62
+ .dockerignore
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ GEMINI_API_KEY=AIzaSyBMFn9f05ghcNH3Sl-LdQdHW6MlqCtPYn0
2
+ GROQ_API_KEY=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
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ "reports/"
DEPLOYMENT.md ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCP Deployment Guide for Phishing Detection API
2
+
3
+ This guide explains how to deploy the Phishing Detection API to Google Cloud Platform (GCP) using Docker and Cloud Run.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Google Cloud Platform Account**: Ensure you have a GCP account and billing enabled
8
+ 2. **gcloud CLI**: Install the [Google Cloud SDK](https://cloud.google.com/sdk/docs/install)
9
+ 3. **Docker**: Install [Docker Desktop](https://www.docker.com/products/docker-desktop) or Docker Engine
10
+ 4. **GROQ API Key**: Obtain your API key from [Groq](https://console.groq.com/)
11
+
12
+ ## Quick Start
13
+
14
+ ### Option 1: Using Cloud Build (Recommended)
15
+
16
+ 1. **Set your GCP project**:
17
+ ```bash
18
+ gcloud config set project YOUR_PROJECT_ID
19
+ ```
20
+
21
+ 2. **Run the deployment script**:
22
+ ```bash
23
+ # For Linux/Mac
24
+ chmod +x deploy.sh
25
+ ./deploy.sh YOUR_PROJECT_ID us-central1
26
+
27
+ # For Windows PowerShell
28
+ .\deploy.ps1 -ProjectId YOUR_PROJECT_ID -Region us-central1
29
+ ```
30
+
31
+ The script will:
32
+ - Enable required APIs
33
+ - Create the GROQ_API_KEY secret in Secret Manager
34
+ - Build the Docker image
35
+ - Deploy to Cloud Run
36
+
37
+ ### Option 2: Manual Deployment
38
+
39
+ #### Step 1: Set Up GCP Project
40
+
41
+ ```bash
42
+ # Set your project
43
+ gcloud config set project YOUR_PROJECT_ID
44
+
45
+ # Enable required APIs
46
+ gcloud services enable cloudbuild.googleapis.com
47
+ gcloud services enable run.googleapis.com
48
+ gcloud services enable containerregistry.googleapis.com
49
+ gcloud services enable secretmanager.googleapis.com
50
+ ```
51
+
52
+ #### Step 2: Create Secret for GROQ API Key
53
+
54
+ ```bash
55
+ # Create the secret
56
+ echo -n "your-groq-api-key" | gcloud secrets create GROQ_API_KEY \
57
+ --data-file=- \
58
+ --replication-policy="automatic"
59
+
60
+ # Grant Cloud Run access to the secret
61
+ PROJECT_NUMBER=$(gcloud projects describe YOUR_PROJECT_ID --format="value(projectNumber)")
62
+ gcloud secrets add-iam-policy-binding GROQ_API_KEY \
63
+ --member="serviceAccount:$PROJECT_NUMBER-compute@developer.gserviceaccount.com" \
64
+ --role="roles/secretmanager.secretAccessor"
65
+ ```
66
+
67
+ #### Step 3: Build and Deploy
68
+
69
+ ```bash
70
+ # Using Cloud Build (recommended)
71
+ gcloud builds submit --config=cloudbuild.yaml
72
+
73
+ # Or build locally and push
74
+ docker build -t gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest .
75
+ docker push gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest
76
+
77
+ # Deploy to Cloud Run
78
+ gcloud run deploy phishing-detection-api \
79
+ --image gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest \
80
+ --region us-central1 \
81
+ --platform managed \
82
+ --allow-unauthenticated \
83
+ --memory 4Gi \
84
+ --cpu 2 \
85
+ --timeout 300 \
86
+ --max-instances 10 \
87
+ --set-env-vars PYTHONUNBUFFERED=1 \
88
+ --set-secrets GROQ_API_KEY=GROQ_API_KEY:latest
89
+ ```
90
+
91
+ ## Configuration Options
92
+
93
+ ### Cloud Run Settings
94
+
95
+ The deployment uses these default settings (configured in `cloudbuild.yaml`):
96
+
97
+ - **Memory**: 4GB (required for ML models)
98
+ - **CPU**: 2 vCPUs
99
+ - **Timeout**: 300 seconds (5 minutes)
100
+ - **Max Instances**: 10
101
+ - **Region**: us-central1 (change in cloudbuild.yaml or deploy command)
102
+
103
+ ### Adjusting Resources
104
+
105
+ If you need more resources, modify `cloudbuild.yaml`:
106
+
107
+ ```yaml
108
+ - '--memory'
109
+ - '8Gi' # Increase memory for larger models
110
+ - '--cpu'
111
+ - '4' # Increase CPU for faster inference
112
+ ```
113
+
114
+ Or deploy with custom settings:
115
+
116
+ ```bash
117
+ gcloud run deploy phishing-detection-api \
118
+ --image gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest \
119
+ --memory 8Gi \
120
+ --cpu 4 \
121
+ --timeout 600 \
122
+ --max-instances 20
123
+ ```
124
+
125
+ ## Verifying Deployment
126
+
127
+ 1. **Check service status**:
128
+ ```bash
129
+ gcloud run services describe phishing-detection-api --region us-central1
130
+ ```
131
+
132
+ 2. **Get service URL**:
133
+ ```bash
134
+ SERVICE_URL=$(gcloud run services describe phishing-detection-api \
135
+ --region us-central1 \
136
+ --format="value(status.url)")
137
+ echo $SERVICE_URL
138
+ ```
139
+
140
+ 3. **Test health endpoint**:
141
+ ```bash
142
+ curl $SERVICE_URL/health
143
+ ```
144
+
145
+ 4. **View logs**:
146
+ ```bash
147
+ gcloud run services logs read phishing-detection-api --region us-central1
148
+ ```
149
+
150
+ ## API Endpoints
151
+
152
+ Once deployed, your service will have these endpoints:
153
+
154
+ - **Root**: `GET /` - API information
155
+ - **Health Check**: `GET /health` - Service health and model status
156
+ - **Prediction**: `POST /predict` - Main prediction endpoint
157
+ - **API Docs**: `GET /docs` - Interactive API documentation (Swagger UI)
158
+
159
+ ## Testing the API
160
+
161
+ ### Using curl:
162
+
163
+ ```bash
164
+ curl -X POST "$SERVICE_URL/predict" \
165
+ -H "Content-Type: application/json" \
166
+ -d '{
167
+ "sender": "test@example.com",
168
+ "subject": "Urgent Action Required",
169
+ "text": "Please verify your account at http://suspicious-site.com",
170
+ "metadata": {}
171
+ }'
172
+ ```
173
+
174
+ ### Using Python:
175
+
176
+ ```python
177
+ import requests
178
+
179
+ url = "YOUR_SERVICE_URL/predict"
180
+ payload = {
181
+ "sender": "test@example.com",
182
+ "subject": "Urgent Action Required",
183
+ "text": "Please verify your account at http://suspicious-site.com",
184
+ "metadata": {}
185
+ }
186
+
187
+ response = requests.post(url, json=payload)
188
+ print(response.json())
189
+ ```
190
+
191
+ ## Monitoring and Logging
192
+
193
+ ### View Logs in Console
194
+
195
+ 1. Go to [Cloud Run Console](https://console.cloud.google.com/run)
196
+ 2. Click on your service: `phishing-detection-api`
197
+ 3. Navigate to the "Logs" tab
198
+
199
+ ### View Logs via CLI
200
+
201
+ ```bash
202
+ gcloud run services logs read phishing-detection-api --region us-central1 --limit 50
203
+ ```
204
+
205
+ ### Set Up Alerts
206
+
207
+ 1. Go to [Cloud Monitoring](https://console.cloud.google.com/monitoring)
208
+ 2. Create alerts for:
209
+ - Error rate
210
+ - Request latency
211
+ - Memory usage
212
+ - CPU utilization
213
+
214
+ ## Troubleshooting
215
+
216
+ ### Container fails to start
217
+
218
+ - Check logs: `gcloud run services logs read phishing-detection-api --region us-central1`
219
+ - Verify models are present in the container
220
+ - Check memory limits (may need to increase)
221
+
222
+ ### Models not loading
223
+
224
+ - Ensure all model files are included in the Docker image
225
+ - Check model paths in `config.py`
226
+ - Verify model files exist in `models/`, `finetuned_bert/`, and `Message_model/final_semantic_model/`
227
+
228
+ ### GROQ API errors
229
+
230
+ - Verify the secret is correctly set: `gcloud secrets versions access latest --secret="GROQ_API_KEY"`
231
+ - Check IAM permissions for the Cloud Run service account
232
+ - Verify the API key is valid
233
+
234
+ ### High memory usage
235
+
236
+ - Increase memory allocation in Cloud Run settings
237
+ - Consider using model quantization
238
+ - Check for memory leaks in the application
239
+
240
+ ## Cost Optimization
241
+
242
+ 1. **Set minimum instances to 0**: Scales to zero when not in use
243
+ ```bash
244
+ gcloud run services update phishing-detection-api \
245
+ --min-instances 0 \
246
+ --max-instances 10
247
+ ```
248
+
249
+ 2. **Use appropriate instance sizes**: Start with smaller instances and scale up if needed
250
+
251
+ 3. **Enable request concurrency**: Reduce number of instances needed
252
+ ```bash
253
+ gcloud run services update phishing-detection-api \
254
+ --concurrency 10
255
+ ```
256
+
257
+ ## Security Considerations
258
+
259
+ 1. **Authentication**: Currently deployed as public. Consider adding authentication:
260
+ ```bash
261
+ gcloud run services update phishing-detection-api \
262
+ --no-allow-unauthenticated
263
+ ```
264
+
265
+ 2. **API Keys**: Store sensitive keys in Secret Manager (already configured)
266
+
267
+ 3. **VPC**: Consider deploying in a VPC for additional network isolation
268
+
269
+ ## Updating the Service
270
+
271
+ To update after code changes:
272
+
273
+ ```bash
274
+ # Rebuild and deploy
275
+ gcloud builds submit --config=cloudbuild.yaml
276
+
277
+ # Or manually
278
+ docker build -t gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest .
279
+ docker push gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest
280
+ gcloud run deploy phishing-detection-api \
281
+ --image gcr.io/YOUR_PROJECT_ID/phishing-detection-api:latest \
282
+ --region us-central1
283
+ ```
284
+
285
+ ## Local Testing with Docker
286
+
287
+ Before deploying to GCP, test locally:
288
+
289
+ ```bash
290
+ # Build the image
291
+ docker build -t phishing-detection-api:local .
292
+
293
+ # Run with environment variable
294
+ docker run -p 8000:8000 \
295
+ -e GROQ_API_KEY=your-api-key \
296
+ phishing-detection-api:local
297
+
298
+ # Test
299
+ curl http://localhost:8000/health
300
+ ```
301
+
302
+ ## Support
303
+
304
+ For issues or questions:
305
+ 1. Check the logs first
306
+ 2. Verify all prerequisites are met
307
+ 3. Ensure models are properly loaded
308
+ 4. Review the API documentation at `/docs` endpoint
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 slim image as base
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for SSL, whois, and other required tools
8
+ RUN apt-get update && apt-get install -y \
9
+ gcc \
10
+ g++ \
11
+ curl \
12
+ ca-certificates \
13
+ && rm -rf /var/lib/apt/lists/* \
14
+ && curl --version
15
+
16
+ # Copy requirements first for better caching
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy the entire Model_sprint_2 directory
24
+ COPY . .
25
+
26
+ # Set environment variables
27
+ ENV PYTHONUNBUFFERED=1
28
+ ENV PYTHONDONTWRITEBYTECODE=1
29
+
30
+ # Expose the port the app runs on
31
+ EXPOSE 8000
32
+
33
+ # Health check
34
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
35
+ CMD curl -f http://localhost:8000/health || exit 1
36
+
37
+ # Run the application using uvicorn
38
+ CMD ["python", "-m", "uvicorn", "app1:app", "--host", "0.0.0.0", "--port", "8000"]
Message_model/final_semantic_model/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[MASK]": 128000
3
+ }
Message_model/final_semantic_model/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DebertaV2ForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "dtype": "float32",
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-07,
13
+ "legacy": true,
14
+ "max_position_embeddings": 512,
15
+ "max_relative_positions": -1,
16
+ "model_type": "deberta-v2",
17
+ "norm_rel_ebd": "layer_norm",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 0,
21
+ "pooler_dropout": 0,
22
+ "pooler_hidden_act": "gelu",
23
+ "pooler_hidden_size": 768,
24
+ "pos_att_type": [
25
+ "p2c",
26
+ "c2p"
27
+ ],
28
+ "position_biased_input": false,
29
+ "position_buckets": 256,
30
+ "relative_attention": true,
31
+ "share_att_key": true,
32
+ "transformers_version": "4.57.1",
33
+ "type_vocab_size": 0,
34
+ "vocab_size": 128100
35
+ }
Message_model/final_semantic_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a960912824e3268b5fdf0217b8b5d7ea5a3817a6437c01b5be36e5aa94808745
3
+ size 737719272
Message_model/final_semantic_model/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": {
9
+ "content": "[UNK]",
10
+ "lstrip": false,
11
+ "normalized": true,
12
+ "rstrip": false,
13
+ "single_word": false
14
+ }
15
+ }
Message_model/final_semantic_model/spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
Message_model/final_semantic_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
Message_model/final_semantic_model/tokenizer_config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "128000": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "[CLS]",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": false,
48
+ "eos_token": "[SEP]",
49
+ "extra_special_tokens": {},
50
+ "mask_token": "[MASK]",
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "sp_model_kwargs": {},
55
+ "split_by_punct": false,
56
+ "tokenizer_class": "DebertaV2Tokenizer",
57
+ "unk_token": "[UNK]",
58
+ "vocab_type": "spm"
59
+ }
Message_model/final_semantic_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14bca844639ea9d701739f1946f783acd435eceecb6fb1e12773a8a645269dd2
3
+ size 5777
Message_model/predict.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from scipy.special import softmax
5
+ import os
6
+
7
+ class PhishingPredictor:
8
+ def __init__(self, model_path="final_semantic_model"):
9
+ self.model_path = model_path
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ if not os.path.exists(model_path):
13
+ raise FileNotFoundError(f"Model not found at {model_path}. Please run train.py first.")
14
+
15
+ print(f"Loading model from {model_path}...")
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+ self.model = AutoModelForSequenceClassification.from_pretrained(
18
+ model_path,
19
+ use_safetensors=True
20
+ )
21
+ self.model.to(self.device)
22
+ self.model.eval()
23
+ print(f"Model loaded successfully on {self.device}")
24
+
25
+ def predict(self, text):
26
+ if not text or not text.strip():
27
+ return {
28
+ "text": text,
29
+ "phishing_probability": 0.0,
30
+ "prediction": "ham",
31
+ "confidence": "low"
32
+ }
33
+
34
+ inputs = self.tokenizer(
35
+ text,
36
+ truncation=True,
37
+ padding=True,
38
+ max_length=128,
39
+ return_tensors="pt"
40
+ )
41
+
42
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
43
+
44
+ with torch.no_grad():
45
+ outputs = self.model(**inputs)
46
+ logits = outputs.logits.cpu().numpy()
47
+
48
+ probabilities = softmax(logits, axis=1)[0]
49
+
50
+ phishing_prob = float(probabilities[1])
51
+
52
+ prediction = "phishing" if phishing_prob > 0.5 else "ham"
53
+
54
+ confidence_score = max(phishing_prob, 1 - phishing_prob)
55
+ if confidence_score > 0.8:
56
+ confidence = "high"
57
+ elif confidence_score > 0.6:
58
+ confidence = "medium"
59
+ else:
60
+ confidence = "low"
61
+
62
+ return {
63
+ "text": text,
64
+ "phishing_probability": round(phishing_prob, 4),
65
+ "ham_probability": round(float(probabilities[0]), 4),
66
+ "prediction": prediction,
67
+ "confidence": confidence,
68
+ "confidence_score": round(confidence_score, 4)
69
+ }
70
+
71
+ def predict_batch(self, texts):
72
+ results = []
73
+ for text in texts:
74
+ results.append(self.predict(text))
75
+ return results
76
+
77
+
78
+ def main():
79
+ try:
80
+ predictor = PhishingPredictor()
81
+
82
+ print("\n" + "="*60)
83
+ print("SMS PHISHING DETECTION SYSTEM")
84
+ print("="*60)
85
+ print("Enter SMS messages to analyze (type 'quit' to exit)")
86
+ print("Type 'batch' to analyze multiple messages at once")
87
+ print("-"*60)
88
+
89
+ while True:
90
+ user_input = input("\nEnter SMS message: ").strip()
91
+
92
+ if user_input.lower() in ['quit', 'exit', 'q']:
93
+ print("Goodbye!")
94
+ break
95
+
96
+ elif user_input.lower() == 'batch':
97
+ print("\nBatch mode - Enter messages (empty line to finish):")
98
+ messages = []
99
+ while True:
100
+ msg = input(f"Message {len(messages) + 1}: ").strip()
101
+ if not msg:
102
+ break
103
+ messages.append(msg)
104
+
105
+ if messages:
106
+ results = predictor.predict_batch(messages)
107
+ print(f"\n{'='*60}")
108
+ print("BATCH RESULTS")
109
+ print(f"{'='*60}")
110
+ for i, result in enumerate(results, 1):
111
+ print(f"\nMessage {i}: {result['text'][:50]}...")
112
+ print(f"Prediction: {result['prediction'].upper()}")
113
+ print(f"Phishing Probability: {result['phishing_probability']:.1%}")
114
+ print(f"Confidence: {result['confidence'].upper()}")
115
+ print("-" * 40)
116
+ else:
117
+ print("No messages entered.")
118
+
119
+ elif user_input:
120
+ result = predictor.predict(user_input)
121
+
122
+ print(f"\n{'='*60}")
123
+ print("PREDICTION RESULT")
124
+ print(f"{'='*60}")
125
+ print(f"Message: {result['text']}")
126
+ print(f"Prediction: {result['prediction'].upper()}")
127
+ print(f"Phishing Probability: {result['phishing_probability']:.1%}")
128
+ print(f"Ham Probability: {result['ham_probability']:.1%}")
129
+ print(f"Confidence: {result['confidence'].upper()} ({result['confidence_score']:.1%})")
130
+
131
+ prob = result['phishing_probability']
132
+ if prob > 0.7:
133
+ print("🚨 HIGH RISK - Likely phishing!")
134
+ elif prob > 0.3:
135
+ print("⚠️ MEDIUM RISK - Be cautious")
136
+ else:
137
+ print("✅ LOW RISK - Appears legitimate")
138
+
139
+ else:
140
+ print("Please enter a message or 'quit' to exit.")
141
+
142
+ except FileNotFoundError as e:
143
+ print(f"Error: {e}")
144
+ print("Please run train.py first to create the model.")
145
+ except Exception as e:
146
+ print(f"An error occurred: {e}")
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
Message_model/train.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import (
9
+ accuracy_score,
10
+ f1_score,
11
+ classification_report,
12
+ confusion_matrix
13
+ )
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForSequenceClassification,
17
+ Trainer,
18
+ TrainingArguments
19
+ )
20
+ from transformers.trainer_utils import get_last_checkpoint
21
+ from scipy.special import softmax
22
+
23
+ # --- 1. Check for CUDA (GPU) ---
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"--- 1. Using device: {device} ---")
26
+ if device == "cpu":
27
+ print("--- WARNING: CUDA not available. Training will run on CPU and will be very slow. ---")
28
+ print("---------------------------------")
29
+ # --- End CUDA Check ---
30
+
31
+ MODEL_NAME = "microsoft/deberta-v3-base"
32
+ FINAL_MODEL_DIR = "final_semantic_model"
33
+ REPORT_DIR = "evaluation_report"
34
+ CHECKPOINT_DIR = "training_checkpoints"
35
+
36
+ os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
37
+ os.makedirs(REPORT_DIR, exist_ok=True)
38
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
39
+
40
+ print("--- 2. Loading and splitting dataset ---")
41
+ try:
42
+ df = pd.read_csv("dataset.csv")
43
+ except FileNotFoundError:
44
+ print("Error: dataset.csv not found.")
45
+ print("Please make sure the file is in the same directory as this script.")
46
+ df = pd.DataFrame(columns=['ext_type', 'text'])
47
+ exit()
48
+
49
+ df.rename(columns={"ext_type": "label"}, inplace=True)
50
+ df['label'] = df['label'].map({'spam': 1, 'ham': 0})
51
+ df.dropna(subset=['label', 'text'], inplace=True)
52
+ df['label'] = df['label'].astype(int)
53
+
54
+ if len(df['label'].unique()) < 2:
55
+ print("Error: The dataset must contain both 'ham' (0) and 'spam' (1) labels.")
56
+ print(f"Found labels: {df['label'].unique()}")
57
+ print("Please update dataset.csv with examples for both classes.")
58
+ exit()
59
+
60
+ train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df['label'])
61
+ val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'])
62
+
63
+ print(f"Total examples: {len(df)}")
64
+ print(f"Training examples: {len(train_df)}")
65
+ print(f"Validation examples: {len(val_df)}")
66
+ print(f"Test examples: {len(test_df)}")
67
+ print("---------------------------------")
68
+
69
+
70
+ print("--- 3. Loading model and tokenizer ---")
71
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
72
+ model = AutoModelForSequenceClassification.from_pretrained(
73
+ MODEL_NAME,
74
+ num_labels=2,
75
+ use_safetensors=True # Use secure safetensors format to avoid torch.load error
76
+ )
77
+ print("---------------------------------")
78
+
79
+
80
+ class PhishingDataset(torch.utils.data.Dataset):
81
+ def __init__(self, texts, labels, tokenizer):
82
+ self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=128)
83
+ self.labels = labels
84
+
85
+ def __getitem__(self, idx):
86
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
87
+ item['labels'] = torch.tensor(self.labels[idx])
88
+ return item
89
+
90
+ def __len__(self):
91
+ return len(self.labels)
92
+
93
+ train_dataset = PhishingDataset(train_df['text'].tolist(), train_df['label'].tolist(), tokenizer)
94
+ val_dataset = PhishingDataset(val_df['text'].tolist(), val_df['label'].tolist(), tokenizer)
95
+ test_dataset = PhishingDataset(test_df['text'].tolist(), test_df['label'].tolist(), tokenizer)
96
+
97
+
98
+ def compute_metrics(pred):
99
+ labels = pred.label_ids
100
+ preds = pred.predictions.argmax(-1)
101
+ f1 = f1_score(labels, preds, average="weighted")
102
+ acc = accuracy_score(labels, preds)
103
+ return {"accuracy": acc, "f1": f1}
104
+
105
+
106
+ print("--- 4. Starting model training ---")
107
+ training_args = TrainingArguments(
108
+ output_dir=CHECKPOINT_DIR,
109
+ num_train_epochs=3,
110
+ per_device_train_batch_size=32,
111
+ per_device_eval_batch_size=32,
112
+ warmup_steps=50,
113
+ weight_decay=0.01,
114
+ logging_dir='./logs',
115
+ logging_steps=10,
116
+ eval_strategy="steps",
117
+ eval_steps=10,
118
+ save_strategy="steps",
119
+ save_steps=10,
120
+ load_best_model_at_end=True,
121
+ metric_for_best_model="f1",
122
+ save_total_limit=2,
123
+ no_cuda=(device == "cpu"),
124
+ save_safetensors=True # This saves new checkpoints securely
125
+ )
126
+
127
+ trainer = Trainer(
128
+ model=model,
129
+ args=training_args,
130
+ train_dataset=train_dataset,
131
+ eval_dataset=val_dataset,
132
+ compute_metrics=compute_metrics,
133
+ )
134
+
135
+ # This logic automatically detects if a checkpoint exists
136
+ last_checkpoint = get_last_checkpoint(CHECKPOINT_DIR)
137
+ if last_checkpoint:
138
+ print(f"--- Resuming training from: {last_checkpoint} ---")
139
+ else:
140
+ print("--- No checkpoint found. Starting training from scratch. ---")
141
+
142
+ # Pass the found checkpoint (or None) to the trainer
143
+ trainer.train(resume_from_checkpoint=last_checkpoint)
144
+
145
+ print("--- Training finished ---")
146
+ print("---------------------------------")
147
+
148
+
149
+ print(f"--- 5. Saving best model to {FINAL_MODEL_DIR} ---")
150
+ trainer.save_model(FINAL_MODEL_DIR)
151
+ tokenizer.save_pretrained(FINAL_MODEL_DIR)
152
+ print("--- Model saved ---")
153
+ print("---------------------------------")
154
+
155
+
156
+ print(f"--- 6. Generating report on TEST set ---")
157
+ model_for_eval = AutoModelForSequenceClassification.from_pretrained(
158
+ FINAL_MODEL_DIR,
159
+ use_safetensors=True
160
+ )
161
+ eval_tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_DIR)
162
+
163
+ eval_trainer = Trainer(model=model_for_eval, args=training_args)
164
+
165
+ predictions = eval_trainer.predict(test_dataset)
166
+
167
+ y_true = predictions.label_ids
168
+ y_pred_logits = predictions.predictions
169
+ y_pred_probs = softmax(y_pred_logits, axis=1)
170
+ y_pred_labels = np.argmax(y_pred_logits, axis=1)
171
+
172
+ print("--- Generating Classification Report ---")
173
+ report = classification_report(y_true, y_pred_labels, target_names=["Ham (0)", "Phishing (1)"])
174
+ report_path = os.path.join(REPORT_DIR, "classification_report.txt")
175
+
176
+ with open(report_path, "w") as f:
177
+ f.write("--- Semantic Model Classification Report ---\n\n")
178
+ f.write(report)
179
+
180
+ print(report)
181
+ print(f"Classification report saved to {report_path}")
182
+
183
+ print("--- Generating Confusion Matrix ---")
184
+ cm = confusion_matrix(y_true, y_pred_labels)
185
+ cm_path = os.path.join(REPORT_DIR, "confusion_matrix.png")
186
+
187
+ plt.figure(figsize=(8, 6))
188
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
189
+ xticklabels=["Ham (0)", "Phishing (1)"],
190
+ yticklabels=["Ham (0)", "Phishing (1)"])
191
+ plt.title("Confusion Matrix for Semantic Model")
192
+ plt.xlabel("Predicted Label")
193
+ plt.ylabel("True Label")
194
+ plt.savefig(cm_path)
195
+ plt.close()
196
+ print(f"Confusion matrix saved to {cm_path}")
197
+
198
+ print("--- Generating Probability Scatterplot ---")
199
+ prob_df = pd.DataFrame({
200
+ 'true_label': y_true,
201
+ 'predicted_phishing_prob': y_pred_probs[:, 1]
202
+ })
203
+ prob_path = os.path.join(REPORT_DIR, "probability_scatterplot.png")
204
+
205
+ plt.figure(figsize=(10, 6))
206
+ sns.stripplot(data=prob_df, x='true_label', y='predicted_phishing_prob', jitter=0.2, alpha=0.7)
207
+ plt.title("Model Confidence: Predicted Phishing Probability vs. True Label")
208
+ plt.xlabel("True Label")
209
+ plt.ylabel("Predicted Phishing Probability")
210
+ plt.xticks([0, 1], ["Ham (0)", "Phishing (1)"])
211
+ plt.axhline(0.5, color='r', linestyle='--', label='Decision Boundary (0.5)')
212
+ plt.legend()
213
+ plt.savefig(prob_path)
214
+ plt.close()
215
+ print(f"Probability scatterplot saved to {prob_path}")
216
+
217
+ print("---------------------------------")
218
+ print(f"--- Evaluation Complete. Reports saved to {REPORT_DIR} ---")
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AEGIS SECURE API
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ short_description: This is the main model api for the project aegis secure.
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_DEPLOYMENT.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Start: Deploy Phishing Detection API to GCP
2
+
3
+ ## Prerequisites Checklist
4
+
5
+ - [ ] Google Cloud Platform account with billing enabled
6
+ - [ ] gcloud CLI installed and authenticated (`gcloud auth login`)
7
+ - [ ] Docker installed and running
8
+ - [ ] GROQ API key (get from https://console.groq.com/)
9
+
10
+ ## Fastest Deployment (One Command)
11
+
12
+ ### For Linux/Mac:
13
+ ```bash
14
+ chmod +x deploy.sh
15
+ ./deploy.sh YOUR_PROJECT_ID us-central1
16
+ ```
17
+
18
+ ### For Windows PowerShell:
19
+ ```powershell
20
+ .\deploy.ps1 -ProjectId YOUR_PROJECT_ID -Region us-central1
21
+ ```
22
+
23
+ **Replace `YOUR_PROJECT_ID` with your actual GCP project ID.**
24
+
25
+ The script will:
26
+ 1. ✅ Enable required GCP APIs
27
+ 2. ✅ Create GROQ_API_KEY secret
28
+ 3. ✅ Build Docker image
29
+ 4. ✅ Deploy to Cloud Run
30
+ 5. ✅ Display your service URL
31
+
32
+ ## Manual Deployment Steps
33
+
34
+ If you prefer manual steps, see [DEPLOYMENT.md](./DEPLOYMENT.md) for detailed instructions.
35
+
36
+ ## Files Created for Deployment
37
+
38
+ - **Dockerfile** - Container configuration for the application
39
+ - **.dockerignore** - Files to exclude from Docker build
40
+ - **cloudbuild.yaml** - GCP Cloud Build configuration
41
+ - **deploy.sh** - Automated deployment script (Linux/Mac)
42
+ - **deploy.ps1** - Automated deployment script (Windows)
43
+ - **docker-compose.yml** - For local testing with Docker
44
+ - **DEPLOYMENT.md** - Comprehensive deployment guide
45
+
46
+ ## Testing Locally First (Optional)
47
+
48
+ Before deploying to GCP, test locally:
49
+
50
+ ```bash
51
+ # Build the image
52
+ docker build -t phishing-api:local .
53
+
54
+ # Run locally
55
+ docker run -p 8000:8000 \
56
+ -e GROQ_API_KEY=your-api-key \
57
+ phishing-api:local
58
+
59
+ # Test in another terminal
60
+ curl http://localhost:8000/health
61
+ ```
62
+
63
+ Or use Docker Compose:
64
+
65
+ ```bash
66
+ # Set your API key in environment
67
+ export GROQ_API_KEY=your-api-key # Linux/Mac
68
+ # or
69
+ $env:GROQ_API_KEY="your-api-key" # Windows PowerShell
70
+
71
+ # Run
72
+ docker-compose up
73
+ ```
74
+
75
+ ## After Deployment
76
+
77
+ Your service will be available at:
78
+ - **API**: `https://phishing-detection-api-[hash]-uc.a.run.app`
79
+ - **Health Check**: `https://phishing-detection-api-[hash]-uc.a.run.app/health`
80
+ - **API Docs**: `https://phishing-detection-api-[hash]-uc.a.run.app/docs`
81
+
82
+ ## Quick Test
83
+
84
+ ```bash
85
+ # Get your service URL
86
+ SERVICE_URL=$(gcloud run services describe phishing-detection-api \
87
+ --region us-central1 \
88
+ --format="value(status.url)")
89
+
90
+ # Test health endpoint
91
+ curl $SERVICE_URL/health
92
+
93
+ # Test prediction endpoint
94
+ curl -X POST "$SERVICE_URL/predict" \
95
+ -H "Content-Type: application/json" \
96
+ -d '{
97
+ "sender": "test@example.com",
98
+ "subject": "Urgent Action Required",
99
+ "text": "Please verify your account at http://suspicious-site.com",
100
+ "metadata": {}
101
+ }'
102
+ ```
103
+
104
+ ## Troubleshooting
105
+
106
+ 1. **Build fails**: Check that all model files are in the correct directories
107
+ 2. **Service won't start**: Check logs with `gcloud run services logs read phishing-detection-api --region us-central1`
108
+ 3. **GROQ API errors**: Verify secret is set correctly: `gcloud secrets versions access latest --secret="GROQ_API_KEY"`
109
+ 4. **Memory issues**: Increase memory in cloudbuild.yaml or deployment command
110
+
111
+ For detailed troubleshooting, see [DEPLOYMENT.md](./DEPLOYMENT.md#troubleshooting).
112
+
113
+ ## Next Steps
114
+
115
+ - Monitor your service in [Cloud Run Console](https://console.cloud.google.com/run)
116
+ - Set up alerts in [Cloud Monitoring](https://console.cloud.google.com/monitoring)
117
+ - Review API documentation at `/docs` endpoint
118
+ - Scale resources if needed (see DEPLOYMENT.md)
119
+
120
+ ## Support
121
+
122
+ For issues:
123
+ 1. Check logs: `gcloud run services logs read phishing-detection-api --region us-central1`
124
+ 2. Review [DEPLOYMENT.md](./DEPLOYMENT.md) for detailed information
125
+ 3. Verify all prerequisites are met
api_keys.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Optional
4
+ import asyncio
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta
7
+ from dotenv import load_dotenv
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Load environment variables from .env file
15
+ env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env')
16
+ load_dotenv(dotenv_path=env_path)
17
+
18
+ # Debug: Print environment variables
19
+ logger.info(f"Current working directory: {os.getcwd()}")
20
+ logger.info(f"Loading .env from: {env_path}")
21
+ logger.info(f"GEMINI_API_KEY: {'*' * 8 + os.getenv('GEMINI_API_KEY', '')[-4:] if os.getenv('GEMINI_API_KEY') else 'Not set'}")
22
+ logger.info(f"GEMINI_API_KEYS: {'*' * 8 + os.getenv('GEMINI_API_KEYS', '')[-4:] if os.getenv('GEMINI_API_KEYS') else 'Not set'}")
23
+
24
+ @dataclass
25
+ class APIKey:
26
+ key: str
27
+ last_used: Optional[datetime] = None
28
+ is_available: bool = True
29
+ rate_limit_reset: Optional[datetime] = None
30
+
31
+ class APIKeyManager:
32
+ _instance = None
33
+ _lock = asyncio.Lock()
34
+
35
+ def __new__(cls):
36
+ if cls._instance is None:
37
+ cls._instance = super(APIKeyManager, cls).__new__(cls)
38
+ cls._instance._initialize()
39
+ return cls._instance
40
+
41
+ def _initialize(self):
42
+ self.keys: List[APIKey] = []
43
+ self._current_index = 0
44
+ self._load_api_keys()
45
+
46
+ def _load_api_keys(self):
47
+ # Try to load from GEMINI_API_KEY first
48
+ single_key = os.getenv('GEMINI_API_KEY', '').strip()
49
+ if single_key:
50
+ single_key = single_key.strip('"\'')
51
+ self.keys = [APIKey(key=single_key)]
52
+ logger.info(f"Loaded 1 API key from GEMINI_API_KEY")
53
+ return
54
+
55
+ # Fall back to GEMINI_API_KEYS if GEMINI_API_KEY is not set
56
+ api_keys_str = os.getenv('GEMINI_API_KEYS', '').strip()
57
+ if api_keys_str:
58
+ keys = [key.strip().strip('"\'') for key in api_keys_str.split(',') if key.strip()]
59
+ self.keys = [APIKey(key=key) for key in keys]
60
+ logger.info(f"Loaded {len(keys)} API keys from GEMINI_API_KEYS")
61
+ return
62
+
63
+ logger.warning("No API keys found in environment variables")
64
+
65
+ def get_available_key(self) -> Optional[str]:
66
+ """Get an available API key, considering rate limits."""
67
+ now = datetime.utcnow()
68
+
69
+ for key_obj in self.keys:
70
+ if not key_obj.is_available:
71
+ if key_obj.rate_limit_reset and now >= key_obj.rate_limit_reset:
72
+ key_obj.is_available = True
73
+ key_obj.rate_limit_reset = None
74
+ else:
75
+ continue
76
+
77
+ key_obj.last_used = now
78
+ return key_obj.key
79
+
80
+ return None
81
+
82
+ def mark_key_unavailable(self, key: str, retry_after_seconds: int = 60):
83
+ """Mark a key as unavailable due to rate limiting."""
84
+ for key_obj in self.keys:
85
+ if key_obj.key == key:
86
+ key_obj.is_available = False
87
+ key_obj.rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after_seconds)
88
+ logger.warning(f"Rate limit hit for API key. Will retry after {retry_after_seconds} seconds")
89
+ return
90
+ logger.warning(f"Tried to mark unknown API key as unavailable")
91
+
92
+ api_key_manager = APIKeyManager()
app.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import time
5
+ import sys
6
+ import asyncio
7
+ from typing import List, Dict, Optional
8
+ from urllib.parse import urlparse
9
+ import socket
10
+ import httpx
11
+
12
+ import joblib
13
+ import torch
14
+ import numpy as np
15
+ import pandas as pd
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ import google.generativeai as genai
20
+ from dotenv import load_dotenv
21
+
22
+ import config
23
+ from models import get_ml_models, get_dl_models, FinetunedBERT
24
+ from feature_extraction import process_row
25
+
26
+ load_dotenv()
27
+ sys.path.append(os.path.join(config.BASE_DIR, 'Message_model'))
28
+ from predict import PhishingPredictor
29
+
30
+ app = FastAPI(
31
+ title="Phishing Detection API",
32
+ description="Advanced phishing detection system using multiple ML/DL models and Gemini AI",
33
+ version="1.0.0"
34
+ )
35
+
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ class MessageInput(BaseModel):
45
+ text: str
46
+ metadata: Optional[Dict] = {}
47
+
48
+ class PredictionResponse(BaseModel):
49
+ confidence: float
50
+ reasoning: str
51
+ highlighted_text: str
52
+ final_decision: str
53
+ suggestion: str
54
+
55
+ ml_models = {}
56
+ dl_models = {}
57
+ bert_model = None
58
+ semantic_model = None
59
+ gemini_model = None
60
+
61
+ MODEL_BOUNDARIES = {
62
+ 'logistic': 0.5,
63
+ 'svm': 0.5,
64
+ 'xgboost': 0.5,
65
+ 'attention_blstm': 0.5,
66
+ 'rcnn': 0.5,
67
+ 'bert': 0.5,
68
+ 'semantic': 0.5
69
+ }
70
+
71
+ def load_models():
72
+ global ml_models, dl_models, bert_model, semantic_model, gemini_model
73
+
74
+ print("Loading models...")
75
+
76
+ models_dir = config.MODELS_DIR
77
+ for model_name in ['logistic', 'svm', 'xgboost']:
78
+ model_path = os.path.join(models_dir, f'{model_name}.joblib')
79
+ if os.path.exists(model_path):
80
+ ml_models[model_name] = joblib.load(model_path)
81
+ print(f"✓ Loaded {model_name} model")
82
+ else:
83
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
84
+
85
+ for model_name in ['attention_blstm', 'rcnn']:
86
+ model_path = os.path.join(models_dir, f'{model_name}.pt')
87
+ if os.path.exists(model_path):
88
+ model_template = get_dl_models(input_dim=len(config.NUMERICAL_FEATURES))
89
+ dl_models[model_name] = model_template[model_name]
90
+ dl_models[model_name].load_state_dict(torch.load(model_path, map_location='cpu'))
91
+ dl_models[model_name].eval()
92
+ print(f"✓ Loaded {model_name} model")
93
+ else:
94
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
95
+
96
+ bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert')
97
+ if os.path.exists(bert_path):
98
+ try:
99
+ bert_model = FinetunedBERT(bert_path)
100
+ print("✓ Loaded BERT model")
101
+ except Exception as e:
102
+ print(f"⚠ Warning: Could not load BERT model: {e}")
103
+
104
+ semantic_model_path = os.path.join(config.BASE_DIR, 'Message_model', 'final_semantic_model')
105
+ if os.path.exists(semantic_model_path) and os.listdir(semantic_model_path):
106
+ try:
107
+ semantic_model = PhishingPredictor(model_path=semantic_model_path)
108
+ print("✓ Loaded semantic model")
109
+ except Exception as e:
110
+ print(f"⚠ Warning: Could not load semantic model: {e}")
111
+ else:
112
+ checkpoint_path = os.path.join(config.BASE_DIR, 'Message_model', 'training_checkpoints', 'checkpoint-30')
113
+ if os.path.exists(checkpoint_path):
114
+ try:
115
+ semantic_model = PhishingPredictor(model_path=checkpoint_path)
116
+ print("✓ Loaded semantic model from checkpoint")
117
+ except Exception as e:
118
+ print(f"⚠ Warning: Could not load semantic model from checkpoint: {e}")
119
+
120
+ gemini_api_key = os.environ.get('GEMINI_API_KEY')
121
+ if gemini_api_key:
122
+ genai.configure(api_key=gemini_api_key)
123
+ gemini_model = genai.GenerativeModel('gemini-2.0-flash')
124
+ print("✓ Initialized Gemini API")
125
+ else:
126
+ print("⚠ Warning: GEMINI_API_KEY not set. Set it as environment variable.")
127
+ print(" Example: export GEMINI_API_KEY='your-api-key-here'")
128
+
129
+ def parse_message(text: str) -> tuple:
130
+ url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)?[a-zA-Z0-9-]+\.[a-z]{2,12}\b(?:/[^\s]*)?'
131
+ urls = re.findall(url_pattern, text)
132
+ cleaned_text = re.sub(url_pattern, '', text)
133
+ cleaned_text = ' '.join(cleaned_text.lower().split())
134
+ cleaned_text = re.sub(r'[^a-z0-9\s.,!?-]', '', cleaned_text)
135
+ cleaned_text = re.sub(r'([.,!?])+', r'\1', cleaned_text)
136
+ cleaned_text = ' '.join(cleaned_text.split())
137
+ return urls, cleaned_text
138
+
139
+ async def extract_url_features(urls: List[str]) -> pd.DataFrame:
140
+ if not urls:
141
+ return pd.DataFrame()
142
+
143
+ df = pd.DataFrame({'url': urls})
144
+ whois_cache = {}
145
+ ssl_cache = {}
146
+
147
+ tasks = []
148
+ for _, row in df.iterrows():
149
+ tasks.append(asyncio.to_thread(process_row, row, whois_cache, ssl_cache))
150
+
151
+ feature_list = await asyncio.gather(*tasks)
152
+ features_df = pd.DataFrame(feature_list)
153
+ result_df = pd.concat([df, features_df], axis=1)
154
+ return result_df
155
+
156
+ def custom_boundary(raw_score: float, boundary: float) -> float:
157
+ return (raw_score - boundary) * 100
158
+
159
+ def get_model_predictions(features_df: pd.DataFrame, message_text: str) -> Dict:
160
+ predictions = {}
161
+
162
+ numerical_features = config.NUMERICAL_FEATURES
163
+ categorical_features = config.CATEGORICAL_FEATURES
164
+
165
+ try:
166
+ X = features_df[numerical_features + categorical_features]
167
+ except KeyError as e:
168
+ print(f"Error: Missing columns in features_df. {e}")
169
+ print(f"Available columns: {features_df.columns.tolist()}")
170
+ X = pd.DataFrame(columns=numerical_features + categorical_features)
171
+
172
+ if not X.empty:
173
+ X.loc[:, numerical_features] = X.loc[:, numerical_features].fillna(-1)
174
+ X.loc[:, categorical_features] = X.loc[:, categorical_features].fillna('N/A')
175
+
176
+ for model_name, model in ml_models.items():
177
+ try:
178
+ all_probas = model.predict_proba(X)[:, 1]
179
+ raw_score = np.max(all_probas)
180
+
181
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
182
+ predictions[model_name] = {
183
+ 'raw_score': float(raw_score),
184
+ 'scaled_score': float(scaled_score)
185
+ }
186
+ except Exception as e:
187
+ print(f"Error with {model_name} (Prediction Step): {e}")
188
+
189
+ X_numerical = X[numerical_features].values
190
+
191
+ for model_name, model in dl_models.items():
192
+ try:
193
+ X_tensor = torch.tensor(X_numerical, dtype=torch.float32)
194
+ with torch.no_grad():
195
+ all_scores = model(X_tensor)
196
+ raw_score = torch.max(all_scores).item()
197
+
198
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
199
+ predictions[model_name] = {
200
+ 'raw_score': float(raw_score),
201
+ 'scaled_score': float(scaled_score)
202
+ }
203
+ except Exception as e:
204
+ print(f"Error with {model_name}: {e}")
205
+
206
+ if bert_model and len(features_df) > 0:
207
+ try:
208
+ urls = features_df['url'].tolist()
209
+ raw_scores = bert_model.predict_proba(urls)
210
+ avg_raw_score = np.mean([score[1] for score in raw_scores])
211
+ scaled_score = custom_boundary(avg_raw_score, MODEL_BOUNDARIES['bert'])
212
+ predictions['bert'] = {
213
+ 'raw_score': float(avg_raw_score),
214
+ 'scaled_score': float(scaled_score)
215
+ }
216
+ except Exception as e:
217
+ print(f"Error with BERT: {e}")
218
+
219
+ if semantic_model and message_text:
220
+ try:
221
+ result = semantic_model.predict(message_text)
222
+ raw_score = result['phishing_probability']
223
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES['semantic'])
224
+ predictions['semantic'] = {
225
+ 'raw_score': float(raw_score),
226
+ 'scaled_score': float(scaled_score),
227
+ 'confidence': result['confidence']
228
+ }
229
+ except Exception as e:
230
+ print(f"Error with semantic model: {e}")
231
+
232
+ return predictions
233
+
234
+ async def get_network_features_for_gemini(urls: List[str]) -> str:
235
+ if not urls:
236
+ return "No URLs to analyze for network features."
237
+
238
+ results = []
239
+ async with httpx.AsyncClient() as client:
240
+ for i, url_str in enumerate(urls[:3]):
241
+ try:
242
+ hostname = urlparse(url_str).hostname
243
+ if not hostname:
244
+ results.append(f"\nURL {i+1} ({url_str}): Invalid URL, no hostname.")
245
+ continue
246
+
247
+ try:
248
+ ip_address = await asyncio.to_thread(socket.gethostbyname, hostname)
249
+ except socket.gaierror:
250
+ results.append(f"\nURL {i+1} ({hostname}): Could not resolve domain to IP.")
251
+ continue
252
+
253
+ try:
254
+ geo_url = f"http://ip-api.com/json/{ip_address}?fields=status,message,country,city,isp,org,as"
255
+ response = await client.get(geo_url, timeout=3.0)
256
+ response.raise_for_status()
257
+ data = response.json()
258
+
259
+ if data.get('status') == 'success':
260
+ geo_info = (
261
+ f" • IP Address: {ip_address}\n"
262
+ f" • Location: {data.get('city', 'N/A')}, {data.get('country', 'N/A')}\n"
263
+ f" • ISP: {data.get('isp', 'N/A')}\n"
264
+ f" • Organization: {data.get('org', 'N/A')}\n"
265
+ f" • ASN: {data.get('as', 'N/A')}"
266
+ )
267
+ results.append(f"\nURL {i+1} ({hostname}):\n{geo_info}")
268
+ else:
269
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: API lookup failed ({data.get('message')})")
270
+
271
+ except httpx.RequestError as e:
272
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: Network error while fetching IP info ({str(e)})")
273
+
274
+ except Exception as e:
275
+ results.append(f"\nURL {i+1} ({url_str}): Error processing URL ({str(e)})")
276
+
277
+ if not results:
278
+ return "No valid hostnames found in URLs to analyze."
279
+
280
+ return "\n".join(results)
281
+
282
+
283
+ async def get_gemini_final_decision(urls: List[str], features_df: pd.DataFrame,
284
+ message_text: str, predictions: Dict,
285
+ original_text: str) -> Dict:
286
+
287
+ if not gemini_model:
288
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
289
+ confidence = min(100, max(0, 50 + abs(avg_scaled_score)))
290
+
291
+ return {
292
+ "confidence": round(confidence, 2),
293
+ "reasoning": "Gemini API not available. Using average model scores.",
294
+ "highlighted_text": original_text,
295
+ "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate",
296
+ "suggestion": "Do not interact with this message. Delete it immediately and report it to your IT department." if avg_scaled_score > 0 else "This message appears safe, but remain cautious with any links or attachments."
297
+ }
298
+
299
+ url_features_summary = "No URLs detected in message"
300
+ has_urls = len(features_df) > 0
301
+
302
+ if has_urls:
303
+ feature_summary_parts = []
304
+ for idx, row in features_df.iterrows():
305
+ url = row.get('url', 'Unknown')
306
+ feature_summary_parts.append(f"\nURL {idx+1}: {url}")
307
+ feature_summary_parts.append(f" • Length: {row.get('url_length', 'N/A')} chars")
308
+ feature_summary_parts.append(f" • Dots in URL: {row.get('count_dot', 'N/A')}")
309
+ feature_summary_parts.append(f" • Special characters: {row.get('count_special_chars', 'N/A')}")
310
+ feature_summary_parts.append(f" • Domain age: {row.get('domain_age_days', 'N/A')} days")
311
+ feature_summary_parts.append(f" • SSL certificate valid: {row.get('cert_has_valid_hostname', 'N/A')}")
312
+ feature_summary_parts.append(f" • Uses HTTPS: {row.get('https', 'N/A')}")
313
+ url_features_summary = "\n".join(feature_summary_parts)
314
+
315
+ network_features_summary = await get_network_features_for_gemini(urls)
316
+
317
+ model_predictions_summary = []
318
+ for model_name, pred_data in predictions.items():
319
+ scaled = pred_data['scaled_score']
320
+ raw = pred_data['raw_score']
321
+ model_predictions_summary.append(
322
+ f" • {model_name.upper()}: scaled_score={scaled:.2f} (raw={raw:.3f})"
323
+ )
324
+ model_scores_text = "\n".join(model_predictions_summary)
325
+
326
+ MAX_TEXT_LEN = 3000
327
+ if len(original_text) > MAX_TEXT_LEN:
328
+ truncated_original_text = original_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
329
+ else:
330
+ truncated_original_text = original_text
331
+
332
+ if len(message_text) > MAX_TEXT_LEN:
333
+ truncated_message_text = message_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
334
+ else:
335
+ truncated_message_text = message_text
336
+
337
+ context = f"""You are a security model that must decide if a message is phishing or legitimate.
338
+
339
+ Use all evidence below:
340
+ - URL/network data (trust NETWORK_GEO more than URL_FEATURES when they disagree; domain_age = -1 means unknown).
341
+ - Model scores (scaled_score > 0 → more phishing, < 0 → more legitimate).
342
+ - Message content (urgency, threats, credential/OTP/payment requests, impersonation).
343
+
344
+ If strong phishing signals exist, prefer "phishing". If everything matches a normal, known service/organization and content is routine, prefer "legitimate".
345
+
346
+ Return only this JSON object:
347
+ {{
348
+ "confidence": <float 0-100>,
349
+ "reasoning": "<brief explanation referring to key evidence>",
350
+ "highlighted_text": "<full original message with suspicious spans wrapped in $$...$$>",
351
+ "final_decision": "phishing" or "legitimate",
352
+ "suggestion": "<practical advice for the user on what to do>"
353
+ }}
354
+
355
+ MESSAGE_ORIGINAL:
356
+ {truncated_original_text}
357
+
358
+ MESSAGE_CLEANED:
359
+ {truncated_message_text}
360
+
361
+ URLS:
362
+ {', '.join(urls) if urls else 'None'}
363
+
364
+ URL_FEATURES:
365
+ {url_features_summary}
366
+
367
+ NETWORK_GEO:
368
+ {network_features_summary}
369
+
370
+ MODEL_SCORES (scaled_score > 0 phishing, < 0 legitimate):
371
+ {model_scores_text}
372
+ """
373
+
374
+ try:
375
+ generation_config = {
376
+ 'temperature': 0.2,
377
+ 'top_p': 0.85,
378
+ 'top_k': 40,
379
+ 'max_output_tokens': 8192,
380
+ 'response_mime_type': 'application/json'
381
+ }
382
+
383
+ safety_settings = {
384
+ "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
385
+ "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
386
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
387
+ "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
388
+ }
389
+
390
+ max_retries = 3
391
+ retry_delay = 2
392
+
393
+ for attempt in range(max_retries):
394
+ try:
395
+ response = await gemini_model.generate_content_async(
396
+ context,
397
+ generation_config=generation_config,
398
+ safety_settings=safety_settings
399
+ )
400
+ if not response.candidates or not response.candidates[0].content.parts:
401
+ raise ValueError(f"No content returned. Finish Reason: {response.candidates[0].finish_reason}")
402
+ break
403
+
404
+ except Exception as retry_error:
405
+ print(f"Gemini API attempt {attempt + 1} failed: {retry_error}")
406
+ if attempt < max_retries - 1:
407
+ print(f"Retrying in {retry_delay}s...")
408
+ await asyncio.sleep(retry_delay)
409
+ retry_delay *= 2
410
+ else:
411
+ raise retry_error
412
+
413
+ response_text = response.text.strip()
414
+
415
+ if '```json' in response_text:
416
+ response_text = response_text.split('```json')[1].split('```')[0].strip()
417
+ elif response_text.startswith('```') and response_text.endswith('```'):
418
+ response_text = response_text[3:-3].strip()
419
+
420
+ if not response_text.startswith('{'):
421
+ json_match = re.search(r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}', response_text, re.DOTALL)
422
+ if json_match:
423
+ response_text = json_match.group(0)
424
+ else:
425
+ raise ValueError(f"Could not find JSON in Gemini response: {response_text[:200]}")
426
+
427
+ result = json.loads(response_text)
428
+
429
+ required_fields = ['confidence', 'reasoning', 'highlighted_text', 'final_decision', 'suggestion']
430
+ if not all(field in result for field in required_fields):
431
+ raise ValueError(f"Missing required fields. Got: {list(result.keys())}")
432
+
433
+ result['confidence'] = float(result['confidence'])
434
+ if not 0 <= result['confidence'] <= 100:
435
+ result['confidence'] = max(0, min(100, result['confidence']))
436
+
437
+ if result['final_decision'].lower() not in ['phishing', 'legitimate']:
438
+ result['final_decision'] = 'phishing' if result['confidence'] >= 50 else 'legitimate'
439
+ else:
440
+ result['final_decision'] = result['final_decision'].lower()
441
+
442
+ if not result['highlighted_text'].strip() or '...' in result['highlighted_text']:
443
+ result['highlighted_text'] = original_text
444
+
445
+ if not result.get('suggestion', '').strip():
446
+ if result['final_decision'] == 'phishing':
447
+ result['suggestion'] = "Do not interact with this message. Delete it immediately and report it as phishing."
448
+ else:
449
+ result['suggestion'] = "This message appears safe, but always verify sender identity before taking any action."
450
+
451
+ return result
452
+
453
+ except json.JSONDecodeError as e:
454
+ print(f"JSON parsing error: {e}")
455
+ print(f"Response text that failed parsing: {response_text[:500]}")
456
+
457
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
458
+ confidence = min(100, max(0, 50 + abs(avg_scaled_score)))
459
+
460
+ return {
461
+ "confidence": round(confidence, 2),
462
+ "reasoning": f"Gemini response parsing failed. Fallback: Based on model average (score: {avg_scaled_score:.2f}), message appears {'legitimate' if avg_scaled_score <= 0 else 'suspicious'}.",
463
+ "highlighted_text": original_text,
464
+ "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate",
465
+ "suggestion": "Do not interact with this message. Delete it immediately and be cautious." if avg_scaled_score > 0 else "Exercise caution. Verify the sender before taking any action."
466
+ }
467
+
468
+ except Exception as e:
469
+ print(f"Error with Gemini API: {e}")
470
+
471
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
472
+ confidence = min(100, max(0, 50 + abs(avg_scaled_score)))
473
+
474
+ return {
475
+ "confidence": round(confidence, 2),
476
+ "reasoning": f"Gemini API error: {str(e)}. Fallback decision based on {len(predictions)} model predictions (average score: {avg_scaled_score:.2f}).",
477
+ "highlighted_text": original_text,
478
+ "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate",
479
+ "suggestion": "Treat this message with caution. Delete it if suspicious, or verify the sender through official channels before taking action." if avg_scaled_score > 0 else "This message appears safe based on models, but always verify sender identity before clicking links or providing information."
480
+ }
481
+
482
+ @app.on_event("startup")
483
+ async def startup_event():
484
+ load_models()
485
+ print("\n" + "="*60)
486
+ print("Phishing Detection API is ready!")
487
+ print("="*60)
488
+ print("API Documentation: http://localhost:8000/docs")
489
+ print("="*60 + "\n")
490
+
491
+ @app.get("/")
492
+ async def root():
493
+ return {
494
+ "message": "Phishing Detection API",
495
+ "version": "1.0.0",
496
+ "endpoints": {
497
+ "predict": "/predict (POST)",
498
+ "health": "/health (GET)",
499
+ "docs": "/docs (GET)"
500
+ }
501
+ }
502
+
503
+ @app.get("/health")
504
+ async def health_check():
505
+ models_loaded = {
506
+ "ml_models": list(ml_models.keys()),
507
+ "dl_models": list(dl_models.keys()),
508
+ "bert_model": bert_model is not None,
509
+ "semantic_model": semantic_model is not None,
510
+ "gemini_model": gemini_model is not None
511
+ }
512
+
513
+ return {
514
+ "status": "healthy",
515
+ "models_loaded": models_loaded
516
+ }
517
+
518
+ @app.post("/predict", response_model=PredictionResponse)
519
+ async def predict(message_input: MessageInput):
520
+ try:
521
+ original_text = message_input.text
522
+
523
+ if not original_text or not original_text.strip():
524
+ raise HTTPException(status_code=400, detail="Message text cannot be empty")
525
+
526
+ urls, cleaned_text = parse_message(original_text)
527
+
528
+ features_df = pd.DataFrame()
529
+ if urls:
530
+ features_df = await extract_url_features(urls)
531
+
532
+ predictions = {}
533
+ if len(features_df) > 0 or (cleaned_text and semantic_model):
534
+ predictions = await asyncio.to_thread(get_model_predictions, features_df, cleaned_text)
535
+
536
+ if not predictions:
537
+ if not urls and not cleaned_text:
538
+ detail = "Message text is empty after cleaning."
539
+ elif not urls and not semantic_model:
540
+ detail = "No URLs provided and semantic model is not loaded."
541
+ elif not any([ml_models, dl_models, bert_model, semantic_model]):
542
+ detail = "No models available for prediction. Please ensure models are trained and loaded."
543
+ else:
544
+ detail = "Could not generate predictions. Models may be missing or feature extraction failed."
545
+
546
+ raise HTTPException(
547
+ status_code=500,
548
+ detail=detail
549
+ )
550
+
551
+ final_result = await get_gemini_final_decision(
552
+ urls, features_df, cleaned_text, predictions, original_text
553
+ )
554
+
555
+ return PredictionResponse(**final_result)
556
+
557
+ except HTTPException:
558
+ raise
559
+ except Exception as e:
560
+ import traceback
561
+ print(traceback.format_exc())
562
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
563
+
564
+ if __name__ == "__main__":
565
+ import uvicorn
566
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app1.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import time
5
+ import sys
6
+ import asyncio
7
+ from typing import List, Dict, Optional
8
+ from urllib.parse import urlparse
9
+ import socket
10
+ import httpx
11
+
12
+ import joblib
13
+ import torch
14
+ import numpy as np
15
+ import pandas as pd
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ from groq import AsyncGroq
20
+ from dotenv import load_dotenv
21
+ from bs4 import BeautifulSoup
22
+
23
+ import config
24
+ from models import get_ml_models, get_dl_models, FinetunedBERT
25
+ from feature_extraction import process_row
26
+
27
+ load_dotenv()
28
+ sys.path.append(os.path.join(config.BASE_DIR, 'Message_model'))
29
+ from predict import PhishingPredictor
30
+
31
+ app = FastAPI(
32
+ title="Phishing Detection API",
33
+ description="Advanced phishing detection system using multiple ML/DL models and Groq",
34
+ version="1.0.0"
35
+ )
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+ class MessageInput(BaseModel):
46
+ sender: Optional[str] = ""
47
+ subject: Optional[str] = ""
48
+ text: str
49
+ metadata: Optional[Dict] = {}
50
+
51
+ class PredictionResponse(BaseModel):
52
+ confidence: float
53
+ reasoning: str
54
+ highlighted_text: str
55
+ final_decision: str
56
+ suggestion: str
57
+
58
+ ml_models = {}
59
+ dl_models = {}
60
+ bert_model = None
61
+ semantic_model = None
62
+ groq_async_client = None
63
+
64
+ MODEL_BOUNDARIES = {
65
+ 'logistic': 0.5,
66
+ 'svm': 0.5,
67
+ 'xgboost': 0.5,
68
+ 'attention_blstm': 0.5,
69
+ 'rcnn': 0.5,
70
+ 'bert': 0.5,
71
+ 'semantic': 0.5
72
+ }
73
+
74
+ def load_models():
75
+ global ml_models, dl_models, bert_model, semantic_model, groq_async_client
76
+
77
+ print("Loading models...")
78
+
79
+ models_dir = config.MODELS_DIR
80
+ for model_name in ['logistic', 'svm', 'xgboost']:
81
+ model_path = os.path.join(models_dir, f'{model_name}.joblib')
82
+ if os.path.exists(model_path):
83
+ ml_models[model_name] = joblib.load(model_path)
84
+ print(f"✓ Loaded {model_name} model")
85
+ else:
86
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
87
+
88
+ for model_name in ['attention_blstm', 'rcnn']:
89
+ model_path = os.path.join(models_dir, f'{model_name}.pt')
90
+ if os.path.exists(model_path):
91
+ model_template = get_dl_models(input_dim=len(config.NUMERICAL_FEATURES))
92
+ dl_models[model_name] = model_template[model_name]
93
+ dl_models[model_name].load_state_dict(torch.load(model_path, map_location='cpu'))
94
+ dl_models[model_name].eval()
95
+ print(f"✓ Loaded {model_name} model")
96
+ else:
97
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
98
+
99
+ bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert')
100
+ if os.path.exists(bert_path):
101
+ try:
102
+ bert_model = FinetunedBERT(bert_path)
103
+ print("✓ Loaded BERT model")
104
+ except Exception as e:
105
+ print(f"⚠ Warning: Could not load BERT model: {e}")
106
+
107
+ semantic_model_path = os.path.join(config.BASE_DIR, 'Message_model', 'final_semantic_model')
108
+ if os.path.exists(semantic_model_path) and os.listdir(semantic_model_path):
109
+ try:
110
+ semantic_model = PhishingPredictor(model_path=semantic_model_path)
111
+ print("✓ Loaded semantic model")
112
+ except Exception as e:
113
+ print(f"⚠ Warning: Could not load semantic model: {e}")
114
+ else:
115
+ checkpoint_path = os.path.join(config.BASE_DIR, 'Message_model', 'training_checkpoints', 'checkpoint-30')
116
+ if os.path.exists(checkpoint_path):
117
+ try:
118
+ semantic_model = PhishingPredictor(model_path=checkpoint_path)
119
+ print("✓ Loaded semantic model from checkpoint")
120
+ except Exception as e:
121
+ print(f"⚠ Warning: Could not load semantic model from checkpoint: {e}")
122
+
123
+ groq_api_key = os.environ.get('GROQ_API_KEY')
124
+ if groq_api_key:
125
+ groq_async_client = AsyncGroq(api_key=groq_api_key)
126
+ print("✓ Initialized Groq API Client")
127
+ else:
128
+ print("⚠ Warning: GROQ_API_KEY not set. Set it as environment variable.")
129
+ print(" Example: export GROQ_API_KEY='your-api-key-here'")
130
+
131
+ def parse_message(text: str) -> tuple:
132
+ url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)?[a-zA-Z0-9-]+\.[a-z]{2,12}\b(?:/[^\s]*)?'
133
+ urls = re.findall(url_pattern, text)
134
+ cleaned_text = re.sub(url_pattern, '', text)
135
+ cleaned_text = ' '.join(cleaned_text.lower().split())
136
+ cleaned_text = re.sub(r'[^a-z0-9\s.,!?-]', '', cleaned_text)
137
+ cleaned_text = re.sub(r'([.,!?])+', r'\1', cleaned_text)
138
+ cleaned_text = ' '.join(cleaned_text.split())
139
+ return urls, cleaned_text
140
+
141
+ async def extract_url_features(urls: List[str]) -> pd.DataFrame:
142
+ if not urls:
143
+ return pd.DataFrame()
144
+
145
+ df = pd.DataFrame({'url': urls})
146
+ whois_cache = {}
147
+ ssl_cache = {}
148
+
149
+ tasks = []
150
+ for _, row in df.iterrows():
151
+ tasks.append(asyncio.to_thread(process_row, row, whois_cache, ssl_cache))
152
+
153
+ feature_list = await asyncio.gather(*tasks)
154
+ features_df = pd.DataFrame(feature_list)
155
+ result_df = pd.concat([df, features_df], axis=1)
156
+ return result_df
157
+
158
+ def custom_boundary(raw_score: float, boundary: float) -> float:
159
+ return (raw_score - boundary) * 100
160
+
161
+ def get_model_predictions(features_df: pd.DataFrame, message_text: str) -> Dict:
162
+ predictions = {}
163
+
164
+ numerical_features = config.NUMERICAL_FEATURES
165
+ categorical_features = config.CATEGORICAL_FEATURES
166
+
167
+ try:
168
+ X = features_df[numerical_features + categorical_features]
169
+ except KeyError as e:
170
+ print(f"Error: Missing columns in features_df. {e}")
171
+ print(f"Available columns: {features_df.columns.tolist()}")
172
+ X = pd.DataFrame(columns=numerical_features + categorical_features)
173
+
174
+ if not X.empty:
175
+ X.loc[:, numerical_features] = X.loc[:, numerical_features].fillna(-1)
176
+ X.loc[:, categorical_features] = X.loc[:, categorical_features].fillna('N/A')
177
+
178
+ for model_name, model in ml_models.items():
179
+ try:
180
+ all_probas = model.predict_proba(X)[:, 1]
181
+ raw_score = np.max(all_probas)
182
+
183
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
184
+ predictions[model_name] = {
185
+ 'raw_score': float(raw_score),
186
+ 'scaled_score': float(scaled_score)
187
+ }
188
+ except Exception as e:
189
+ print(f"Error with {model_name} (Prediction Step): {e}")
190
+
191
+ X_numerical = X[numerical_features].values
192
+
193
+ for model_name, model in dl_models.items():
194
+ try:
195
+ X_tensor = torch.tensor(X_numerical, dtype=torch.float32)
196
+ with torch.no_grad():
197
+ all_scores = model(X_tensor)
198
+ raw_score = torch.max(all_scores).item()
199
+
200
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
201
+ predictions[model_name] = {
202
+ 'raw_score': float(raw_score),
203
+ 'scaled_score': float(scaled_score)
204
+ }
205
+ except Exception as e:
206
+ print(f"Error with {model_name}: {e}")
207
+
208
+ if bert_model and len(features_df) > 0:
209
+ try:
210
+ urls = features_df['url'].tolist()
211
+ raw_scores = bert_model.predict_proba(urls)
212
+ avg_raw_score = np.mean([score[1] for score in raw_scores])
213
+ scaled_score = custom_boundary(avg_raw_score, MODEL_BOUNDARIES['bert'])
214
+ predictions['bert'] = {
215
+ 'raw_score': float(avg_raw_score),
216
+ 'scaled_score': float(scaled_score)
217
+ }
218
+ except Exception as e:
219
+ print(f"Error with BERT: {e}")
220
+
221
+ if semantic_model and message_text:
222
+ try:
223
+ result = semantic_model.predict(message_text)
224
+ raw_score = result['phishing_probability']
225
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES['semantic'])
226
+ predictions['semantic'] = {
227
+ 'raw_score': float(raw_score),
228
+ 'scaled_score': float(scaled_score),
229
+ 'confidence': result['confidence']
230
+ }
231
+ except Exception as e:
232
+ print(f"Error with semantic model: {e}")
233
+
234
+ return predictions
235
+
236
+ async def get_network_features_for_gemini(urls: List[str]) -> str:
237
+ if not urls:
238
+ return "No URLs to analyze for network features."
239
+
240
+ results = []
241
+ async with httpx.AsyncClient() as client:
242
+ for i, url_str in enumerate(urls[:3]):
243
+ try:
244
+ hostname = urlparse(url_str).hostname
245
+ if not hostname:
246
+ results.append(f"\nURL {i+1} ({url_str}): Invalid URL, no hostname.")
247
+ continue
248
+
249
+ try:
250
+ ip_address = await asyncio.to_thread(socket.gethostbyname, hostname)
251
+ except socket.gaierror:
252
+ results.append(f"\nURL {i+1} ({hostname}): Could not resolve domain to IP.")
253
+ continue
254
+
255
+ try:
256
+ geo_url = f"http://ip-api.com/json/{ip_address}?fields=status,message,country,city,isp,org,as"
257
+ response = await client.get(geo_url, timeout=3.0)
258
+ response.raise_for_status()
259
+ data = response.json()
260
+
261
+ if data.get('status') == 'success':
262
+ geo_info = (
263
+ f" • IP Address: {ip_address}\n"
264
+ f" • Location: {data.get('city', 'N/A')}, {data.get('country', 'N/A')}\n"
265
+ f" • ISP: {data.get('isp', 'N/A')}\n"
266
+ f" • Organization: {data.get('org', 'N/A')}\n"
267
+ f" • ASN: {data.get('as', 'N/A')}"
268
+ )
269
+ results.append(f"\nURL {i+1} ({hostname}):\n{geo_info}")
270
+ else:
271
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: API lookup failed ({data.get('message')})")
272
+
273
+ except httpx.RequestError as e:
274
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: Network error while fetching IP info ({str(e)})")
275
+
276
+ except Exception as e:
277
+ results.append(f"\nURL {i+1} ({url_str}): Error processing URL ({str(e)})")
278
+
279
+ if not results:
280
+ return "No valid hostnames found in URLs to analyze."
281
+
282
+ return "\n".join(results)
283
+
284
+ SYSTEM_PROMPT = """You are the FINAL JUDGE in a phishing detection system. Your role is critical: analyze ALL available evidence and make the ultimate decision.
285
+
286
+ IMPORTANT INSTRUCTIONS:
287
+ 1. You have FULL AUTHORITY to override model predictions if evidence suggests they're wrong.
288
+ 2. **TRUST THE 'INDEPENDENT NETWORK & GEO-DATA' OVER 'URL FEATURES'.** The ML model features (like `domain_age: -1`) can be wrong due to lookup failures. The 'INDEPENDENT' data is a real-time check.
289
+ 3. If 'INDEPENDENT' data shows a legitimate organization (e.g., "Cloudflare", "Google", "Codeforces") for a known domain, but the models score it as phishing (due to `domain_age: -1`), you **should override** and classify as 'legitimate'.
290
+ 4. Your confidence score is DIRECTIONAL (0-100):
291
+ - Scores > 50.0 mean 'phishing'.
292
+ - Scores < 50.0 mean 'legitimate'.
293
+ - 50.0 is neutral.
294
+ - The magnitude indicates certainty (e.g., 95.0 is 'very confident phishing'; 5.0 is 'very confident legitimate').
295
+ - Your confidence score MUST match your 'final_decision'.
296
+ 5. BE WARY OF FALSE POSITIVES. Legitimate messages (bank alerts, contest notifications) can seem urgent.
297
+
298
+ PRIORITY GUIDANCE (Use this logic):
299
+ - IF URLs are present: Focus heavily on URL features.
300
+ - Examine 'URL FEATURES' for patterns (e.g., domain_age: -1 or 0, high special_chars).
301
+ - **CRITICAL:** Cross-reference this with the 'INDEPENDENT NETWORK & GEO-DATA'. This real-time data (IP, Location, ISP) is your ground truth.
302
+ - **If `domain_age` is -1, it's a lookup failure.** IGNORE IT and trust the 'INDEPENDENT NETWORK & GEO-DATA' to see if the domain is real (e.g., 'codeforces.com' with a valid IP).
303
+ - Then supplement with message content analysis.
304
+ - IF NO URLs are present: Focus entirely on message content and semantics.
305
+ - Analyze language patterns, urgency tactics, and social engineering techniques
306
+ - Look for credential requests, financial solicitations, or threats
307
+ - Evaluate the semantic model's assessment heavily
308
+
309
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
310
+ FEW-SHOT EXAMPLES FOR GUIDANCE:
311
+
312
+ Example 1 - Clear Phishing:
313
+ Message: "URGENT! Click: http://paypa1-secure.xyz/verify"
314
+ URL Features: domain_age: 5
315
+ Network Data: IP: 123.45.67.89, Location: Russia, ISP: Shady-Host
316
+ Model Scores: All positive
317
+ Correct Decision: {{
318
+ "confidence": 95.0,
319
+ "reasoning": "Classic phishing. Misspelled domain, new age, and network data points to a suspicious ISP in Russia.",
320
+ "highlighted_text": "URGENT! Click: $$http://paypa1-secure.xyz/verify$$",
321
+ "final_decision": "phishing",
322
+ "suggestion": "Do NOT click. Delete immediately."
323
+ }}
324
+
325
+ Example 2 - Legitimate (False Positive Case):
326
+ Message: "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/..."
327
+ URL Features: domain_age: -1 (This is a lookup failure!)
328
+ Network Data: URL (codeforces.com): IP: 104.22.6.109, Location: San Francisco, USA, ISP: Cloudflare, Inc.
329
+ Model Scores: Mixed (some positive due to domain_age: -1)
330
+ Correct Decision: {{
331
+ "confidence": 10.0,
332
+ "reasoning": "OVERRIDING models. The 'URL FEATURES' show a 'domain_age: -1' which is a clear lookup error that confused the models. The 'INDEPENDENT NETWORK & GEO-DATA' confirms the domain 'codeforces.com' is real and hosted on Cloudflare, a legitimate provider. The message content is a standard, safe notification.",
333
+ "highlighted_text": "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/...",
334
+ "final_decision": "legitimate",
335
+ "suggestion": "This message is safe. It is a legitimate notification from Codeforces."
336
+ }}
337
+
338
+ Example 3 - Legitimate (Long Formal Text):
339
+ Message: "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]"
340
+ URL Features: domain_age: 8414
341
+ Network Data: URL (cars.tatamotors.com): IP: 23.209.113.12, Location: Boardman, USA, ISP: Akamai Technologies
342
+ Model Scores: All negative
343
+ Correct Decision: {{
344
+ "confidence": 5.0,
345
+ "reasoning": "This is a legitimate corporate communication. The text, although truncated, is clearly a formal guidance note for shareholders. The network data confirms 'cars.tatamotors.com' is hosted on Akamai, a major CDN used by large corporations. The models correctly identify this as safe.",
346
+ "highlighted_text": "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]",
347
+ "final_decision": "legitimate",
348
+ "suggestion": "This message is a legitimate corporate communication and appears safe."
349
+ }}
350
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
351
+
352
+ YOUR ANALYSIS TASK:
353
+ Analyze the message data provided by the user (in the 'user' message) following the steps and logic outlined above.
354
+
355
+ **CRITICAL for `highlighted_text`:** You MUST return the *entire original message*. Only wrap the specific words/URLs that are suspicious with `$$...$$`. If nothing is suspicious (i.e., `final_decision` is 'legitimate'), return the original message with NO `$$` markers.
356
+
357
+ OUTPUT FORMAT (respond with ONLY this JSON, no markdown, no explanation):
358
+ {{
359
+ "confidence": <float (0-100, directional score where >50 is phishing)>,
360
+ "reasoning": "<your detailed analysis explaining why this is/isn't phishing, mentioning why you trust/override models>",
361
+ "highlighted_text": "<THE FULL, ENTIRE original message with suspicious parts marked as $$suspicious text$$>",
362
+ "final_decision": "phishing" or "legitimate",
363
+ "suggestion": "<specific, actionable advice for the user on how to handle this message - what to do or not do>"
364
+ }}"""
365
+
366
+ async def get_groq_final_decision(urls: List[str], features_df: pd.DataFrame,
367
+ message_text: str, predictions: Dict,
368
+ original_text: str,
369
+ sender: Optional[str] = "",
370
+ subject: Optional[str] = "") -> Dict:
371
+
372
+ if not groq_async_client:
373
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
374
+ confidence = min(100, max(0, 50 + avg_scaled_score))
375
+ final_decision = "phishing" if confidence > 50 else "legitimate"
376
+
377
+ return {
378
+ "confidence": round(confidence, 2),
379
+ "reasoning": f"Groq API not available. Using average model scores. (Avg Scaled Score: {avg_scaled_score:.2f})",
380
+ "highlighted_text": original_text,
381
+ "final_decision": final_decision,
382
+ "suggestion": "Do not interact with this message. Delete it immediately and report it to your IT department." if final_decision == "phishing" else "This message appears safe, but remain cautious with any links or attachments."
383
+ }
384
+
385
+ url_features_summary = "No URLs detected in message"
386
+ if len(features_df) > 0:
387
+ feature_summary_parts = []
388
+ for idx, row in features_df.iterrows():
389
+ url = row.get('url', 'Unknown')
390
+ feature_summary_parts.append(f"\nURL {idx+1}: {url}")
391
+ feature_summary_parts.append(f" • Length: {row.get('url_length', 'N/A')} chars")
392
+ feature_summary_parts.append(f" • Dots in URL: {row.get('count_dot', 'N/A')}")
393
+ feature_summary_parts.append(f" • Special characters: {row.get('count_special_chars', 'N/A')}")
394
+ feature_summary_parts.append(f" • Domain age: {row.get('domain_age_days', 'N/A')} days")
395
+ feature_summary_parts.append(f" • SSL certificate valid: {row.get('cert_has_valid_hostname', 'N/A')}")
396
+ feature_summary_parts.append(f" • Uses HTTPS: {row.get('https', 'N/A')}")
397
+ url_features_summary = "\n".join(feature_summary_parts)
398
+
399
+ network_features_summary = await get_network_features_for_gemini(urls)
400
+
401
+ model_predictions_summary = []
402
+ for model_name, pred_data in predictions.items():
403
+ scaled = pred_data['scaled_score']
404
+ raw = pred_data['raw_score']
405
+ model_predictions_summary.append(
406
+ f" • {model_name.upper()}: scaled_score={scaled:.2f} (raw={raw:.3f})"
407
+ )
408
+ model_scores_text = "\n".join(model_predictions_summary)
409
+
410
+ MAX_TEXT_LEN = 3000
411
+ if len(original_text) > MAX_TEXT_LEN:
412
+ truncated_original_text = original_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
413
+ else:
414
+ truncated_original_text = original_text
415
+
416
+ if len(message_text) > MAX_TEXT_LEN:
417
+ truncated_message_text = message_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
418
+ else:
419
+ truncated_message_text = message_text
420
+
421
+ user_prompt = f"""MESSAGE DATA:
422
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
423
+ Sender: {sender if sender else 'N/A'}
424
+ Subject: {subject if subject else 'N/A'}
425
+
426
+ Original Message (Parsed from HTML):
427
+ {truncated_original_text}
428
+
429
+ Cleaned Text (for models):
430
+ {truncated_message_text}
431
+
432
+ URLs Found: {', '.join(urls) if urls else 'None'}
433
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
434
+
435
+ URL FEATURES (from ML models):
436
+ {url_features_summary}
437
+
438
+ INDEPENDENT NETWORK & GEO-DATA (for Gemini analysis only):
439
+ {network_features_summary}
440
+
441
+ MODEL PREDICTIONS:
442
+ (Positive scaled scores → phishing, Negative → legitimate. Range: -50 to +50)
443
+ {model_scores_text}
444
+
445
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
446
+ Please analyze this data and provide your JSON response."""
447
+
448
+ try:
449
+ max_retries = 3
450
+ retry_delay = 2
451
+
452
+ for attempt in range(max_retries):
453
+ try:
454
+ chat_completion = await groq_async_client.chat.completions.create(
455
+ messages=[
456
+ {
457
+ "role": "system",
458
+ "content": SYSTEM_PROMPT,
459
+ },
460
+ {
461
+ "role": "user",
462
+ "content": user_prompt,
463
+ }
464
+ ],
465
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
466
+ temperature=0.2,
467
+ max_tokens=4096,
468
+ top_p=0.85,
469
+ response_format={"type": "json_object"},
470
+ )
471
+
472
+ response_text = chat_completion.choices[0].message.content
473
+ break
474
+
475
+ except Exception as retry_error:
476
+ print(f"Groq API attempt {attempt + 1} failed: {retry_error}")
477
+ if attempt < max_retries - 1:
478
+ print(f"Retrying in {retry_delay}s...")
479
+ await asyncio.sleep(retry_delay)
480
+ retry_delay *= 2
481
+ else:
482
+ raise retry_error
483
+
484
+ result = json.loads(response_text)
485
+
486
+ required_fields = ['confidence', 'reasoning', 'highlighted_text', 'final_decision', 'suggestion']
487
+ if not all(field in result for field in required_fields):
488
+ raise ValueError(f"Missing required fields. Got: {list(result.keys())}")
489
+
490
+ result['confidence'] = float(result['confidence'])
491
+ if not 0 <= result['confidence'] <= 100:
492
+ result['confidence'] = max(0, min(100, result['confidence']))
493
+
494
+ if result['final_decision'].lower() not in ['phishing', 'legitimate']:
495
+ result['final_decision'] = 'phishing' if result['confidence'] > 50 else 'legitimate'
496
+ else:
497
+ result['final_decision'] = result['final_decision'].lower()
498
+
499
+ if result['final_decision'] == 'phishing' and result['confidence'] < 50:
500
+ print(f"Warning: Groq decision 'phishing' mismatches confidence {result['confidence']}. Adjusting confidence.")
501
+ result['confidence'] = 51.0
502
+ elif result['final_decision'] == 'legitimate' and result['confidence'] > 50:
503
+ print(f"Warning: Groq decision 'legitimate' mismatches confidence {result['confidence']}. Adjusting confidence.")
504
+ result['confidence'] = 49.0
505
+
506
+ if not result['highlighted_text'].strip() or '...' in result['highlighted_text'] or 'TRUNCATED' in result['highlighted_text']:
507
+ print("Warning: Groq returned empty or truncated 'highlighted_text'. Falling back to original_text.")
508
+ result['highlighted_text'] = original_text
509
+
510
+ if not result.get('suggestion', '').strip():
511
+ if result['final_decision'] == 'phishing':
512
+ result['suggestion'] = "Do not interact with this message. Delete it immediately and report it as phishing."
513
+ else:
514
+ result['suggestion'] = "This message appears safe, but always verify sender identity before taking any action."
515
+
516
+ return result
517
+
518
+ except json.JSONDecodeError as e:
519
+ print(f"JSON parsing error: {e}")
520
+ print(f"Response text that failed parsing: {response_text[:500]}")
521
+
522
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
523
+ confidence = min(100, max(0, 50 + avg_scaled_score))
524
+ final_decision = "phishing" if confidence > 50 else "legitimate"
525
+
526
+ return {
527
+ "confidence": round(confidence, 2),
528
+ "reasoning": f"Groq response parsing failed. Fallback: Based on model average (directional score: {confidence:.2f}), message appears {'suspicious' if final_decision == 'phishing' else 'legitimate'}.",
529
+ "highlighted_text": original_text,
530
+ "final_decision": final_decision,
531
+ "suggestion": "Do not interact with this message. Delete it immediately and be cautious." if final_decision == 'phishing' else "Exercise caution. Verify the sender before taking any action."
532
+ }
533
+
534
+ except Exception as e:
535
+ print(f"Error with Groq API: {e}")
536
+
537
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
538
+ confidence = min(100, max(0, 50 + avg_scaled_score))
539
+ final_decision = "phishing" if confidence > 50 else "legitimate"
540
+
541
+ return {
542
+ "confidence": round(confidence, 2),
543
+ "reasoning": f"Groq API error: {str(e)}. Fallback decision based on {len(predictions)} model predictions (average directional score: {confidence:.2f}).",
544
+ "highlighted_text": original_text,
545
+ "final_decision": final_decision,
546
+ "suggestion": "Treat this message with caution. Delete it if suspicious, or verify the sender through official channels before taking action." if final_decision == 'phishing' else "This message appears safe based on models, but always verify sender identity before clicking links or providing information."
547
+ }
548
+
549
+ @app.on_event("startup")
550
+ async def startup_event():
551
+ load_models()
552
+ print("\n" + "="*60)
553
+ print("Phishing Detection API is ready!")
554
+ print("="*60)
555
+ print("API Documentation: http://localhost:8000/docs")
556
+ print("="*60 + "\n")
557
+
558
+ @app.get("/")
559
+ async def root():
560
+ return {
561
+ "message": "Phishing Detection API",
562
+ "version": "1.0.0",
563
+ "endpoints": {
564
+ "predict": "/predict (POST)",
565
+ "health": "/health (GET)",
566
+ "docs": "/docs (GET)"
567
+ }
568
+ }
569
+
570
+ @app.get("/health")
571
+ async def health_check():
572
+ models_loaded = {
573
+ "ml_models": list(ml_models.keys()),
574
+ "dl_models": list(dl_models.keys()),
575
+ "bert_model": bert_model is not None,
576
+ "semantic_model": semantic_model is not None,
577
+ "groq_client": groq_async_client is not None
578
+ }
579
+
580
+ return {
581
+ "status": "healthy",
582
+ "models_loaded": models_loaded
583
+ }
584
+
585
+ @app.post("/predict", response_model=PredictionResponse)
586
+ async def predict(message_input: MessageInput):
587
+ try:
588
+ html_body = message_input.text
589
+ sender = message_input.sender
590
+ subject = message_input.subject
591
+
592
+ soup = BeautifulSoup(html_body, 'html.parser')
593
+ original_text = soup.get_text(separator=' ', strip=True)
594
+
595
+ if not original_text or not original_text.strip():
596
+ raise HTTPException(status_code=400, detail="Message text (after HTML parsing) cannot be empty")
597
+
598
+ urls, cleaned_text = parse_message(original_text)
599
+
600
+ features_df = pd.DataFrame()
601
+ if urls:
602
+ features_df = await extract_url_features(urls)
603
+
604
+ predictions = {}
605
+ if len(features_df) > 0 or (cleaned_text and semantic_model):
606
+ predictions = await asyncio.to_thread(get_model_predictions, features_df, cleaned_text)
607
+
608
+ if not predictions:
609
+ if not urls and not cleaned_text:
610
+ detail = "Message text is empty after cleaning."
611
+ elif not urls and not semantic_model:
612
+ detail = "No URLs provided and semantic model is not loaded."
613
+ elif not any([ml_models, dl_models, bert_model, semantic_model]):
614
+ detail = "No models available for prediction. Please ensure models are trained and loaded."
615
+ else:
616
+ detail = "Could not generate predictions. Models may be missing or feature extraction failed."
617
+
618
+ raise HTTPException(
619
+ status_code=500,
620
+ detail=detail
621
+ )
622
+
623
+ final_result = await get_groq_final_decision(
624
+ urls, features_df, cleaned_text, predictions, original_text,
625
+ sender, subject
626
+ )
627
+
628
+ return PredictionResponse(**final_result)
629
+
630
+ except HTTPException:
631
+ raise
632
+ except Exception as e:
633
+ import traceback
634
+ print(traceback.format_exc())
635
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
636
+
637
+ if __name__ == "__main__":
638
+ import uvicorn
639
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app2.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import time
5
+ import sys
6
+ import asyncio
7
+ from typing import List, Dict, Optional
8
+ from urllib.parse import urlparse
9
+ import socket
10
+ import httpx
11
+
12
+ import joblib
13
+ import torch
14
+ import numpy as np
15
+ import pandas as pd
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ from groq import AsyncGroq
20
+ from dotenv import load_dotenv
21
+
22
+ # --- Make sure 'config.py' and 'models.py' are in the same directory or accessible
23
+ import config
24
+ from models import get_ml_models, get_dl_models, FinetunedBERT
25
+ from feature_extraction import process_row
26
+
27
+ load_dotenv()
28
+ sys.path.append(os.path.join(config.BASE_DIR, 'Message_model'))
29
+ from predict import PhishingPredictor
30
+
31
+ app = FastAPI(
32
+ title="Phishing Detection API",
33
+ description="Advanced phishing detection system using multiple ML/DL models and Groq",
34
+ version="1.0.0"
35
+ )
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"], # Allows all origins
40
+ allow_credentials=True,
41
+ allow_methods=["*"], # Allows all methods
42
+ allow_headers=["*"], # Allows all headers
43
+ )
44
+
45
+ # --- Pydantic Models ---
46
+
47
+ class MessageInput(BaseModel):
48
+ text: str
49
+ metadata: Optional[Dict] = {}
50
+
51
+ class PredictionResponse(BaseModel):
52
+ confidence: float
53
+ reasoning: str
54
+ highlighted_text: str
55
+ final_decision: str
56
+ suggestion: str
57
+
58
+ # --- Global Variables ---
59
+
60
+ ml_models = {}
61
+ dl_models = {}
62
+ bert_model = None
63
+ semantic_model = None
64
+ groq_async_client = None
65
+
66
+ MODEL_BOUNDARIES = {
67
+ 'logistic': 0.5,
68
+ 'svm': 0.5,
69
+ 'xgboost': 0.5,
70
+ 'attention_blstm': 0.5,
71
+ 'rcnn': 0.5,
72
+ 'bert': 0.5,
73
+ 'semantic': 0.5
74
+ }
75
+
76
+ # --- Model Loading ---
77
+
78
+ def load_models():
79
+ global ml_models, dl_models, bert_model, semantic_model, groq_async_client
80
+
81
+ print("Loading models...")
82
+
83
+ models_dir = config.MODELS_DIR
84
+ for model_name in ['logistic', 'svm', 'xgboost']:
85
+ model_path = os.path.join(models_dir, f'{model_name}.joblib')
86
+ if os.path.exists(model_path):
87
+ ml_models[model_name] = joblib.load(model_path)
88
+ print(f"✓ Loaded {model_name} model")
89
+ else:
90
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
91
+
92
+ for model_name in ['attention_blstm', 'rcnn']:
93
+ model_path = os.path.join(models_dir, f'{model_name}.pt')
94
+ if os.path.exists(model_path):
95
+ model_template = get_dl_models(input_dim=len(config.NUMERICAL_FEATURES))
96
+ dl_models[model_name] = model_template[model_name]
97
+ dl_models[model_name].load_state_dict(torch.load(model_path, map_location='cpu'))
98
+ dl_models[model_name].eval()
99
+ print(f"✓ Loaded {model_name} model")
100
+ else:
101
+ print(f"⚠ Warning: {model_name} model not found at {model_path}")
102
+
103
+ bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert')
104
+ if os.path.exists(bert_path):
105
+ try:
106
+ bert_model = FinetunedBERT(bert_path)
107
+ print("✓ Loaded BERT model")
108
+ except Exception as e:
109
+ print(f"⚠ Warning: Could not load BERT model: {e}")
110
+
111
+ semantic_model_path = os.path.join(config.BASE_DIR, 'Message_model', 'final_semantic_model')
112
+ if os.path.exists(semantic_model_path) and os.listdir(semantic_model_path):
113
+ try:
114
+ semantic_model = PhishingPredictor(model_path=semantic_model_path)
115
+ print("✓ Loaded semantic model")
116
+ except Exception as e:
117
+ print(f"⚠ Warning: Could not load semantic model: {e}")
118
+ else:
119
+ checkpoint_path = os.path.join(config.BASE_DIR, 'Message_model', 'training_checkpoints', 'checkpoint-30')
120
+ if os.path.exists(checkpoint_path):
121
+ try:
122
+ semantic_model = PhishingPredictor(model_path=checkpoint_path)
123
+ print("✓ Loaded semantic model from checkpoint")
124
+ except Exception as e:
125
+ print(f"⚠ Warning: Could not load semantic model from checkpoint: {e}")
126
+
127
+ groq_api_key = os.environ.get('GROQ_API_KEY')
128
+ if groq_api_key:
129
+ groq_async_client = AsyncGroq(api_key=groq_api_key)
130
+ print("✓ Initialized Groq API Client")
131
+ else:
132
+ print("⚠ Warning: GROQ_API_KEY not set. Set it as environment variable.")
133
+ print(" Example: export GROQ_API_KEY='your-api-key-here'")
134
+
135
+ # --- Feature Extraction & Prediction Logic ---
136
+
137
+ def parse_message(text: str) -> tuple:
138
+ url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)?[a-zA-Z0-9-]+\.[a-z]{2,12}\b(?:/[^\s]*)?'
139
+ urls = re.findall(url_pattern, text)
140
+ cleaned_text = re.sub(url_pattern, '', text)
141
+ cleaned_text = ' '.join(cleaned_text.lower().split())
142
+ cleaned_text = re.sub(r'[^a-z0-9\s.,!?-]', '', cleaned_text)
143
+ cleaned_text = re.sub(r'([.,!?])+', r'\1', cleaned_text)
144
+ cleaned_text = ' '.join(cleaned_text.split())
145
+ return urls, cleaned_text
146
+
147
+ async def extract_url_features(urls: List[str]) -> pd.DataFrame:
148
+ if not urls:
149
+ return pd.DataFrame()
150
+
151
+ df = pd.DataFrame({'url': urls})
152
+ whois_cache = {}
153
+ ssl_cache = {}
154
+
155
+ tasks = []
156
+ for _, row in df.iterrows():
157
+ tasks.append(asyncio.to_thread(process_row, row, whois_cache, ssl_cache))
158
+
159
+ feature_list = await asyncio.gather(*tasks)
160
+ features_df = pd.DataFrame(feature_list)
161
+ result_df = pd.concat([df, features_df], axis=1)
162
+ return result_df
163
+
164
+ def custom_boundary(raw_score: float, boundary: float) -> float:
165
+ # --- MODIFIED: This now returns a score from -50 to +50 ---
166
+ return (raw_score - boundary) * 100
167
+
168
+ def get_model_predictions(features_df: pd.DataFrame, message_text: str) -> Dict:
169
+ predictions = {}
170
+
171
+ numerical_features = config.NUMERICAL_FEATURES
172
+ categorical_features = config.CATEGORICAL_FEATURES
173
+
174
+ try:
175
+ X = features_df[numerical_features + categorical_features]
176
+ except KeyError as e:
177
+ print(f"Error: Missing columns in features_df. {e}")
178
+ print(f"Available columns: {features_df.columns.tolist()}")
179
+ X = pd.DataFrame(columns=numerical_features + categorical_features)
180
+
181
+ if not X.empty:
182
+ X.loc[:, numerical_features] = X.loc[:, numerical_features].fillna(-1)
183
+ X.loc[:, categorical_features] = X.loc[:, categorical_features].fillna('N/A')
184
+
185
+ for model_name, model in ml_models.items():
186
+ try:
187
+ all_probas = model.predict_proba(X)[:, 1]
188
+ raw_score = np.max(all_probas)
189
+
190
+ # --- MODIFIED: 'scaled_score' is now from -50 (legit) to +50 (phishing) ---
191
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
192
+ predictions[model_name] = {
193
+ 'raw_score': float(raw_score),
194
+ 'scaled_score': float(scaled_score)
195
+ }
196
+ except Exception as e:
197
+ print(f"Error with {model_name} (Prediction Step): {e}")
198
+
199
+ X_numerical = X[numerical_features].values
200
+
201
+ for model_name, model in dl_models.items():
202
+ try:
203
+ X_tensor = torch.tensor(X_numerical, dtype=torch.float32)
204
+ with torch.no_grad():
205
+ all_scores = model(X_tensor)
206
+ raw_score = torch.max(all_scores).item()
207
+
208
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name])
209
+ predictions[model_name] = {
210
+ 'raw_score': float(raw_score),
211
+ 'scaled_score': float(scaled_score)
212
+ }
213
+ except Exception as e:
214
+ print(f"Error with {model_name}: {e}")
215
+
216
+ if bert_model and len(features_df) > 0:
217
+ try:
218
+ urls = features_df['url'].tolist()
219
+ raw_scores = bert_model.predict_proba(urls)
220
+ avg_raw_score = np.mean([score[1] for score in raw_scores])
221
+ scaled_score = custom_boundary(avg_raw_score, MODEL_BOUNDARIES['bert'])
222
+ predictions['bert'] = {
223
+ 'raw_score': float(avg_raw_score),
224
+ 'scaled_score': float(scaled_score)
225
+ }
226
+ except Exception as e:
227
+ print(f"Error with BERT: {e}")
228
+
229
+ if semantic_model and message_text:
230
+ try:
231
+ result = semantic_model.predict(message_text)
232
+ raw_score = result['phishing_probability']
233
+ scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES['semantic'])
234
+ predictions['semantic'] = {
235
+ 'raw_score': float(raw_score),
236
+ 'scaled_score': float(scaled_score),
237
+ 'confidence': result['confidence'] # Note: this is the semantic model's own confidence
238
+ }
239
+ except Exception as e:
240
+ print(f"Error with semantic model: {e}")
241
+
242
+ return predictions
243
+
244
+ # --- Groq/LLM Final Decision Logic ---
245
+
246
+ async def get_network_features_for_gemini(urls: List[str]) -> str:
247
+ """
248
+ Fetches real-time IP, Geo, and ISP data for URLs.
249
+ This runs independently and is ONLY used to inform the LLM prompt.
250
+ """
251
+ if not urls:
252
+ return "No URLs to analyze for network features."
253
+
254
+ results = []
255
+ async with httpx.AsyncClient() as client:
256
+ for i, url_str in enumerate(urls[:3]):
257
+ try:
258
+ hostname = urlparse(url_str).hostname
259
+ if not hostname:
260
+ results.append(f"\nURL {i+1} ({url_str}): Invalid URL, no hostname.")
261
+ continue
262
+
263
+ try:
264
+ ip_address = await asyncio.to_thread(socket.gethostbyname, hostname)
265
+ except socket.gaierror:
266
+ results.append(f"\nURL {i+1} ({hostname}): Could not resolve domain to IP.")
267
+ continue
268
+
269
+ try:
270
+ geo_url = f"http://ip-api.com/json/{ip_address}?fields=status,message,country,city,isp,org,as"
271
+ response = await client.get(geo_url, timeout=3.0)
272
+ response.raise_for_status()
273
+ data = response.json()
274
+
275
+ if data.get('status') == 'success':
276
+ geo_info = (
277
+ f" • IP Address: {ip_address}\n"
278
+ f" • Location: {data.get('city', 'N/A')}, {data.get('country', 'N/A')}\n"
279
+ f" • ISP: {data.get('isp', 'N/A')}\n"
280
+ f" • Organization: {data.get('org', 'N/A')}\n"
281
+ f" • ASN: {data.get('as', 'N/A')}"
282
+ )
283
+ results.append(f"\nURL {i+1} ({hostname}):\n{geo_info}")
284
+ else:
285
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: API lookup failed ({data.get('message')})")
286
+
287
+ except httpx.RequestError as e:
288
+ results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: Network error while fetching IP info ({str(e)})")
289
+
290
+ except Exception as e:
291
+ results.append(f"\nURL {i+1} ({url_str}): Error processing URL ({str(e)})")
292
+
293
+ if not results:
294
+ return "No valid hostnames found in URLs to analyze."
295
+
296
+ return "\n".join(results)
297
+
298
+ # --- CORRECTED: Static system prompt with fixed examples ---
299
+ # This contains all the instructions, few-shot examples, and output format.
300
+ SYSTEM_PROMPT = """You are the FINAL JUDGE in a phishing detection system. Your role is critical: analyze ALL available evidence and make the ultimate decision.
301
+
302
+ IMPORTANT INSTRUCTIONS:
303
+ 1. You have FULL AUTHORITY to override model predictions if evidence suggests they're wrong.
304
+ 2. **TRUST THE 'INDEPENDENT NETWORK & GEO-DATA' OVER 'URL FEATURES'.** The ML model features (like `domain_age: -1`) can be wrong due to lookup failures. The 'INDEPENDENT' data is a real-time check.
305
+ 3. If 'INDEPENDENT' data shows a legitimate organization (e.g., "Cloudflare", "Google", "Codeforces") for a known domain, but the models score it as phishing (due to `domain_age: -1`), you **should override** and classify as 'legitimate'.
306
+ 4. Your confidence score is DIRECTIONAL (0-100):
307
+ - Scores > 50.0 mean 'phishing'.
308
+ - Scores < 50.0 mean 'legitimate'.
309
+ - 50.0 is neutral.
310
+ - The magnitude indicates certainty (e.g., 95.0 is 'very confident phishing'; 5.0 is 'very confident legitimate').
311
+ - Your confidence score MUST match your 'final_decision'.
312
+ 5. BE WARY OF FALSE POSITIVES. Legitimate messages (bank alerts, contest notifications) can seem urgent.
313
+
314
+ PRIORITY GUIDANCE (Use this logic):
315
+ - IF URLs are present: Focus heavily on URL features.
316
+ - Examine 'URL FEATURES' for patterns (e.g., domain_age: -1 or 0, high special_chars).
317
+ - **CRITICAL:** Cross-reference this with the 'INDEPENDENT NETWORK & GEO-DATA'. This real-time data (IP, Location, ISP) is your ground truth.
318
+ - **If `domain_age` is -1, it's a lookup failure.** IGNORE IT and trust the 'INDEPENDENT NETWORK & GEO-DATA' to see if the domain is real (e.g., 'codeforces.com' with a valid IP).
319
+ - Then supplement with message content analysis.
320
+ - IF NO URLs are present: Focus entirely on message content and semantics.
321
+ - Analyze language patterns, urgency tactics, and social engineering techniques
322
+ - Look for credential requests, financial solicitations, or threats
323
+ - Evaluate the semantic model's assessment heavily
324
+
325
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
326
+ FEW-SHOT EXAMPLES FOR GUIDANCE:
327
+
328
+ Example 1 - Clear Phishing:
329
+ Message: "URGENT! Click: http://paypa1-secure.xyz/verify"
330
+ URL Features: domain_age: 5
331
+ Network Data: IP: 123.45.67.89, Location: Russia, ISP: Shady-Host
332
+ Model Scores: All positive
333
+ Correct Decision: {{
334
+ "confidence": 95.0,
335
+ "reasoning": "Classic phishing. Misspelled domain, new age, and network data points to a suspicious ISP in Russia.",
336
+ "highlighted_text": "URGENT! Click: $$http://paypa1-secure.xyz/verify$$",
337
+ "final_decision": "phishing",
338
+ "suggestion": "Do NOT click. Delete immediately."
339
+ }}
340
+
341
+ Example 2 - Legitimate (False Positive Case):
342
+ Message: "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/..."
343
+ URL Features: domain_age: -1 (This is a lookup failure!)
344
+ Network Data: URL (codeforces.com): IP: 104.22.6.109, Location: San Francisco, USA, ISP: Cloudflare, Inc.
345
+ Model Scores: Mixed (some positive due to domain_age: -1)
346
+ Correct Decision: {{
347
+ "confidence": 10.0,
348
+ "reasoning": "OVERRIDING models. The 'URL FEATURES' show a 'domain_age: -1' which is a clear lookup error that confused the models. The 'INDEPENDENT NETWORK & GEO-DATA' confirms the domain 'codeforces.com' is real and hosted on Cloudflare, a legitimate provider. The message content is a standard, safe notification.",
349
+ "highlighted_text": "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/...",
350
+ "final_decision": "legitimate",
351
+ "suggestion": "This message is safe. It is a legitimate notification from Codeforces."
352
+ }}
353
+
354
+ Example 3 - Legitimate (Long Formal Text):
355
+ Message: "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]"
356
+ URL Features: domain_age: 8414
357
+ Network Data: URL (cars.tatamotors.com): IP: 23.209.113.12, Location: Boardman, USA, ISP: Akamai Technologies
358
+ Model Scores: All negative
359
+ Correct Decision: {{
360
+ "confidence": 5.0,
361
+ "reasoning": "This is a legitimate corporate communication. The text, although truncated, is clearly a formal guidance note for shareholders. The network data confirms 'cars.tatamotors.com' is hosted on Akamai, a major CDN used by large corporations. The models correctly identify this as safe.",
362
+ "highlighted_text": "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]",
363
+ "final_decision": "legitimate",
364
+ "suggestion": "This message is a legitimate corporate communication and appears safe."
365
+ }}
366
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
367
+
368
+ YOUR ANALYSIS TASK:
369
+ Analyze the message data provided by the user (in the 'user' message) following the steps and logic outlined above.
370
+
371
+ **CRITICAL for `highlighted_text`:** You MUST return the *entire original message*. Only wrap the specific words/URLs that are suspicious with `$$...$$`. If nothing is suspicious (i.e., `final_decision` is 'legitimate'), return the original message with NO `$$` markers.
372
+
373
+ OUTPUT FORMAT (respond with ONLY this JSON, no markdown, no explanation):
374
+ {{
375
+ "confidence": <float (0-100, directional score where >50 is phishing)>,
376
+ "reasoning": "<your detailed analysis explaining why this is/isn't phishing, mentioning why you trust/override models>",
377
+ "highlighted_text": "<THE FULL, ENTIRE original message with suspicious parts marked as $$suspicious text$$>",
378
+ "final_decision": "phishing" or "legitimate",
379
+ "suggestion": "<specific, actionable advice for the user on how to handle this message - what to do or not do>"
380
+ }}"""
381
+
382
+
383
+ async def get_groq_final_decision(urls: List[str], features_df: pd.DataFrame,
384
+ message_text: str, predictions: Dict,
385
+ original_text: str) -> Dict:
386
+
387
+ if not groq_async_client:
388
+ # --- MODIFIED: Fallback logic for confidence score ---
389
+ # avg_scaled_score is from -50 (legit) to +50 (phishing)
390
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
391
+ # We add 50 to shift the range to 0-100
392
+ confidence = min(100, max(0, 50 + avg_scaled_score))
393
+ final_decision = "phishing" if confidence > 50 else "legitimate"
394
+
395
+ return {
396
+ "confidence": round(confidence, 2),
397
+ "reasoning": f"Groq API not available. Using average model scores. (Avg Scaled Score: {avg_scaled_score:.2f})",
398
+ "highlighted_text": original_text,
399
+ "final_decision": final_decision,
400
+ "suggestion": "Do not interact with this message. Delete it immediately and report it to your IT department." if final_decision == "phishing" else "This message appears safe, but remain cautious with any links or attachments."
401
+ }
402
+
403
+ url_features_summary = "No URLs detected in message"
404
+ if len(features_df) > 0:
405
+ feature_summary_parts = []
406
+ for idx, row in features_df.iterrows():
407
+ url = row.get('url', 'Unknown')
408
+ feature_summary_parts.append(f"\nURL {idx+1}: {url}")
409
+ feature_summary_parts.append(f" • Length: {row.get('url_length', 'N/A')} chars")
410
+ feature_summary_parts.append(f" • Dots in URL: {row.get('count_dot', 'N/A')}")
411
+ feature_summary_parts.append(f" • Special characters: {row.get('count_special_chars', 'N/A')}")
412
+ feature_summary_parts.append(f" • Domain age: {row.get('domain_age_days', 'N/A')} days")
413
+ feature_summary_parts.append(f" • SSL certificate valid: {row.get('cert_has_valid_hostname', 'N/A')}")
414
+ feature_summary_parts.append(f" • Uses HTTPS: {row.get('https', 'N/A')}")
415
+ url_features_summary = "\n".join(feature_summary_parts)
416
+
417
+ network_features_summary = await get_network_features_for_gemini(urls)
418
+
419
+ model_predictions_summary = []
420
+ for model_name, pred_data in predictions.items():
421
+ scaled = pred_data['scaled_score'] # This is now -50 to +50
422
+ raw = pred_data['raw_score']
423
+ model_predictions_summary.append(
424
+ f" • {model_name.upper()}: scaled_score={scaled:.2f} (raw={raw:.3f})"
425
+ )
426
+ model_scores_text = "\n".join(model_predictions_summary)
427
+
428
+ MAX_TEXT_LEN = 3000
429
+ if len(original_text) > MAX_TEXT_LEN:
430
+ truncated_original_text = original_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
431
+ else:
432
+ truncated_original_text = original_text
433
+
434
+ if len(message_text) > MAX_TEXT_LEN:
435
+ truncated_message_text = message_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]"
436
+ else:
437
+ truncated_message_text = message_text
438
+
439
+ # --- NEW: User prompt only contains dynamic data ---
440
+ user_prompt = f"""MESSAGE DATA:
441
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
442
+ Original Message:
443
+ {truncated_original_text}
444
+
445
+ Cleaned Text:
446
+ {truncated_message_text}
447
+
448
+ URLs Found: {', '.join(urls) if urls else 'None'}
449
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
450
+
451
+ URL FEATURES (from ML models):
452
+ {url_features_summary}
453
+
454
+ INDEPENDENT NETWORK & GEO-DATA (for Gemini analysis only):
455
+ {network_features_summary}
456
+
457
+ MODEL PREDICTIONS:
458
+ (Positive scaled scores → phishing, Negative → legitimate. Range: -50 to +50)
459
+ {model_scores_text}
460
+
461
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
462
+ Please analyze this data and provide your JSON response."""
463
+
464
+ try:
465
+ max_retries = 3
466
+ retry_delay = 2
467
+
468
+ for attempt in range(max_retries):
469
+ try:
470
+ # --- MODIFIED: API call now uses system and user roles ---
471
+ chat_completion = await groq_async_client.chat.completions.create(
472
+ messages=[
473
+ {
474
+ "role": "system",
475
+ "content": SYSTEM_PROMPT,
476
+ },
477
+ {
478
+ "role": "user",
479
+ "content": user_prompt,
480
+ }
481
+ ],
482
+ model="meta-llama/llama-4-scout-17b-16e-instruct", # Using 8B for speed, can be 70b
483
+ temperature=0.2,
484
+ max_tokens=4096,
485
+ top_p=0.85,
486
+ response_format={"type": "json_object"},
487
+ )
488
+
489
+ response_text = chat_completion.choices[0].message.content
490
+ break # Success
491
+
492
+ except Exception as retry_error:
493
+ print(f"Groq API attempt {attempt + 1} failed: {retry_error}")
494
+ if attempt < max_retries - 1:
495
+ print(f"Retrying in {retry_delay}s...")
496
+ await asyncio.sleep(retry_delay)
497
+ retry_delay *= 2
498
+ else:
499
+ raise retry_error # Raise the final error
500
+
501
+ result = json.loads(response_text)
502
+
503
+ required_fields = ['confidence', 'reasoning', 'highlighted_text', 'final_decision', 'suggestion']
504
+ if not all(field in result for field in required_fields):
505
+ raise ValueError(f"Missing required fields. Got: {list(result.keys())}")
506
+
507
+ result['confidence'] = float(result['confidence'])
508
+ if not 0 <= result['confidence'] <= 100:
509
+ result['confidence'] = max(0, min(100, result['confidence']))
510
+
511
+ if result['final_decision'].lower() not in ['phishing', 'legitimate']:
512
+ # --- MODIFIED: Decision is based on the directional confidence score ---
513
+ result['final_decision'] = 'phishing' if result['confidence'] > 50 else 'legitimate'
514
+ else:
515
+ result['final_decision'] = result['final_decision'].lower()
516
+
517
+ # --- MODIFIED: Check that confidence and decision match ---
518
+ if result['final_decision'] == 'phishing' and result['confidence'] < 50:
519
+ print(f"Warning: Groq decision 'phishing' mismatches confidence {result['confidence']}. Adjusting confidence.")
520
+ result['confidence'] = 51.0 # Set to a default phishing score
521
+ elif result['final_decision'] == 'legitimate' and result['confidence'] > 50:
522
+ print(f"Warning: Groq decision 'legitimate' mismatches confidence {result['confidence']}. Adjusting confidence.")
523
+ result['confidence'] = 49.0 # Set to a default legitimate score
524
+
525
+ # --- Fallback for empty or truncated highlighted_text ---
526
+ if not result['highlighted_text'].strip() or '...' in result['highlighted_text'] or 'TRUNCATED' in result['highlighted_text']:
527
+ print("Warning: Groq returned empty or truncated 'highlighted_text'. Falling back to original_text.")
528
+ result['highlighted_text'] = original_text
529
+
530
+ if not result.get('suggestion', '').strip():
531
+ if result['final_decision'] == 'phishing':
532
+ result['suggestion'] = "Do not interact with this message. Delete it immediately and report it as phishing."
533
+ else:
534
+ result['suggestion'] = "This message appears safe, but always verify sender identity before taking any action."
535
+
536
+ return result
537
+
538
+ except json.JSONDecodeError as e:
539
+ print(f"JSON parsing error: {e}")
540
+ print(f"Response text that failed parsing: {response_text[:500]}")
541
+
542
+ # --- MODIFIED: Fallback logic for confidence score ---
543
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
544
+ confidence = min(100, max(0, 50 + avg_scaled_score))
545
+ final_decision = "phishing" if confidence > 50 else "legitimate"
546
+
547
+ return {
548
+ "confidence": round(confidence, 2),
549
+ "reasoning": f"Groq response parsing failed. Fallback: Based on model average (directional score: {confidence:.2f}), message appears {'suspicious' if final_decision == 'phishing' else 'legitimate'}.",
550
+ "highlighted_text": original_text,
551
+ "final_decision": final_decision,
552
+ "suggestion": "Do not interact with this message. Delete it immediately and be cautious." if final_decision == 'phishing' else "Exercise caution. Verify the sender before taking any action."
553
+ }
554
+
555
+ except Exception as e:
556
+ print(f"Error with Groq API: {e}")
557
+
558
+ # --- MODIFIED: Fallback logic for confidence score ---
559
+ avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0
560
+ confidence = min(100, max(0, 50 + avg_scaled_score))
561
+ final_decision = "phishing" if confidence > 50 else "legitimate"
562
+
563
+ return {
564
+ "confidence": round(confidence, 2),
565
+ "reasoning": f"Groq API error: {str(e)}. Fallback decision based on {len(predictions)} model predictions (average directional score: {confidence:.2f}).",
566
+ "highlighted_text": original_text,
567
+ "final_decision": final_decision,
568
+ "suggestion": "Treat this message with caution. Delete it if suspicious, or verify the sender through official channels before taking action." if final_decision == 'phishing' else "This message appears safe based on models, but always verify sender identity before clicking links or providing information."
569
+ }
570
+
571
+ # --- FastAPI Endpoints ---
572
+
573
+ @app.on_event("startup")
574
+ async def startup_event():
575
+ load_models()
576
+ print("\n" + "="*60)
577
+ print("Phishing Detection API is ready!")
578
+ print("="*60)
579
+ print("API Documentation: http://localhost:8000/docs")
580
+ print("="*60 + "\n")
581
+
582
+ @app.get("/")
583
+ async def root():
584
+ return {
585
+ "message": "Phishing Detection API",
586
+ "version": "1.0.0",
587
+ "endpoints": {
588
+ "predict": "/predict (POST)",
589
+ "health": "/health (GET)",
590
+ "docs": "/docs (GET)"
591
+ }
592
+ }
593
+
594
+ @app.get("/health")
595
+ async def health_check():
596
+ models_loaded = {
597
+ "ml_models": list(ml_models.keys()),
598
+ "dl_models": list(dl_models.keys()),
599
+ "bert_model": bert_model is not None,
600
+ "semantic_model": semantic_model is not None,
601
+ "groq_client": groq_async_client is not None
602
+ }
603
+
604
+ return {
605
+ "status": "healthy",
606
+ "models_loaded": models_loaded
607
+ }
608
+
609
+ @app.post("/predict", response_model=PredictionResponse)
610
+ async def predict(message_input: MessageInput):
611
+ try:
612
+ original_text = message_input.text
613
+
614
+ if not original_text or not original_text.strip():
615
+ raise HTTPException(status_code=400, detail="Message text cannot be empty")
616
+
617
+ urls, cleaned_text = parse_message(original_text)
618
+
619
+ features_df = pd.DataFrame()
620
+ if urls:
621
+ features_df = await extract_url_features(urls)
622
+
623
+ predictions = {}
624
+ if len(features_df) > 0 or (cleaned_text and semantic_model):
625
+ # --- MODIFIED: Run this in a thread to avoid blocking ---
626
+ predictions = await asyncio.to_thread(get_model_predictions, features_df, cleaned_text)
627
+
628
+ if not predictions:
629
+ if not urls and not cleaned_text:
630
+ detail = "Message text is empty after cleaning."
631
+ elif not urls and not semantic_model:
632
+ detail = "No URLs provided and semantic model is not loaded."
633
+ elif not any([ml_models, dl_models, bert_model, semantic_model]):
634
+ detail = "No models available for prediction. Please ensure models are trained and loaded."
635
+ else:
636
+ detail = "Could not generate predictions. Models may be missing or feature extraction failed."
637
+
638
+ raise HTTPException(
639
+ status_code=500,
640
+ detail=detail
641
+ )
642
+
643
+ final_result = await get_groq_final_decision(
644
+ urls, features_df, cleaned_text, predictions, original_text
645
+ )
646
+
647
+ return PredictionResponse(**final_result)
648
+
649
+ except HTTPException:
650
+ raise
651
+ except Exception as e:
652
+ import traceback
653
+ print(traceback.format_exc())
654
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
655
+
656
+ if __name__ == "__main__":
657
+ import uvicorn
658
+ uvicorn.run(app, host="0.0.0.0", port=8000)
cloudbuild.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ steps:
2
+ # Build the Docker image
3
+ - name: 'gcr.io/cloud-builders/docker'
4
+ args:
5
+ - 'build'
6
+ - '-t'
7
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:$SHORT_SHA'
8
+ - '-t'
9
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:latest'
10
+ - '.'
11
+ dir: '.'
12
+
13
+ # Push the Docker image to Container Registry
14
+ - name: 'gcr.io/cloud-builders/docker'
15
+ args:
16
+ - 'push'
17
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:$SHORT_SHA'
18
+
19
+ - name: 'gcr.io/cloud-builders/docker'
20
+ args:
21
+ - 'push'
22
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:latest'
23
+
24
+ # Deploy to Cloud Run
25
+ - name: 'gcr.io/google.com/cloudsdktool/cloud-sdk'
26
+ entrypoint: gcloud
27
+ args:
28
+ - 'run'
29
+ - 'deploy'
30
+ - 'phishing-detection-api'
31
+ - '--image'
32
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:$SHORT_SHA'
33
+ - '--region'
34
+ - 'us-central1'
35
+ - '--platform'
36
+ - 'managed'
37
+ - '--allow-unauthenticated'
38
+ - '--memory'
39
+ - '4Gi'
40
+ - '--cpu'
41
+ - '2'
42
+ - '--timeout'
43
+ - '300'
44
+ - '--max-instances'
45
+ - '10'
46
+ - '--set-env-vars'
47
+ - 'PYTHONUNBUFFERED=1'
48
+ - '--set-secrets'
49
+ - 'GROQ_API_KEY=GROQ_API_KEY:latest'
50
+
51
+ images:
52
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:$SHORT_SHA'
53
+ - 'gcr.io/$PROJECT_ID/phishing-detection-api:latest'
54
+
55
+ options:
56
+ machineType: 'E2_HIGHCPU_8'
57
+ logging: CLOUD_LOGGING_ONLY
58
+
59
+ timeout: '1200s'
config.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ DATA_DIR = os.path.join(BASE_DIR, "data")
5
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
6
+ REPORTS_DIR = os.path.join(BASE_DIR, "reports")
7
+
8
+ TRAIN_SAMPLE_FRACTION = 0.5
9
+ ENGINEERED_TRAIN_FILE = os.path.join(BASE_DIR, "engineered_features.csv")
10
+
11
+ REPORT_SAMPLE_SIZE = 1500
12
+ ENGINEERED_TEST_FILE = os.path.join(BASE_DIR, "engineered_features_test.csv")
13
+
14
+ LEXICAL_FEATURES = [
15
+ 'url_length',
16
+ 'hostname_length',
17
+ 'path_length',
18
+ 'query_length',
19
+ 'fragment_length',
20
+ 'count_dot',
21
+ 'count_hyphen',
22
+ 'count_underscore',
23
+ 'count_slash',
24
+ 'count_at',
25
+ 'count_equals',
26
+ 'count_percent',
27
+ 'count_digits',
28
+ 'count_letters',
29
+ 'count_special_chars',
30
+ 'has_ip_address',
31
+ 'has_http',
32
+ 'has_https',
33
+ ]
34
+
35
+ WHOIS_FEATURES = [
36
+ 'domain_age_days',
37
+ 'domain_lifespan_days',
38
+ 'days_since_domain_update',
39
+ 'registrar_name',
40
+ ]
41
+
42
+ SSL_FEATURES = [
43
+ 'cert_age_days',
44
+ 'cert_validity_days',
45
+ 'cert_issuer_cn',
46
+ 'cert_subject_cn',
47
+ 'ssl_protocol_version',
48
+ 'cert_has_valid_hostname',
49
+ ]
50
+
51
+ ALL_FEATURE_COLUMNS = (
52
+ LEXICAL_FEATURES +
53
+ WHOIS_FEATURES +
54
+ SSL_FEATURES
55
+ )
56
+
57
+ CATEGORICAL_FEATURES = [
58
+ 'registrar_name',
59
+ 'cert_issuer_cn',
60
+ 'cert_subject_cn',
61
+ 'ssl_protocol_version'
62
+ ]
63
+
64
+ NUMERICAL_FEATURES = [
65
+ col for col in ALL_FEATURE_COLUMNS if col not in CATEGORICAL_FEATURES
66
+ ]
67
+
68
+ ML_MODEL_RANDOM_STATE = 42
69
+ ML_TEST_SIZE = 0.2
70
+
71
+ DL_EPOCHS = 50
72
+ DL_BATCH_SIZE = 64
73
+ DL_LEARNING_RATE = 0.001
data_pipeline.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import pandas as pd
4
+ import config
5
+ from feature_extraction import extract_features_from_dataframe
6
+
7
+ def load_and_sample_raw_data(data_dir, fraction=0.1, random_state=42):
8
+ raw_data_files = glob.glob(os.path.join(data_dir, "*.csv"))
9
+
10
+ if not raw_data_files:
11
+ print(f"Error: No .csv files found in '{data_dir}'.")
12
+ print("Please add your raw data files (e.g., phishing.csv, legit.csv) to the /data/ folder.")
13
+ return pd.DataFrame()
14
+
15
+ print(f"Found {len(raw_data_files)} raw data files.")
16
+
17
+ all_samples = []
18
+ for file_path in raw_data_files:
19
+ try:
20
+ print(f"Loading and sampling {os.path.basename(file_path)}...")
21
+ df = pd.read_csv(file_path, on_bad_lines='skip')
22
+
23
+ if 'label' not in df.columns or 'url' not in df.columns:
24
+ print(f"Warning: Skipping {file_path}. Must contain 'label' and 'url' columns.")
25
+ continue
26
+
27
+ sample_df = df.sample(frac=fraction, random_state=random_state)
28
+ all_samples.append(sample_df)
29
+
30
+ except Exception as e:
31
+ print(f"Error processing {file_path}: {e}")
32
+
33
+ if not all_samples:
34
+ print("Error: No valid data could be loaded.")
35
+ return pd.DataFrame()
36
+
37
+ combined_df = pd.concat(all_samples, ignore_index=True)
38
+ combined_df = combined_df.sample(frac=0.1, random_state=random_state).reset_index(drop=True)
39
+
40
+ print(f"Total raw training data prepared: {len(combined_df)} samples.")
41
+ return combined_df
42
+
43
+ def main():
44
+ print("--- Starting Data Pipeline ---")
45
+
46
+ raw_df = load_and_sample_raw_data(
47
+ data_dir=config.DATA_DIR,
48
+ fraction=config.TRAIN_SAMPLE_FRACTION
49
+ )
50
+
51
+ if raw_df.empty:
52
+ print("Data pipeline failed. Exiting.")
53
+ return
54
+
55
+ engineered_df = extract_features_from_dataframe(raw_df)
56
+
57
+ engineered_df.to_csv(config.ENGINEERED_TRAIN_FILE, index=False)
58
+
59
+ print(f"\n--- Data Pipeline Complete ---")
60
+ print(f"Engineered training set saved to: {config.ENGINEERED_TRAIN_FILE}")
61
+ print(f"Total features: {len(config.ALL_FEATURE_COLUMNS)}")
62
+
63
+ if __name__ == "__main__":
64
+ os.makedirs(config.DATA_DIR, exist_ok=True)
65
+ if not glob.glob(os.path.join(config.DATA_DIR, "*.csv")):
66
+ print("Creating dummy data files...")
67
+ dummy_phish = pd.DataFrame({
68
+ 'label': [1, 1],
69
+ 'url': ['facebook.com.login-support.ru', 'myetherwallets.kr/wallet']
70
+ })
71
+ dummy_phish.to_csv(os.path.join(config.DATA_DIR, 'phishing_data_1.csv'), index=False)
72
+
73
+ dummy_legit = pd.DataFrame({
74
+ 'label': [0, 0],
75
+ 'url': ['google.com', 'https://www.millect.com/Plans']
76
+ })
77
+ dummy_legit.to_csv(os.path.join(config.DATA_DIR, 'legit_data_1.csv'), index=False)
78
+ print(f"Dummy files created in {config.DATA_DIR}. Please replace them with your real data.")
79
+
80
+ main()
deploy.ps1 ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCP Deployment Script for Phishing Detection API (PowerShell)
2
+ # This script helps deploy the application to Google Cloud Run
3
+
4
+ param(
5
+ [Parameter(Mandatory=$true)]
6
+ [string]$ProjectId,
7
+
8
+ [Parameter(Mandatory=$false)]
9
+ [string]$Region = "us-central1"
10
+ )
11
+
12
+ $ServiceName = "phishing-detection-api"
13
+
14
+ Write-Host "==========================================" -ForegroundColor Cyan
15
+ Write-Host "Phishing Detection API - GCP Deployment" -ForegroundColor Cyan
16
+ Write-Host "==========================================" -ForegroundColor Cyan
17
+ Write-Host "Project ID: $ProjectId"
18
+ Write-Host "Region: $Region"
19
+ Write-Host "Service Name: $ServiceName"
20
+ Write-Host "==========================================" -ForegroundColor Cyan
21
+
22
+ # Check if gcloud is installed
23
+ try {
24
+ $null = Get-Command gcloud -ErrorAction Stop
25
+ } catch {
26
+ Write-Host "Error: gcloud CLI is not installed. Please install it first." -ForegroundColor Red
27
+ exit 1
28
+ }
29
+
30
+ # Check if docker is installed
31
+ try {
32
+ $null = Get-Command docker -ErrorAction Stop
33
+ } catch {
34
+ Write-Host "Error: Docker is not installed. Please install it first." -ForegroundColor Red
35
+ exit 1
36
+ }
37
+
38
+ # Set the project
39
+ Write-Host "Setting GCP project..." -ForegroundColor Yellow
40
+ gcloud config set project $ProjectId
41
+
42
+ # Enable required APIs
43
+ Write-Host "Enabling required APIs..." -ForegroundColor Yellow
44
+ gcloud services enable cloudbuild.googleapis.com
45
+ gcloud services enable run.googleapis.com
46
+ gcloud services enable containerregistry.googleapis.com
47
+
48
+ # Check if GROQ_API_KEY secret exists, if not create it
49
+ Write-Host "Checking for GROQ_API_KEY secret..." -ForegroundColor Yellow
50
+ $secretExists = gcloud secrets describe GROQ_API_KEY --project=$ProjectId 2>&1
51
+ if ($LASTEXITCODE -ne 0) {
52
+ Write-Host "GROQ_API_KEY secret not found. Creating it..." -ForegroundColor Yellow
53
+ $GroqKey = Read-Host "Enter your GROQ_API_KEY" -AsSecureString
54
+ $GroqKeyPlain = [Runtime.InteropServices.Marshal]::PtrToStringAuto(
55
+ [Runtime.InteropServices.Marshal]::SecureStringToBSTR($GroqKey)
56
+ )
57
+
58
+ echo $GroqKeyPlain | gcloud secrets create GROQ_API_KEY `
59
+ --data-file=- `
60
+ --replication-policy="automatic" `
61
+ --project=$ProjectId
62
+
63
+ # Grant Cloud Run service account access to the secret
64
+ $ProjectNumber = (gcloud projects describe $ProjectId --format="value(projectNumber)")
65
+ gcloud secrets add-iam-policy-binding GROQ_API_KEY `
66
+ --member="serviceAccount:$ProjectNumber-compute@developer.gserviceaccount.com" `
67
+ --role="roles/secretmanager.secretAccessor" `
68
+ --project=$ProjectId
69
+ } else {
70
+ Write-Host "GROQ_API_KEY secret already exists." -ForegroundColor Green
71
+ }
72
+
73
+ # Build and deploy using Cloud Build
74
+ Write-Host "Building and deploying using Cloud Build..." -ForegroundColor Yellow
75
+ gcloud builds submit --config=cloudbuild.yaml --project=$ProjectId
76
+
77
+ # Get the service URL
78
+ Write-Host "Deployment complete!" -ForegroundColor Green
79
+ Write-Host "Getting service URL..." -ForegroundColor Yellow
80
+ $ServiceUrl = (gcloud run services describe $ServiceName `
81
+ --region=$Region `
82
+ --format="value(status.url)" `
83
+ --project=$ProjectId)
84
+
85
+ Write-Host "==========================================" -ForegroundColor Green
86
+ Write-Host "Deployment Successful!" -ForegroundColor Green
87
+ Write-Host "Service URL: $ServiceUrl"
88
+ Write-Host "Health Check: $ServiceUrl/health"
89
+ Write-Host "API Docs: $ServiceUrl/docs"
90
+ Write-Host "==========================================" -ForegroundColor Green
deploy.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GCP Deployment Script for Phishing Detection API
4
+ # This script helps deploy the application to Google Cloud Run
5
+
6
+ set -e
7
+
8
+ PROJECT_ID=${1:-"your-project-id"}
9
+ REGION=${2:-"us-central1"}
10
+ SERVICE_NAME="phishing-detection-api"
11
+
12
+ echo "=========================================="
13
+ echo "Phishing Detection API - GCP Deployment"
14
+ echo "=========================================="
15
+ echo "Project ID: $PROJECT_ID"
16
+ echo "Region: $REGION"
17
+ echo "Service Name: $SERVICE_NAME"
18
+ echo "=========================================="
19
+
20
+ # Check if gcloud is installed
21
+ if ! command -v gcloud &> /dev/null; then
22
+ echo "Error: gcloud CLI is not installed. Please install it first."
23
+ exit 1
24
+ fi
25
+
26
+ # Check if docker is installed
27
+ if ! command -v docker &> /dev/null; then
28
+ echo "Error: Docker is not installed. Please install it first."
29
+ exit 1
30
+ fi
31
+
32
+ # Set the project
33
+ echo "Setting GCP project..."
34
+ gcloud config set project $PROJECT_ID
35
+
36
+ # Enable required APIs
37
+ echo "Enabling required APIs..."
38
+ gcloud services enable cloudbuild.googleapis.com
39
+ gcloud services enable run.googleapis.com
40
+ gcloud services enable containerregistry.googleapis.com
41
+
42
+ # Check if GROQ_API_KEY secret exists, if not create it
43
+ echo "Checking for GROQ_API_KEY secret..."
44
+ if ! gcloud secrets describe GROQ_API_KEY --project=$PROJECT_ID &> /dev/null; then
45
+ echo "GROQ_API_KEY secret not found. Creating it..."
46
+ read -sp "Enter your GROQ_API_KEY: " GROQ_KEY
47
+ echo
48
+ echo -n "$GROQ_KEY" | gcloud secrets create GROQ_API_KEY \
49
+ --data-file=- \
50
+ --replication-policy="automatic" \
51
+ --project=$PROJECT_ID
52
+
53
+ # Grant Cloud Run service account access to the secret
54
+ PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format="value(projectNumber)")
55
+ gcloud secrets add-iam-policy-binding GROQ_API_KEY \
56
+ --member="serviceAccount:$PROJECT_NUMBER-compute@developer.gserviceaccount.com" \
57
+ --role="roles/secretmanager.secretAccessor" \
58
+ --project=$PROJECT_ID
59
+ else
60
+ echo "GROQ_API_KEY secret already exists."
61
+ fi
62
+
63
+ # Build and deploy using Cloud Build
64
+ echo "Building and deploying using Cloud Build..."
65
+ gcloud builds submit --config=cloudbuild.yaml --project=$PROJECT_ID
66
+
67
+ # Get the service URL
68
+ echo "Deployment complete!"
69
+ echo "Getting service URL..."
70
+ SERVICE_URL=$(gcloud run services describe $SERVICE_NAME \
71
+ --region=$REGION \
72
+ --format="value(status.url)" \
73
+ --project=$PROJECT_ID)
74
+
75
+ echo "=========================================="
76
+ echo "Deployment Successful!"
77
+ echo "Service URL: $SERVICE_URL"
78
+ echo "Health Check: $SERVICE_URL/health"
79
+ echo "API Docs: $SERVICE_URL/docs"
80
+ echo "=========================================="
docker-compose.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ phishing-api:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ container_name: phishing-detection-api
9
+ ports:
10
+ - "8000:8000"
11
+ environment:
12
+ - PYTHONUNBUFFERED=1
13
+ - PYTHONDONTWRITEBYTECODE=1
14
+ # Add your GROQ_API_KEY here or use .env file
15
+ - GROQ_API_KEY=${GROQ_API_KEY:-your-api-key-here}
16
+ volumes:
17
+ # Optional: Mount models directory if you want to update models without rebuilding
18
+ # - ./models:/app/models:ro
19
+ # - ./finetuned_bert:/app/finetuned_bert:ro
20
+ # - ./Message_model:/app/Message_model:ro
21
+ restart: unless-stopped
22
+ healthcheck:
23
+ test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
24
+ interval: 30s
25
+ timeout: 10s
26
+ retries: 3
27
+ start_period: 60s
engineered_features.csv ADDED
The diff for this file is too large to render. See raw diff
 
feature_extraction.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whois
2
+ import tldextract
3
+ import ssl
4
+ import socket
5
+ import pandas as pd
6
+ from datetime import datetime
7
+ from urllib.parse import urlparse
8
+ from OpenSSL import SSL
9
+ import re
10
+ import sys
11
+ from tqdm import tqdm
12
+
13
+ def get_lexical_features(url):
14
+ features = {}
15
+ try:
16
+ if not (url.startswith('http://') or url.startswith('https://')):
17
+ url = 'http://' + url
18
+ parsed_url = urlparse(url)
19
+ hostname = parsed_url.hostname if parsed_url.hostname else ''
20
+ path = parsed_url.path
21
+ query = parsed_url.query
22
+ fragment = parsed_url.fragment
23
+ except Exception:
24
+ hostname = ''
25
+ path = ''
26
+ query = ''
27
+ fragment = ''
28
+ url = ''
29
+
30
+ features['url_length'] = len(url)
31
+ features['hostname_length'] = len(hostname)
32
+ features['path_length'] = len(path)
33
+ features['query_length'] = len(query)
34
+ features['fragment_length'] = len(fragment)
35
+ features['count_dot'] = url.count('.')
36
+ features['count_hyphen'] = url.count('-')
37
+ features['count_underscore'] = url.count('_')
38
+ features['count_slash'] = url.count('/')
39
+ features['count_at'] = url.count('@')
40
+ features['count_equals'] = url.count('=')
41
+ features['count_percent'] = url.count('%')
42
+ features['count_digits'] = sum(c.isdigit() for c in url)
43
+ features['count_letters'] = sum(c.isalpha() for c in url)
44
+ features['count_special_chars'] = len(re.findall(r'[^a-zA-Z0-9\s]', url))
45
+
46
+ ip_regex = r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
47
+ features['has_ip_address'] = 1 if re.match(ip_regex, hostname) else 0
48
+ features['has_http'] = 1 if 'http:' in url else 0
49
+ features['has_https'] = 1 if 'https:' in url else 0
50
+
51
+ return features
52
+
53
+ def get_whois_features(registrable_domain, whois_cache):
54
+ if registrable_domain in whois_cache:
55
+ return whois_cache[registrable_domain]
56
+
57
+ features = {
58
+ 'domain_age_days': -1,
59
+ 'domain_lifespan_days': -1,
60
+ 'days_since_domain_update': -1,
61
+ 'registrar_name': 'N/A',
62
+ }
63
+
64
+ try:
65
+ w = whois.whois(registrable_domain)
66
+ if not w.creation_date:
67
+ whois_cache[registrable_domain] = features
68
+ return features
69
+
70
+ creation_date = w.creation_date[0] if isinstance(w.creation_date, list) else w.creation_date
71
+ expiration_date = w.expiration_date[0] if isinstance(w.expiration_date, list) else w.expiration_date
72
+ updated_date = w.updated_date[0] if isinstance(w.updated_date, list) else w.updated_date
73
+
74
+ if creation_date:
75
+ features['domain_age_days'] = (datetime.now() - creation_date).days
76
+
77
+ if creation_date and expiration_date:
78
+ features['domain_lifespan_days'] = (expiration_date - creation_date).days
79
+
80
+ if updated_date:
81
+ features['days_since_domain_update'] = (datetime.now() - updated_date).days
82
+
83
+ if w.registrar:
84
+ features['registrar_name'] = str(w.registrar).split(' ')[0].replace(',', '').replace('"', '')
85
+
86
+ except Exception as e:
87
+ pass
88
+
89
+ whois_cache[registrable_domain] = features
90
+ return features
91
+
92
+ def get_ssl_features(hostname, ssl_cache):
93
+ if hostname in ssl_cache:
94
+ return ssl_cache[hostname]
95
+
96
+ features = {
97
+ 'cert_age_days': -1,
98
+ 'cert_validity_days': -1,
99
+ 'cert_issuer_cn': 'N/A',
100
+ 'cert_subject_cn': 'N/A',
101
+ 'ssl_protocol_version': 'N/A',
102
+ 'cert_has_valid_hostname': 0,
103
+ }
104
+
105
+ try:
106
+ context = SSL.Context(SSL.SSLv23_METHOD)
107
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
108
+ sock.settimeout(2)
109
+
110
+ ssl_sock = SSL.Connection(context, sock)
111
+ ssl_sock.set_tlsext_host_name(hostname.encode('utf-8'))
112
+ ssl_sock.connect((hostname, 443))
113
+ ssl_sock.do_handshake()
114
+
115
+ cert = ssl_sock.get_peer_certificate()
116
+
117
+ features['ssl_protocol_version'] = ssl_sock.get_protocol_version_name()
118
+ not_before_str = cert.get_notBefore().decode('utf-8')
119
+ not_before_date = datetime.strptime(not_before_str, '%Y%m%d%H%M%SZ')
120
+ features['cert_age_days'] = (datetime.now() - not_before_date).days
121
+ not_after_str = cert.get_notAfter().decode('utf-8')
122
+ not_after_date = datetime.strptime(not_after_str, '%Y%m%d%H%M%SZ')
123
+ features['cert_validity_days'] = (not_after_date - not_before_date).days
124
+ issuer_components = dict(cert.get_issuer().get_components())
125
+ subject_components = dict(cert.get_subject().get_components())
126
+
127
+ features['cert_issuer_cn'] = issuer_components.get(b'CN', b'N/A').decode('utf-8')
128
+ features['cert_subject_cn'] = subject_components.get(b'CN', b'N/A').decode('utf-8')
129
+ if features['cert_subject_cn'] == hostname or f"*.{tldextract.extract(hostname).registered_domain}" == features['cert_subject_cn']:
130
+ features['cert_has_valid_hostname'] = 1
131
+
132
+ ssl_sock.close()
133
+ sock.close()
134
+
135
+ except SSL.Error as e:
136
+ features['cert_issuer_cn'] = 'SSL_ERROR'
137
+ except socket.timeout:
138
+ features['cert_issuer_cn'] = 'TIMEOUT'
139
+ except socket.error:
140
+ features['cert_issuer_cn'] = 'CONN_FAILED'
141
+ except Exception as e:
142
+ features['cert_issuer_cn'] = 'SSL_UNKNOWN_ERROR'
143
+ pass
144
+
145
+ ssl_cache[hostname] = features
146
+ return features
147
+
148
+ def process_row(row, whois_cache, ssl_cache):
149
+ url = row['url']
150
+ try:
151
+ if not (url.startswith('http://') or url.startswith('https://')):
152
+ url_for_parse = 'http://' + url
153
+ else:
154
+ url_for_parse = url
155
+
156
+ parsed_url = urlparse(url_for_parse)
157
+ hostname = parsed_url.hostname if parsed_url.hostname else ''
158
+
159
+ ext = tldextract.extract(url_for_parse)
160
+ registrable_domain = f"{ext.domain}.{ext.suffix}"
161
+
162
+ except Exception:
163
+ hostname = ''
164
+ registrable_domain = ''
165
+ lexical_data = get_lexical_features(url)
166
+
167
+ whois_data = {}
168
+ if registrable_domain:
169
+ whois_data = get_whois_features(registrable_domain, whois_cache)
170
+
171
+ ssl_data = {}
172
+ if hostname:
173
+ ssl_data = get_ssl_features(hostname, ssl_cache)
174
+ all_features = {**lexical_data, **whois_data, **ssl_data}
175
+
176
+ return pd.Series(all_features)
177
+
178
+
179
+ def extract_features_from_dataframe(df):
180
+ if 'url' not in df.columns:
181
+ raise ValueError("DataFrame must contain a 'url' column.")
182
+ whois_cache = {}
183
+ ssl_cache = {}
184
+
185
+ print("Starting feature extraction... This may take a very long time.")
186
+ try:
187
+
188
+ tqdm.pandas(desc="Extracting features")
189
+ feature_df = df.progress_apply(process_row, args=(whois_cache, ssl_cache), axis=1)
190
+ except ImportError:
191
+ print("tqdm not found. Running without progress bar. `pip install tqdm` to see progress.")
192
+ feature_df = df.apply(process_row, args=(whois_cache, ssl_cache), axis=1)
193
+
194
+ print("Feature extraction complete.")
195
+
196
+ final_df = pd.concat([df, feature_df], axis=1)
197
+
198
+ return final_df
models.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.linear_model import LogisticRegression
2
+ from sklearn.svm import SVC
3
+ from xgboost import XGBClassifier
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset
7
+ from transformers import BertTokenizer, BertForSequenceClassification
8
+ import config
9
+ import os
10
+
11
+ def get_ml_models():
12
+ models = {
13
+ 'logistic': LogisticRegression(
14
+ max_iter=1000,
15
+ random_state=config.ML_MODEL_RANDOM_STATE
16
+ ),
17
+ 'svm': SVC(
18
+ probability=True,
19
+ random_state=config.ML_MODEL_RANDOM_STATE
20
+ ),
21
+ 'xgboost': XGBClassifier(
22
+ n_estimators=100,
23
+ random_state=config.ML_MODEL_RANDOM_STATE,
24
+ use_label_encoder=False,
25
+ eval_metric='logloss'
26
+ )
27
+ }
28
+ return models
29
+
30
+ class Attention_BLSTM(nn.Module):
31
+ def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.3):
32
+ super(Attention_BLSTM, self).__init__()
33
+ self.hidden_dim = hidden_dim
34
+ self.num_layers = num_layers
35
+
36
+ self.lstm = nn.LSTM(
37
+ input_dim,
38
+ hidden_dim,
39
+ num_layers,
40
+ batch_first=True,
41
+ bidirectional=True,
42
+ dropout=dropout if num_layers > 1 else 0
43
+ )
44
+
45
+ self.attention = nn.Linear(hidden_dim * 2, 1)
46
+ self.fc = nn.Linear(hidden_dim * 2, 64)
47
+ self.relu = nn.ReLU()
48
+ self.dropout = nn.Dropout(dropout)
49
+ self.output = nn.Linear(64, 1)
50
+ self.sigmoid = nn.Sigmoid()
51
+
52
+ def forward(self, x):
53
+ if len(x.shape) == 2:
54
+ x = x.unsqueeze(1)
55
+
56
+ lstm_out, _ = self.lstm(x)
57
+
58
+ attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
59
+ context_vector = torch.sum(attention_weights * lstm_out, dim=1)
60
+
61
+ out = self.fc(context_vector)
62
+ out = self.relu(out)
63
+ out = self.dropout(out)
64
+ out = self.output(out)
65
+ out = self.sigmoid(out)
66
+
67
+ return out
68
+
69
+ class RCNN(nn.Module):
70
+ def __init__(self, input_dim, embed_dim=64, num_filters=100, filter_sizes=[3, 4, 5], dropout=0.5):
71
+ super(RCNN, self).__init__()
72
+
73
+ self.lstm = nn.LSTM(1, embed_dim // 2, batch_first=True, bidirectional=True)
74
+
75
+ self.convs = nn.ModuleList([
76
+ nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
77
+ for fs in filter_sizes
78
+ ])
79
+
80
+ self.fc = nn.Linear(len(filter_sizes) * num_filters, 64)
81
+ self.relu = nn.ReLU()
82
+ self.dropout = nn.Dropout(dropout)
83
+ self.output = nn.Linear(64, 1)
84
+ self.sigmoid = nn.Sigmoid()
85
+
86
+ def forward(self, x):
87
+ batch_size = x.size(0)
88
+ seq_len = x.size(1)
89
+
90
+ x = x.unsqueeze(2)
91
+
92
+ lstm_out, _ = self.lstm(x)
93
+
94
+ lstm_out = lstm_out.permute(0, 2, 1)
95
+
96
+ conv_outs = [torch.relu(conv(lstm_out)) for conv in self.convs]
97
+
98
+ pooled = [torch.max_pool1d(conv_out, conv_out.size(2)).squeeze(2) for conv_out in conv_outs]
99
+
100
+ cat = torch.cat(pooled, dim=1)
101
+
102
+ out = self.fc(cat)
103
+ out = self.relu(out)
104
+ out = self.dropout(out)
105
+ out = self.output(out)
106
+ out = self.sigmoid(out)
107
+
108
+ return out
109
+
110
+ class FinetunedBERT:
111
+ def __init__(self, model_path=None):
112
+ if model_path is None:
113
+ model_path = os.path.join(config.BASE_DIR, 'finetuned_bert')
114
+
115
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
116
+ self.model = BertForSequenceClassification.from_pretrained(
117
+ model_path,
118
+ num_labels=2,
119
+ local_files_only=True
120
+ )
121
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
122
+ self.model.to(self.device)
123
+ self.model.eval()
124
+
125
+ def predict(self, urls):
126
+ if isinstance(urls, str):
127
+ urls = [urls]
128
+
129
+ encodings = self.tokenizer(
130
+ urls,
131
+ padding=True,
132
+ truncation=True,
133
+ max_length=512,
134
+ return_tensors='pt'
135
+ )
136
+
137
+ encodings = {key: val.to(self.device) for key, val in encodings.items()}
138
+
139
+ with torch.no_grad():
140
+ outputs = self.model(**encodings)
141
+ predictions = torch.argmax(outputs.logits, dim=1)
142
+
143
+ return predictions.cpu().numpy()
144
+
145
+ def predict_proba(self, urls):
146
+ if isinstance(urls, str):
147
+ urls = [urls]
148
+
149
+ encodings = self.tokenizer(
150
+ urls,
151
+ padding=True,
152
+ truncation=True,
153
+ max_length=512,
154
+ return_tensors='pt'
155
+ )
156
+
157
+ encodings = {key: val.to(self.device) for key, val in encodings.items()}
158
+
159
+ with torch.no_grad():
160
+ outputs = self.model(**encodings)
161
+ probas = torch.softmax(outputs.logits, dim=1)
162
+
163
+ return probas.cpu().numpy()
164
+
165
+ class PhishingDataset(Dataset):
166
+ def __init__(self, X, y):
167
+ self.X = torch.tensor(X, dtype=torch.float32)
168
+ self.y = torch.tensor(y, dtype=torch.float32)
169
+
170
+ def __len__(self):
171
+ return len(self.y)
172
+
173
+ def __getitem__(self, idx):
174
+ return self.X[idx], self.y[idx].view(-1)
175
+
176
+ def get_dl_models(input_dim):
177
+ models = {
178
+ 'attention_blstm': Attention_BLSTM(input_dim),
179
+ 'rcnn': RCNN(input_dim)
180
+ }
181
+ return models
models/attention_blstm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa49103ee969cc2194fc6369905ccb6942e0a472263544a40072d678cf4abae9
3
+ size 2283420
models/dl_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5684232d23cab5e5579c68e653318a032867e2f22d101fdfa6756b80fe097042
3
+ size 1191
models/logistic.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f030cc6dd0405aab6f4ad3813d58cf1d17723426e6ba6e0fe5b54466776be921
3
+ size 848679
models/rcnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9195f2166c067d0a1d9f15476f94cf13b6ea7cc6946745be1c4eaccb9c968a7
3
+ size 425800
models/svm.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fb6afd8b17462766f22b0a24ec31d9012b41430de30cb5ae816d0933e0be124
3
+ size 1811267
models/xgboost.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d2c829005e6249c5aeca9e80ff272660c65268b69ce482b00ef436cc2c10841
3
+ size 861278
report.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
9
+ from sklearn.preprocessing import StandardScaler
10
+ import config
11
+ from models import get_dl_models, PhishingDataset, FinetunedBERT
12
+
13
+ sns.set_style("whitegrid")
14
+ plt.rcParams['figure.figsize'] = (12, 8)
15
+ plt.rcParams['font.size'] = 11
16
+
17
+ COLORS = {
18
+ 'primary': '#FF6B6B',
19
+ 'secondary': '#4ECDC4',
20
+ 'tertiary': '#45B7D1',
21
+ 'quaternary': '#FFA07A',
22
+ 'quinary': '#98D8C8',
23
+ 'bg': '#F7F7F7',
24
+ 'text': '#2C3E50'
25
+ }
26
+
27
+ MODEL_THRESHOLDS = {
28
+ 'attention_blstm': 0.8,
29
+ 'rcnn': 0.8,
30
+ 'logistic': 0.5,
31
+ 'svm': 0.5,
32
+ 'xgboost': 0.5,
33
+ 'bert': 0.5
34
+ }
35
+
36
+ def load_sample_data(sample_fraction=0.05):
37
+ print(f"Loading {sample_fraction*100}% sample from data...")
38
+
39
+ if os.path.exists(config.ENGINEERED_TEST_FILE):
40
+ df = pd.read_csv(config.ENGINEERED_TEST_FILE)
41
+ print(f"Loaded test data: {len(df)} samples")
42
+ elif os.path.exists(config.ENGINEERED_TRAIN_FILE):
43
+ df = pd.read_csv(config.ENGINEERED_TRAIN_FILE)
44
+ print(f"Loaded train data: {len(df)} samples")
45
+ else:
46
+ data_files = [
47
+ os.path.join(config.DATA_DIR, 'url_data_labeled.csv'),
48
+ os.path.join(config.DATA_DIR, 'data_bal - 20000.csv')
49
+ ]
50
+ df = None
51
+ for file in data_files:
52
+ if os.path.exists(file):
53
+ df = pd.read_csv(file)
54
+ print(f"Loaded raw data: {len(df)} samples")
55
+ break
56
+
57
+ if df is None:
58
+ raise FileNotFoundError("No data file found!")
59
+
60
+ sample_size = max(int(len(df) * sample_fraction), config.REPORT_SAMPLE_SIZE)
61
+ sample_size = min(sample_size, len(df))
62
+ df_sample = df.sample(n=sample_size, random_state=42)
63
+
64
+ print(f"Sampled {len(df_sample)} URLs for report generation")
65
+ return df_sample
66
+
67
+ def prepare_ml_data(df):
68
+ X = df[config.NUMERICAL_FEATURES + config.CATEGORICAL_FEATURES]
69
+ y = df['label'].values
70
+
71
+ X.loc[:, config.NUMERICAL_FEATURES] = X.loc[:, config.NUMERICAL_FEATURES].fillna(-1)
72
+ X.loc[:, config.CATEGORICAL_FEATURES] = X.loc[:, config.CATEGORICAL_FEATURES].fillna('N/A')
73
+
74
+ return X, y
75
+
76
+ def prepare_dl_data(df):
77
+ X = df[config.NUMERICAL_FEATURES].fillna(-1).values
78
+ y = df['label'].values
79
+
80
+ scaler_path = os.path.join(config.MODELS_DIR, "dl_scaler.pkl")
81
+ if os.path.exists(scaler_path):
82
+ scaler = joblib.load(scaler_path)
83
+ X_scaled = scaler.transform(X)
84
+ else:
85
+ scaler = StandardScaler()
86
+ X_scaled = scaler.fit_transform(X)
87
+
88
+ return X_scaled, y
89
+
90
+ def predict_ml_models(X, y):
91
+ predictions = {}
92
+ scores = {}
93
+
94
+ ml_models = ['logistic', 'svm', 'xgboost']
95
+
96
+ for model_name in ml_models:
97
+ model_path = os.path.join(config.MODELS_DIR, f"{model_name}.joblib")
98
+ if not os.path.exists(model_path):
99
+ print(f"WARNING: Model {model_name} not found, skipping...")
100
+ continue
101
+
102
+ print(f"Loading {model_name} model...")
103
+ model = joblib.load(model_path)
104
+
105
+ y_pred = model.predict(X)
106
+ y_proba = model.predict_proba(X)[:, 1]
107
+
108
+ predictions[model_name] = y_pred
109
+ scores[model_name] = y_proba
110
+
111
+ acc = accuracy_score(y, y_pred)
112
+ print(f" {model_name} accuracy: {acc:.4f}")
113
+
114
+ return predictions, scores
115
+
116
+ def predict_dl_models(X, y):
117
+ predictions = {}
118
+ scores = {}
119
+
120
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
121
+ input_dim = X.shape[1]
122
+
123
+ dl_models_dict = get_dl_models(input_dim)
124
+
125
+ for model_name, model in dl_models_dict.items():
126
+ model_path = os.path.join(config.MODELS_DIR, f"{model_name}.pt")
127
+ if not os.path.exists(model_path):
128
+ print(f"WARNING: Model {model_name} not found, skipping...")
129
+ continue
130
+
131
+ print(f"Loading {model_name} model...")
132
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
133
+ model.to(device)
134
+ model.eval()
135
+
136
+ X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
137
+
138
+ with torch.no_grad():
139
+ outputs = model(X_tensor).cpu().numpy().flatten()
140
+
141
+ threshold = MODEL_THRESHOLDS.get(model_name, 0.5)
142
+ y_pred = (outputs > threshold).astype(int)
143
+
144
+ predictions[model_name] = y_pred
145
+ scores[model_name] = outputs
146
+
147
+ acc = accuracy_score(y, y_pred)
148
+ print(f" {model_name} accuracy: {acc:.4f} (threshold: {threshold})")
149
+
150
+ del model, X_tensor
151
+ if torch.cuda.is_available():
152
+ torch.cuda.empty_cache()
153
+
154
+ return predictions, scores
155
+
156
+ def predict_bert_model(df, y):
157
+ bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert')
158
+ if not os.path.exists(bert_path):
159
+ print(f"WARNING: BERT model not found at {bert_path}, skipping...")
160
+ return None, None
161
+
162
+ if 'url' not in df.columns:
163
+ print("WARNING: 'url' column not found in data, skipping BERT...")
164
+ return None, None
165
+
166
+ try:
167
+ print("Loading BERT model...")
168
+ bert_model = FinetunedBERT(bert_path)
169
+
170
+ urls = df['url'].tolist()
171
+
172
+ batch_size = 32
173
+ all_preds = []
174
+ all_probas = []
175
+
176
+ print(f"Processing {len(urls)} URLs in batches of {batch_size}...")
177
+ for i in range(0, len(urls), batch_size):
178
+ batch_urls = urls[i:i+batch_size]
179
+ batch_preds = bert_model.predict(batch_urls)
180
+ batch_probas = bert_model.predict_proba(batch_urls)[:, 1]
181
+ all_preds.extend(batch_preds)
182
+ all_probas.extend(batch_probas)
183
+
184
+ if torch.cuda.is_available():
185
+ torch.cuda.empty_cache()
186
+
187
+ y_pred = 1-np.array(all_preds)
188
+ y_proba = 1-np.array(all_probas)
189
+
190
+ acc = accuracy_score(y, y_pred)
191
+ print(f" BERT accuracy: {acc:.4f}")
192
+
193
+ return y_pred, y_proba
194
+
195
+ except torch.cuda.OutOfMemoryError:
196
+ print("WARNING: CUDA out of memory for BERT model, skipping...")
197
+ print(" Try reducing batch size or use CPU by setting CUDA_VISIBLE_DEVICES=''")
198
+ return None, None
199
+ except Exception as e:
200
+ print(f"WARNING: Error loading BERT model: {e}")
201
+ return None, None
202
+
203
+ def plot_confusion_matrices(y_true, all_predictions, save_dir):
204
+ print("\nGenerating confusion matrices...")
205
+
206
+ n_models = len(all_predictions)
207
+ if n_models == 0:
208
+ print("No predictions to plot!")
209
+ return
210
+
211
+ cols = min(3, n_models)
212
+ rows = (n_models + cols - 1) // cols
213
+
214
+ fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
215
+ if n_models == 1:
216
+ axes = [axes]
217
+ else:
218
+ axes = axes.flatten() if rows > 1 else axes
219
+
220
+ cmap = sns.color_palette("RdYlGn_r", as_cmap=True)
221
+
222
+ for idx, (model_name, y_pred) in enumerate(all_predictions.items()):
223
+ ax = axes[idx]
224
+
225
+ cm = confusion_matrix(y_true, y_pred)
226
+
227
+ sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax,
228
+ cbar_kws={'label': 'Count'},
229
+ annot_kws={'size': 14, 'weight': 'bold'})
230
+
231
+ ax.set_title(f'{model_name.upper()} Confusion Matrix',
232
+ fontsize=14, fontweight='bold', color=COLORS['text'])
233
+ ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
234
+ ax.set_ylabel('True Label', fontsize=12, fontweight='bold')
235
+ ax.set_xticklabels(['Legitimate (0)', 'Phishing (1)'])
236
+ ax.set_yticklabels(['Legitimate (0)', 'Phishing (1)'])
237
+
238
+ for idx in range(n_models, len(axes)):
239
+ fig.delaxes(axes[idx])
240
+
241
+ plt.tight_layout()
242
+ save_path = os.path.join(save_dir, 'confusion_matrices.png')
243
+ plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
244
+ print(f"Saved confusion matrices to {save_path}")
245
+ plt.close()
246
+
247
+ def plot_accuracy_comparison(y_true, all_predictions, save_dir):
248
+ print("\nGenerating accuracy comparison plot...")
249
+
250
+ if len(all_predictions) == 0:
251
+ print("No predictions to plot!")
252
+ return
253
+
254
+ accuracies = {}
255
+ for model_name, y_pred in all_predictions.items():
256
+ acc = accuracy_score(y_true, y_pred)
257
+ accuracies[model_name] = acc
258
+
259
+ models = list(accuracies.keys())
260
+ accs = list(accuracies.values())
261
+
262
+ colors_list = [COLORS['primary'], COLORS['secondary'], COLORS['tertiary'],
263
+ COLORS['quaternary'], COLORS['quinary']]
264
+ bar_colors = [colors_list[i % len(colors_list)] for i in range(len(models))]
265
+
266
+ fig, ax = plt.subplots(figsize=(12, 7))
267
+
268
+ bars = ax.bar(models, accs, color=bar_colors, edgecolor='black', linewidth=2, alpha=0.8)
269
+
270
+ for bar, acc in zip(bars, accs):
271
+ height = bar.get_height()
272
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
273
+ f'{acc:.4f}',
274
+ ha='center', va='bottom', fontsize=13, fontweight='bold')
275
+
276
+ ax.set_xlabel('Models', fontsize=14, fontweight='bold', color=COLORS['text'])
277
+ ax.set_ylabel('Accuracy', fontsize=14, fontweight='bold', color=COLORS['text'])
278
+ ax.set_title('Model Accuracy Comparison', fontsize=18, fontweight='bold',
279
+ color=COLORS['text'], pad=20)
280
+ ax.set_ylim([0, 1.1])
281
+ ax.grid(axis='y', alpha=0.3, linestyle='--')
282
+ ax.set_axisbelow(True)
283
+
284
+ plt.xticks(rotation=45, ha='right', fontsize=12)
285
+ plt.tight_layout()
286
+
287
+ save_path = os.path.join(save_dir, 'accuracy_comparison.png')
288
+ plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
289
+ print(f"Saved accuracy comparison to {save_path}")
290
+ plt.close()
291
+
292
+ def plot_score_vs_label(y_true, all_scores, save_dir):
293
+ print("\nGenerating score vs label scatter plots...")
294
+
295
+ if len(all_scores) == 0:
296
+ print("No scores to plot!")
297
+ return
298
+
299
+ n_models = len(all_scores)
300
+ cols = min(3, n_models)
301
+ rows = (n_models + cols - 1) // cols
302
+
303
+ fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
304
+ if n_models == 1:
305
+ axes = [axes]
306
+ else:
307
+ axes = axes.flatten() if rows > 1 else axes
308
+
309
+ colors_map = {0: COLORS['secondary'], 1: COLORS['primary']}
310
+
311
+ for idx, (model_name, scores) in enumerate(all_scores.items()):
312
+ ax = axes[idx]
313
+
314
+ for label in [0, 1]:
315
+ mask = y_true == label
316
+ label_name = 'Legitimate' if label == 0 else 'Phishing'
317
+ ax.scatter(np.where(mask)[0], scores[mask],
318
+ c=colors_map[label], label=label_name,
319
+ alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
320
+
321
+ threshold = MODEL_THRESHOLDS.get(model_name, 0.5)
322
+ ax.axhline(y=threshold, color='red', linestyle='--', linewidth=2,
323
+ label=f'Threshold ({threshold})', alpha=0.7)
324
+
325
+ ax.set_title(f'{model_name.upper()} Prediction Scores',
326
+ fontsize=14, fontweight='bold', color=COLORS['text'])
327
+ ax.set_xlabel('Sample Index', fontsize=11, fontweight='bold')
328
+ ax.set_ylabel('Prediction Score', fontsize=11, fontweight='bold')
329
+ ax.set_ylim([-0.1, 1.1])
330
+ ax.legend(loc='best', framealpha=0.9)
331
+ ax.grid(True, alpha=0.3, linestyle='--')
332
+
333
+ for idx in range(n_models, len(axes)):
334
+ fig.delaxes(axes[idx])
335
+
336
+ plt.tight_layout()
337
+ save_path = os.path.join(save_dir, 'score_vs_label.png')
338
+ plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
339
+ print(f"Saved score vs label plots to {save_path}")
340
+ plt.close()
341
+
342
+ def main():
343
+ print("="*60)
344
+ print("PHISHING DETECTION MODEL EVALUATION REPORT")
345
+ print("="*60)
346
+ print("\nCustom Thresholds Configuration:")
347
+ for model, threshold in MODEL_THRESHOLDS.items():
348
+ print(f" • {model}: {threshold}")
349
+ print()
350
+
351
+ os.makedirs(config.REPORTS_DIR, exist_ok=True)
352
+ os.makedirs(config.MODELS_DIR, exist_ok=True)
353
+
354
+ df = load_sample_data(sample_fraction=0.05)
355
+
356
+ all_predictions = {}
357
+ all_scores = {}
358
+
359
+ X_ml, y = prepare_ml_data(df)
360
+ ml_preds, ml_scores = predict_ml_models(X_ml, y)
361
+ all_predictions.update(ml_preds)
362
+ all_scores.update(ml_scores)
363
+
364
+ X_dl, y_dl = prepare_dl_data(df)
365
+ dl_preds, dl_scores = predict_dl_models(X_dl, y_dl)
366
+ all_predictions.update(dl_preds)
367
+ all_scores.update(dl_scores)
368
+
369
+ bert_pred, bert_score = predict_bert_model(df, y)
370
+ if bert_pred is not None:
371
+ all_predictions['bert'] = bert_pred
372
+ all_scores['bert'] = bert_score
373
+
374
+ if len(all_predictions) == 0:
375
+ print("\nWARNING: No models found! Please train models first.")
376
+ print("Run: python train_ml.py && python train_dl.py")
377
+ return
378
+
379
+ plot_confusion_matrices(y, all_predictions, config.REPORTS_DIR)
380
+ plot_accuracy_comparison(y, all_predictions, config.REPORTS_DIR)
381
+ plot_score_vs_label(y, all_scores, config.REPORTS_DIR)
382
+
383
+ print("\n" + "="*60)
384
+ print("REPORT GENERATION COMPLETE!")
385
+ print(f"All visualizations saved to: {config.REPORTS_DIR}")
386
+ print("="*60)
387
+
388
+ if __name__ == "__main__":
389
+ main()
reports/model_architecture.pdf ADDED
Binary file (35.3 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas>=1.5.0
2
+ numpy>=1.23.0,<2.0.0
3
+ scikit-learn>=1.2.0
4
+ matplotlib>=3.6.0
5
+ seaborn>=0.12.0
6
+ torch>=2.0.0
7
+ xgboost>=1.7.0
8
+ transformers>=4.30.0
9
+ python-whois>=0.8.0
10
+ tldextract>=3.4.0
11
+ pyOpenSSL>=23.0.0
12
+ joblib>=1.2.0
13
+ tqdm>=4.65.0
14
+ fastapi>=0.104.0
15
+ uvicorn>=0.24.0
16
+ google-generativeai>=0.3.0
17
+ scipy>=1.10.0
18
+ pydantic>=2.0.0
19
+ python-multipart>=0.0.6
20
+ httpx>=0.24.0
21
+ groq
22
+ python-dotenv>=1.0.0
23
+ beautifulsoup4>=4.12.0
test_api.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncio import sleep
2
+ import requests
3
+ import json
4
+ import textwrap
5
+
6
+
7
+ API_URL = "https://sharyl-liberalistic-procrastinatingly.ngrok-free.dev/predict"
8
+
9
+
10
+ HEADERS = {
11
+ 'Content-Type': 'application/json',
12
+ 'ngrok-skip-browser-warning': 'true'
13
+ }
14
+
15
+
16
+ messages_to_test = [
17
+ """Dear Akshat Nitinkumar Bhatt of DAIICT
18
+ Become a Campus Manager at Unlox: Lead, Earn, and Grow with AI EdTech!
19
+ Join Unlox as a Campus Manager and be the face of India's top AI-powered EdTech revolution right on your campus! Unlox is a next-gen platform dedicated to revolutionising student learning and career preparation. We're looking for driven, tech-savvy student leaders to champion this movement at their colleges.
20
+ What You'll Gain:
21
+ Earn up to ₹21,000 in Stipends* (get ₹1,000 instantly for every 4 enrollments!).
22
+ Free Edulet after 15 enrollments (Tab with smartlab,blu, job bridge program)
23
+ Official Certificate of Completion & Appraisal Letter for your resume.
24
+ Invaluable Real Leadership & Management Experience.
25
+ FREE Access to PrepFree, our exclusive career platform, featuring:
26
+ AI Resume Builder
27
+ Mock Interviews
28
+ Smart Job Matching
29
+ Dedicated Career Assistance and Job Portals
30
+ Masterclasses and bootcamps
31
+ Your Mission:
32
+ Lead a simple WhatsApp campaign to engage 50+ students from your college. Your goal is to grow the Unlox Student Community—all while gaining excellent real-world experience and significantly boosting your resume.
33
+
34
+ Ready to lead and make a tangible impact?
35
+ Application form - Apply by clicking this link
36
+
37
+ Regards
38
+ Aman K
39
+ Team Leader Unlox.
40
+ https://unlox.com/
41
+ https://in.linkedin.com/company/unloxacademy""",
42
+
43
+ """You're invited to join the Internshala Student Partner program to develop professional skills while gaining valuable experience.
44
+ Ready to represent us
45
+ in DAIICT Gandhinagar?
46
+ Join the Internshala Student Partner Program
47
+ Program benefits include:
48
+ Learning opportunity - Masterclasses by industry professionals
49
+ Career development - Certificates and recommendation letters
50
+ Skill building - Develop leadership and communication skills
51
+ Who can apply? B.Tech students at DAIICT Gandhinagar can learn more about the program by clicking below.
52
+ LEARN MOREInternshala (Scholiverse Educare Pvt. Ltd.)
53
+ 901A and 901B, Iris Tech Park, Sector - 48, Sohna Road, Gurugram
54
+ Not interested? Unsubscribe""",
55
+
56
+ """Hey Akshat,
57
+ While you may still be (or not be!) on the hunt for a full-time opportunity, I wanted to share that we are also starting to partner with organizations that have part-time roles and tasks that you can complete to earn some additional income while still looking for full-time work.
58
+
59
+ We’re currently collaborating with an AI research lab on a project to help train AI models to better understand how real people use computers in day-to-day tasks. As a part of this project, you will have a front-row seat into helping shape the way future AI systems interact with common software, making technology more accessible and useful for everyone.
60
+
61
+ What's involved?
62
+ You'll be asked to perform simple, everyday computer tasks—like creating a Word document, making a presentation slide, or organizing videos online.
63
+ Each task is quick (around 2–3 minutes) and you can do as many or as few as you’d like within a 24-hour period. (Note: Project expires on November 6th, 2025)
64
+
65
+ Requirements:
66
+ For verification, you'll need to install two small programs and allow temporary recording of your screen and keyboard activities, strictly for checking that tasks are completed as described. Your data is protected and only used for research purposes.
67
+
68
+ Compensation:
69
+ You'll earn $0.30 (or Rs. 30 if in India) for each correctly completed and verified task (minimum quality standard: 75% accuracy).
70
+ There are at least 1,000 tasks available—meaning you could earn a upto ~$350 in one day if you choose to finish all. The most active participants can take on up to 20,000 tasks.
71
+
72
+ Next Steps and How to Start:
73
+ Just fill out the form here to get started: https://tinyurl.com/lightning-puneet
74
+ You will receive onboarding instructions over email in 5-10 minutes after filling out the form.
75
+
76
+ You can also join our WhatsApp group for this project to discuss with other contributors - cheers (Link will be provided after registration)!
77
+
78
+ We appreciate your help building smarter, fairer AI systems. Happy to answer any questions!
79
+
80
+ Regards,
81
+ Puneet Kohli
82
+ careerflow.ai""",
83
+ """Dear Akshat!
84
+ 😎 Turn your Curiosity into Epic rewards with the TATA group*!
85
+
86
+ Yes, its true! Join the Tata Crucible Campus Quiz 2025 🎉 open to all in-college students from every stream and background.
87
+
88
+ ✨ Here's what you can win:
89
+ A brand-new iPhone 17
90
+ Win cash prizes worth up to ₹2.5 Lakh*
91
+ Internships* with the TATA group
92
+ A Luxury holiday worth ₹50,000
93
+ Certificates & loads of other rewards (every quiz taker gets a reward)
94
+ 👉 How to Participate? (follow 3 simple steps):
95
+
96
+ 1️⃣ Click the button below to Register
97
+ 2️⃣ Once logged in, click on continue and hit "Complete Details.”
98
+ 3️⃣ Fill in your details, take the quiz, and DM your quiz completion screenshot to Tata Crucible’s official Instagram by following them.
99
+ ⚠️ Important: Registration can only be completed once the basic details are filled in.
100
+
101
+ 🔥 Thousands of students are already in. Don’t miss your chance to shine at India’s biggest campus quiz!
102
+ Register Now
103
+
104
+
105
+ internshala (scholiverse educare pvt. ltd.)
106
+ iris tech park, sohna road, gurugram
107
+
108
+ view it in your browser.
109
+ unsubscribe me from this list
110
+ """
111
+ ]
112
+
113
+ def analyze_message(text):
114
+
115
+ payload = {"text": text}
116
+
117
+ try:
118
+
119
+ response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=10)
120
+
121
+
122
+ if response.status_code == 200:
123
+ prediction = response.json()
124
+
125
+ print(json.dumps(prediction, indent=4))
126
+
127
+
128
+ print("-" * 20)
129
+ print(f"Decision: {prediction.get('final_decision')}")
130
+ print(f"Confidence: {prediction.get('confidence')}%")
131
+ print(f"Reasoning: {prediction.get('reasoning')}")
132
+ print(f"Suggestion: {prediction.get('suggestion')}")
133
+ print("-" * 20)
134
+
135
+ else:
136
+
137
+ print(f"Error: Received status code {response.status_code}")
138
+ print(f"Response text: {response.text}")
139
+
140
+ except requests.exceptions.ConnectionError as e:
141
+ print(f"Connection Error: Could not connect to {API_URL}.")
142
+ print("Please ensure the ngrok tunnel is running and the URL is correct.")
143
+ except requests.exceptions.RequestException as e:
144
+
145
+ print(f"An error occurred during the request: {e}")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ for i, message in enumerate(messages_to_test):
150
+ print(f"================== TESTING MESSAGE {i+1} ==================")
151
+
152
+ print(f"Message Snippet: {textwrap.shorten(message, width=70, placeholder='...')}\n")
153
+
154
+ analyze_message(message)
155
+
156
+ print(f"================END OF TEST FOR MESSAGE {i+1} ================\n\n")
157
+ sleep(6.0)
train_dl.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
+ import config
9
+ from models import get_dl_models, PhishingDataset
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.preprocessing import StandardScaler
12
+ import warnings
13
+
14
+ warnings.filterwarnings('ignore', category=UserWarning)
15
+ warnings.filterwarnings('ignore', category=FutureWarning)
16
+
17
+ def prepare_data(df, numerical_cols):
18
+ print("Preparing data for DL training...")
19
+ X = df[numerical_cols].fillna(-1).values
20
+ y = df['label'].values
21
+
22
+ scaler = StandardScaler()
23
+ X_scaled = scaler.fit_transform(X)
24
+
25
+ return X_scaled, y, scaler
26
+
27
+ def train_dl_model(model, train_loader, val_loader, device, epochs=50, lr=0.001):
28
+ criterion = nn.BCELoss()
29
+ optimizer = optim.Adam(model.parameters(), lr=lr)
30
+
31
+ model.to(device)
32
+ best_val_loss = float('inf')
33
+
34
+ for epoch in range(epochs):
35
+ model.train()
36
+ train_loss = 0.0
37
+ for X_batch, y_batch in train_loader:
38
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
39
+
40
+ optimizer.zero_grad()
41
+ outputs = model(X_batch)
42
+ loss = criterion(outputs, y_batch)
43
+ loss.backward()
44
+ optimizer.step()
45
+
46
+ train_loss += loss.item()
47
+
48
+ model.eval()
49
+ val_loss = 0.0
50
+ correct = 0
51
+ total = 0
52
+
53
+ with torch.no_grad():
54
+ for X_batch, y_batch in val_loader:
55
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
56
+ outputs = model(X_batch)
57
+ loss = criterion(outputs, y_batch)
58
+ val_loss += loss.item()
59
+
60
+ predicted = (outputs > 0.5).float()
61
+ total += y_batch.size(0)
62
+ correct += (predicted == y_batch).sum().item()
63
+
64
+ val_accuracy = correct / total
65
+
66
+ if (epoch + 1) % 10 == 0:
67
+ print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_accuracy:.4f}")
68
+
69
+ if val_loss < best_val_loss:
70
+ best_val_loss = val_loss
71
+
72
+ return model
73
+
74
+ def main():
75
+ print("--- Starting DL Model Training ---")
76
+ os.makedirs(config.MODELS_DIR, exist_ok=True)
77
+
78
+ try:
79
+ df = pd.read_csv(config.ENGINEERED_TRAIN_FILE)
80
+ except FileNotFoundError:
81
+ print(f"Error: '{config.ENGINEERED_TRAIN_FILE}' not found.")
82
+ print("Please run `python data_pipeline.py` first.")
83
+ return
84
+
85
+ X_scaled, y, scaler = prepare_data(df, config.NUMERICAL_FEATURES)
86
+
87
+ X_train, X_val, y_train, y_val = train_test_split(
88
+ X_scaled, y,
89
+ test_size=config.ML_TEST_SIZE,
90
+ random_state=config.ML_MODEL_RANDOM_STATE,
91
+ stratify=y
92
+ )
93
+
94
+ print(f"Training on {len(X_train)} samples, validating on {len(X_val)} samples.")
95
+
96
+ train_dataset = PhishingDataset(X_train, y_train)
97
+ val_dataset = PhishingDataset(X_val, y_val)
98
+
99
+ train_loader = DataLoader(train_dataset, batch_size=config.DL_BATCH_SIZE, shuffle=True)
100
+ val_loader = DataLoader(val_dataset, batch_size=config.DL_BATCH_SIZE, shuffle=False)
101
+
102
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
103
+ print(f"Using device: {device}")
104
+
105
+ input_dim = X_train.shape[1]
106
+ dl_models = get_dl_models(input_dim)
107
+
108
+ for name, model in dl_models.items():
109
+ print(f"\n--- Training {name} ---")
110
+
111
+ trained_model = train_dl_model(
112
+ model, train_loader, val_loader, device,
113
+ epochs=config.DL_EPOCHS,
114
+ lr=config.DL_LEARNING_RATE
115
+ )
116
+
117
+ save_path = os.path.join(config.MODELS_DIR, f"{name}.pt")
118
+ torch.save(trained_model.state_dict(), save_path)
119
+ print(f"Model saved to {save_path}")
120
+
121
+ scaler_path = os.path.join(config.MODELS_DIR, "dl_scaler.pkl")
122
+ import joblib
123
+ joblib.dump(scaler, scaler_path)
124
+ print(f"Scaler saved to {scaler_path}")
125
+
126
+ print("\n--- DL Model Training Complete ---")
127
+
128
+ if __name__ == "__main__":
129
+ main()
train_ml.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import joblib
4
+ import config
5
+ from models import get_ml_models
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.metrics import accuracy_score
11
+ import warnings
12
+
13
+ warnings.filterwarnings('ignore', category=UserWarning)
14
+ warnings.filterwarnings('ignore', category=FutureWarning)
15
+
16
+ def prepare_data(df, numerical_cols, categorical_cols):
17
+ print("Preparing data for ML training...")
18
+ X = df[numerical_cols + categorical_cols]
19
+ y = df['label']
20
+ numerical_transformer = Pipeline(steps=[
21
+ ('imputer', 'passthrough'),
22
+ ('scaler', StandardScaler())
23
+ ])
24
+ categorical_transformer = Pipeline(steps=[
25
+ ('imputer', 'passthrough'),
26
+ ('onehot', OneHotEncoder(handle_unknown='ignore'))
27
+ ])
28
+ X.loc[:, numerical_cols] = X.loc[:, numerical_cols].fillna(-1)
29
+ X.loc[:, categorical_cols] = X.loc[:, categorical_cols].fillna('N/A')
30
+
31
+ preprocessor = ColumnTransformer(
32
+ transformers=[
33
+ ('num', numerical_transformer, numerical_cols),
34
+ ('cat', categorical_transformer, categorical_cols)
35
+ ],
36
+ remainder='passthrough'
37
+ )
38
+
39
+ return preprocessor, X, y
40
+
41
+ def main():
42
+ print("--- Starting ML Model Training ---")
43
+ os.makedirs(config.MODELS_DIR, exist_ok=True)
44
+ try:
45
+ df = pd.read_csv(config.ENGINEERED_TRAIN_FILE)
46
+ except FileNotFoundError:
47
+ print(f"Error: '{config.ENGINEERED_TRAIN_FILE}' not found.")
48
+ print("Please run `python data_pipeline.py` first.")
49
+ return
50
+
51
+ preprocessor, X, y = prepare_data(
52
+ df,
53
+ config.NUMERICAL_FEATURES,
54
+ config.CATEGORICAL_FEATURES
55
+ )
56
+ X_train, X_val, y_train, y_val = train_test_split(
57
+ X, y,
58
+ test_size=config.ML_TEST_SIZE,
59
+ random_state=config.ML_MODEL_RANDOM_STATE,
60
+ stratify=y
61
+ )
62
+
63
+ print(f"Training on {len(X_train)} samples, validating on {len(X_val)} samples.")
64
+ ml_models = get_ml_models()
65
+ for name, model in ml_models.items():
66
+ print(f"\n--- Training {name} ---")
67
+
68
+ model_pipeline = Pipeline(steps=[
69
+ ('preprocessor', preprocessor),
70
+ ('classifier', model)
71
+ ])
72
+
73
+ model_pipeline.fit(X_train, y_train)
74
+ y_pred = model_pipeline.predict(X_val)
75
+ val_accuracy = accuracy_score(y_val, y_pred)
76
+ print(f"Validation Accuracy for {name}: {val_accuracy:.4f}")
77
+
78
+ save_path = os.path.join(config.MODELS_DIR, f"{name}.joblib")
79
+ joblib.dump(model_pipeline, save_path)
80
+ print(f"Model saved to {save_path}")
81
+
82
+ print("\n--- ML Model Training Complete ---")
83
+
84
+ if __name__ == "__main__":
85
+ main()