| """ |
| Enhanced Code validation utilities for Manim code generation |
| """ |
|
|
| import re |
| import ast |
| import logging |
| from typing import Tuple, Optional, Dict, Any, List |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class CodeValidator: |
| """Validates and fixes Manim code generation issues.""" |
| |
| @staticmethod |
| def fix_common_issues(code: str) -> str: |
| """Fix common issues in generated Manim code.""" |
| |
| code = re.sub(r'```python\n?', '', code) |
| code = re.sub(r'```\n?', '', code) |
| code = code.replace('```', '') |
| code = code.replace('`', '') |
| |
| |
| code = re.sub(r'Text\(([^"\'()\[\]]*?)\)', r'Text("\1")', code) |
| code = re.sub(r'Text\(([^"\']*?), font_size', r'Text("\1", font_size', code) |
| |
| |
| lines = code.split('\n') |
| fixed_lines = [] |
| for line in lines: |
| if 'Text(' in line and line.count('"') % 2 != 0: |
| |
| line = re.sub(r'([^"]*?)(,|\))', r'\1"\2', line) |
| fixed_lines.append(line) |
| |
| code = '\n'.join(fixed_lines) |
| |
| |
| if not re.search(r'from manim import \*', code): |
| code = 'from manim import *\n\n' + code |
| |
| return code.strip() |
| |
| @staticmethod |
| def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]: |
| """Validate Python syntax of the generated code.""" |
| try: |
| ast.parse(code) |
| return True, None |
| except SyntaxError as e: |
| return False, str(e) |
| |
| @staticmethod |
| def has_scene_class(code: str) -> bool: |
| """Check if the code has a proper Scene class.""" |
| try: |
| tree = ast.parse(code) |
| for node in ast.walk(tree): |
| if isinstance(node, ast.ClassDef): |
| |
| for base in node.bases: |
| if isinstance(base, ast.Name) and base.id == 'Scene': |
| return True |
| elif isinstance(base, ast.Attribute) and base.attr == 'Scene': |
| return True |
| return False |
| except SyntaxError: |
| return False |
| |
| @staticmethod |
| def has_construct_method(code: str) -> bool: |
| """Check if the Scene class has a construct method.""" |
| try: |
| tree = ast.parse(code) |
| for node in ast.walk(tree): |
| if isinstance(node, ast.ClassDef): |
| for item in node.body: |
| if isinstance(item, ast.FunctionDef) and item.name == 'construct': |
| return True |
| return False |
| except SyntaxError: |
| return False |
| |
| @staticmethod |
| def has_animations(code: str) -> bool: |
| """Check if the code has at least one animation.""" |
| required_patterns = [ |
| r'self\.play\(', |
| r'self\.wait\(' |
| ] |
| |
| for pattern in required_patterns: |
| if not re.search(pattern, code): |
| return False |
| return True |
| |
| @staticmethod |
| def validate_manim_code(code: str) -> Tuple[bool, Optional[str]]: |
| """Validate that the code meets Manim requirements.""" |
| |
| if not CodeValidator.has_scene_class(code): |
| return False, "Code must contain a class that inherits from Scene" |
| |
| if not CodeValidator.has_construct_method(code): |
| return False, "Scene class must have a construct method" |
| |
| if not CodeValidator.has_animations(code): |
| return False, "Code must include at least one animation (self.play) and wait" |
| |
| |
| problematic_patterns = [ |
| r'for\s+\*', |
| r'\*\s*=', |
| r'^\s*\*', |
| r'get\*center', |
| r'import\s+\*', |
| r'from\s+[^\s]+\s+import\s+\*' |
| ] |
| |
| for pattern in problematic_patterns: |
| if re.search(pattern, code, re.MULTILINE): |
| return False, f"Code contains problematic pattern: {pattern}" |
| |
| return True, None |