Rthur2003 commited on
Commit
cbff5e0
·
1 Parent(s): f37f03d

fix: enhance CORS origin handling and update validation for YouTube URLs

Browse files
app/main.py CHANGED
@@ -23,7 +23,7 @@ logger = get_logger(__name__)
23
 
24
 
25
  def _load_origins() -> list[str]:
26
- raw = os.getenv("CROWNCODE_CORS_ORIGINS") or os.getenv("CORS_ORIGIN", "http://localhost:3000")
27
  if raw.strip() == "*":
28
  logger.warning("CORS configured to allow all origins")
29
  return ["*"]
@@ -74,10 +74,12 @@ async def general_exception_handler(request: Request, exc: Exception) -> JSONRes
74
  )
75
 
76
 
 
77
  app.add_middleware(
78
  CORSMiddleware,
79
- allow_origins=_load_origins(),
80
- allow_credentials=True,
 
81
  allow_methods=["*"],
82
  allow_headers=["*"],
83
  )
 
23
 
24
 
25
  def _load_origins() -> list[str]:
26
+ raw = os.getenv("CROWNCODE_CORS_ORIGINS") or os.getenv("CORS_ORIGIN") or "http://localhost:3000"
27
  if raw.strip() == "*":
28
  logger.warning("CORS configured to allow all origins")
29
  return ["*"]
 
74
  )
75
 
76
 
77
+ _origins = _load_origins()
78
  app.add_middleware(
79
  CORSMiddleware,
80
+ allow_origins=_origins,
81
+ # credentials=True is only safe with an explicit origin list, not wildcard
82
+ allow_credentials="*" not in _origins,
83
  allow_methods=["*"],
84
  allow_headers=["*"],
85
  )
app/schemas.py CHANGED
@@ -54,9 +54,11 @@ class YouTubeAnalyzeResponse(BaseModel):
54
 
55
 
56
  class AudioAugmentationOptions(BaseModel):
57
- pitch_shift: bool = Field(default=False, description="Apply random pitch shifting")
58
- speed_change: bool = Field(default=False, description="Apply random speed change")
59
- bass_boost: bool = Field(default=False, description="Apply bass boost equalization")
60
- trim_silence: bool = Field(default=False, description="Trim leading and trailing silence")
61
- mix_audio: bool = Field(default=False, description="Mix with another audio track (placeholder)")
62
- add_noise: bool = Field(default=False, description="Add Gaussian noise")
 
 
 
54
 
55
 
56
  class AudioAugmentationOptions(BaseModel):
57
+ model_config = {"populate_by_name": True}
58
+
59
+ pitch_shift: bool = Field(default=False, alias="pitchShift", description="Apply random pitch shifting")
60
+ speed_change: bool = Field(default=False, alias="speedChange", description="Apply random speed change")
61
+ bass_boost: bool = Field(default=False, alias="bassBoost", description="Apply bass boost equalization")
62
+ trim_silence: bool = Field(default=False, alias="trimSilence", description="Trim leading and trailing silence")
63
+ mix_audio: bool = Field(default=False, alias="mixAudio", description="Mix with another audio track (placeholder)")
64
+ add_noise: bool = Field(default=False, alias="addNoise", description="Add Gaussian noise")
app/services/url_parser.py CHANGED
@@ -51,7 +51,7 @@ def _extract_video_id(parsed_url) -> Optional[str]:
51
  candidate = path.strip("/").split("/")[0]
52
  return candidate or None
53
 
54
- if "youtube.com" in host or "music.youtube.com" in host:
55
  if path == "/watch":
56
  return query.get("v", [None])[0]
57
  if path.startswith("/shorts/") or path.startswith("/live/") or path.startswith("/embed/"):
 
51
  candidate = path.strip("/").split("/")[0]
52
  return candidate or None
53
 
54
+ if host in {"youtube.com", "www.youtube.com", "m.youtube.com", "music.youtube.com"}:
55
  if path == "/watch":
56
  return query.get("v", [None])[0]
57
  if path.startswith("/shorts/") or path.startswith("/live/") or path.startswith("/embed/"):
app/services/validation.py CHANGED
@@ -68,18 +68,29 @@ def validate_url(url: str) -> bool:
68
  if any(char in url for char in dangerous_chars):
69
  return False
70
 
71
- allowed_domains = [
 
 
 
 
 
 
 
 
72
  'youtube.com',
73
- 'youtu.be',
 
74
  'music.youtube.com',
 
 
75
  'spotify.com',
76
- 'open.spotify.com'
77
- ]
78
-
79
- url_lower = url.lower()
80
- if not any(domain in url_lower for domain in allowed_domains):
81
  return False
82
-
83
  return True
84
 
85
 
 
68
  if any(char in url for char in dangerous_chars):
69
  return False
70
 
71
+ from urllib.parse import urlparse
72
+
73
+ try:
74
+ parsed = urlparse(url)
75
+ host = (parsed.hostname or '').lower()
76
+ except Exception:
77
+ return False
78
+
79
+ allowed_hosts = {
80
  'youtube.com',
81
+ 'www.youtube.com',
82
+ 'm.youtube.com',
83
  'music.youtube.com',
84
+ 'youtu.be',
85
+ 'www.youtu.be',
86
  'spotify.com',
87
+ 'www.spotify.com',
88
+ 'open.spotify.com',
89
+ }
90
+
91
+ if host not in allowed_hosts:
92
  return False
93
+
94
  return True
95
 
96