Spaces:
Sleeping
Sleeping
| # server/env.py | |
| import os | |
| import re | |
| import shutil | |
| import tempfile | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Tuple, Dict, Any | |
| import sys | |
| from openenv.core.env_server import Environment, State | |
| from src.jira_to_code.models import JiraCodeAction, JiraCodeObservation | |
| class JiraToCodeEnv(Environment): | |
| TASKS = { | |
| "easy": { | |
| "dir": "src/jira_to_code/tasks/easy", | |
| "ticket": ( | |
| "TICKET-101: Fix the off-by-one bug in calculator.add() function. " | |
| "It should correctly sum two numbers." | |
| ), | |
| }, | |
| "easy_2": { | |
| "dir": "src/jira_to_code/tasks/easy_2", | |
| "ticket": ( | |
| "TICKET-102: Fix the bug in string_utils.count_vowels(). " | |
| "It currently only counts lowercase vowels but should be case-insensitive." | |
| ), | |
| }, | |
| "easy_3": {"dir": "src/jira_to_code/tasks/easy_3", "ticket": "TICKET-E3: The API endpoint crashes with a KeyError when a user payload doesn't contain an optional 'phone_number' field. Change dictionary indexing to .get() with a fallback."}, | |
| "easy_4": {"dir": "src/jira_to_code/tasks/easy_4", "ticket": "TICKET-E4: Off-by-One Pagination. get_page_bounds(page, size) misses the 10th item on page 1. Fix the math index logic."}, | |
| "easy_5": {"dir": "src/jira_to_code/tasks/easy_5", "ticket": "TICKET-E5: FastAPI Route Typo. Route signature is id instead of user_id. Fix the parameter mismatch."}, | |
| "medium": { | |
| "dir": "src/jira_to_code/tasks/medium", | |
| "ticket": ( | |
| "TICKET-201: Implement format_user_data in formatter.py. " | |
| "It should format dictionary data to 'LAST_NAME, First_name (Age: X)'. " | |
| "Handle missing age by defaulting to 'Unknown'." | |
| ), | |
| }, | |
| "medium_2": { | |
| "dir": "src/jira_to_code/tasks/medium_2", | |
| "ticket": ( | |
| "TICKET-202: Implement validate_email() and validate_password() in validator.py. " | |
| "Email: must have exactly one '@', at least 1 char before '@', a '.' after '@' with chars around it. " | |
| "Password: at least 8 chars, one uppercase, one lowercase, one digit." | |
| ), | |
| }, | |
| "medium_3": {"dir": "src/jira_to_code/tasks/medium_3", "ticket": "TICKET-M3: Missing Authentication Middleware. A sensitive endpoint (/api/billing) is exposed. Import @require_auth from auth.py and apply it to the route in routes.py."}, | |
| "medium_4": {"dir": "src/jira_to_code/tasks/medium_4", "ticket": "TICKET-M4: N+1 Database Problem. Rewrite the ORM query to use a JOIN (e.g., select_related)."}, | |
| "medium_5": {"dir": "src/jira_to_code/tasks/medium_5", "ticket": "TICKET-M5: Flawed Regex Validation. validate_email rejects emails with a plus sign. Update regex to allow user+test@gmail.com."}, | |
| "medium_6": {"dir": "src/jira_to_code/tasks/medium_6", "ticket": "TICKET-M6: Incomplete Error Handling. fetching currency rates crashes on timeout. Wrap in try/except and return a cached fallback value."}, | |
| "medium_7": {"dir": "src/jira_to_code/tasks/medium_7", "ticket": "TICKET-M7: Stale Cache Bug. update_user_profile updates DB but forgets to call redis.delete('user:id'). Invalidate the cache."}, | |
| "medium_8": {"dir": "src/jira_to_code/tasks/medium_8", "ticket": "TICKET-M8: Timezone Naive Conversion. Event scheduling function creates naive datetimes. Make them UTC aware."}, | |
| "medium_9": {"dir": "src/jira_to_code/tasks/medium_9", "ticket": "TICKET-M9: State Machine Loophole. Cart state machine allows CANCELLED to SHIPPED. Add transition guards."}, | |
| "medium_10": {"dir": "src/jira_to_code/tasks/medium_10", "ticket": "TICKET-M10: Config Merge Overwrite. YAML merge completely overwrites nested dictionaries. Fix recursion logic."}, | |
| "hard": { | |
| "dir": "src/jira_to_code/tasks/hard", | |
| "ticket": ( | |
| "TICKET-301: Implement an LRUCache class in lru_cache.py with put() and get() methods. " | |
| "O(1) time complexity expected. Evict least recently used when capacity is reached." | |
| ), | |
| }, | |
| "hard_2": { | |
| "dir": "src/jira_to_code/tasks/hard_2", | |
| "ticket": ( | |
| "TICKET-302: Implement a DirectedGraph class in graph.py with add_edge(), " | |
| "has_path() (BFS/DFS), and topological_sort() methods. " | |
| "topological_sort() must return an empty list if a cycle is detected." | |
| ), | |
| }, | |
| "hard_3": {"dir": "src/jira_to_code/tasks/hard_3", "ticket": "TICKET-H3: Circular Dependency Resolution. models.py, utils.py, config.py. Extract shared logic into base.py."}, | |
| "hard_4": {"dir": "src/jira_to_code/tasks/hard_4", "ticket": "TICKET-H4: Race Condition in Thread Worker. Refactor the architecture to use queue.Queue."}, | |
| "hard_5": {"dir": "src/jira_to_code/tasks/hard_5", "ticket": "TICKET-H5: OOM Generator Fix. Readlines causes crash on 5GB file. Rewrite to yield generators."}, | |
| "hard_6": {"dir": "src/jira_to_code/tasks/hard_6", "ticket": "TICKET-H6: Implement Abstract Base Class. Implement StripeGateway matching PaymentGateway abstract class."}, | |
| "hard_7": {"dir": "src/jira_to_code/tasks/hard_7", "ticket": "TICKET-H7: Deadlock in Asyncio. Route acquires threading.Lock but forgets to release on exception. Use async context managers."}, | |
| } | |
| # Reward configuration | |
| STEP_PENALTY = -0.01 # Small penalty per step to encourage efficiency | |
| GRACE_STEPS = 3 # No penalty for first N steps (orientation phase) | |
| def __init__(self): | |
| super().__init__() | |
| self.step_count = 0 | |
| self.workspace_dir = None | |
| self.task_level = "easy" | |
| self.task_source_dir = None | |
| self.jira_ticket = "" | |
| def _get_file_tree(self) -> list[str]: | |
| if not self.workspace_dir: | |
| return [] | |
| tree = [] | |
| for root, _, files in os.walk(self.workspace_dir): | |
| for file in files: | |
| if "__pycache__" in root or file.endswith(".pyc"): | |
| continue | |
| rel_path = Path(root) / file | |
| tree.append(str(rel_path.relative_to(self.workspace_dir))) | |
| return tree | |
| def _parse_pytest_results(output: str) -> tuple[int, int]: | |
| """Extract (passed, total) from pytest output for partial-credit scoring.""" | |
| match_passed = re.search(r'(\d+) passed', output) | |
| passed = int(match_passed.group(1)) if match_passed else 0 | |
| match_failed = re.search(r'(\d+) failed', output) | |
| failed = int(match_failed.group(1)) if match_failed else 0 | |
| match_error = re.search(r'(\d+) error', output) | |
| errors = int(match_error.group(1)) if match_error else 0 | |
| total = passed + failed + errors | |
| return passed, max(total, 1) | |
| def reset(self) -> JiraCodeObservation: | |
| self.step_count = 0 | |
| if self.workspace_dir and Path(self.workspace_dir).exists(): | |
| shutil.rmtree(self.workspace_dir) | |
| # Re-read task level from environment variable on every reset | |
| self.task_level = os.getenv("JIRA_TASK_LEVEL", "medium").lower() | |
| if self.task_level not in self.TASKS: | |
| self.task_level = "easy" | |
| self.task_source_dir = Path(self.TASKS[self.task_level]["dir"]).resolve() | |
| self.jira_ticket = self.TASKS[self.task_level]["ticket"] | |
| self.workspace_dir = tempfile.mkdtemp(prefix=f"jira_env_{self.task_level}_") | |
| if self.task_source_dir.exists(): | |
| shutil.copytree(self.task_source_dir, self.workspace_dir, dirs_exist_ok=True) | |
| else: | |
| print(f"Warning: Task directory {self.task_source_dir} not found!") | |
| return JiraCodeObservation( | |
| jira_ticket=self.jira_ticket, | |
| file_tree=self._get_file_tree(), | |
| ) | |
| def step(self, action: JiraCodeAction) -> Tuple[JiraCodeObservation, float, bool, Dict[str, Any]]: | |
| self.step_count += 1 | |
| reward = 0.0 | |
| done = False | |
| current_file_content = None | |
| test_output = None | |
| error = None | |
| workspace_path = Path(self.workspace_dir).resolve() | |
| try: | |
| if action.action_type == "list_files": | |
| current_file_content = "\n".join(self._get_file_tree()) | |
| elif action.action_type in ["read_file", "write_file"]: | |
| if not action.file_path: | |
| error = "file_path must be provided for read/write actions." | |
| else: | |
| target_path = (workspace_path / action.file_path).resolve() | |
| if not target_path.is_relative_to(workspace_path): | |
| error = "Access denied: cannot access files outside workspace." | |
| elif action.action_type == "read_file": | |
| if target_path.exists(): | |
| current_file_content = target_path.read_text() | |
| else: | |
| error = f"File not found: {action.file_path}" | |
| elif action.action_type == "write_file": | |
| if action.content is None: | |
| error = "content must be provided for write_file action." | |
| else: | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| target_path.write_text(action.content) | |
| current_file_content = action.content | |
| reward = 0.05 # Small shaping reward for taking action | |
| elif action.action_type == "run_tests": | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pytest", "-v"], | |
| cwd=self.workspace_dir, | |
| capture_output=True, text=True, timeout=30, | |
| ) | |
| test_output = result.stdout + "\n" + result.stderr | |
| passed, total = self._parse_pytest_results(test_output) | |
| if result.returncode == 0: | |
| # All tests pass — strong positive signal | |
| reward = 0.1 + 0.4 * (passed / total) | |
| elif result.returncode == 1: | |
| # Some tests fail — partial credit | |
| reward = 0.1 * (passed / total) | |
| else: | |
| # Collection error / crash | |
| reward = -0.1 | |
| elif action.action_type == "submit": | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pytest", "-v"], | |
| cwd=self.workspace_dir, | |
| capture_output=True, text=True, timeout=30, | |
| ) | |
| test_output = result.stdout + "\n" + result.stderr | |
| passed, total = self._parse_pytest_results(test_output) | |
| done = True | |
| if result.returncode == 0: | |
| reward = 1.0 # Full marks | |
| else: | |
| reward = 0.5 * (passed / total) # Partial credit on submit | |
| except subprocess.TimeoutExpired: | |
| error = "Tests timed out after 30 seconds." | |
| test_output = "TIMEOUT" | |
| reward = -0.1 | |
| except Exception as e: | |
| error = f"System error: {str(e)}" | |
| reward = -0.2 | |
| # Apply shaping rewards based on step count | |
| if self.step_count <= 3: | |
| reward += 0.02 | |
| else: | |
| reward -= 0.01 | |
| # Enforce strictly bounded rewards for OpenEnv requirements (between 0.01 and 0.99) | |
| reward = max(0.01, min(0.99, reward)) | |
| obs = JiraCodeObservation( | |
| jira_ticket=self.jira_ticket, | |
| file_tree=self._get_file_tree(), | |
| current_file_content=current_file_content, | |
| test_output=test_output, | |
| error=error, | |
| ) | |
| return obs, reward, done, {} | |
| def state(self) -> State: | |
| return State( | |
| episode_id=f"jira-{self.task_level}-{self.step_count}", | |
| step_count=self.step_count, | |
| ) | |
| def close(self): | |
| if self.workspace_dir and Path(self.workspace_dir).exists(): | |
| shutil.rmtree(self.workspace_dir) |