BrowserForge / client.py
cryptodarth's picture
V1
42d1599
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""OpenEnv client for the Browser RL environment."""
import os
from typing import Dict, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from websockets.asyncio.client import connect as ws_connect
try:
from .models import (
BrowserAction,
BrowserElement,
BrowserObservation,
ConstraintState,
RewardBreakdown,
)
except ImportError: # pragma: no cover - direct repo execution
from models import (
BrowserAction,
BrowserElement,
BrowserObservation,
ConstraintState,
RewardBreakdown,
)
class BrowserEnv(EnvClient[BrowserAction, BrowserObservation, State]):
"""Client for BrowserGym-backed OpenEnv browser episodes."""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
max_message_size_mb: float = 100.0,
provider=None,
mode: Optional[str] = None,
auth_token: Optional[str] = None,
):
super().__init__(
base_url=base_url,
connect_timeout_s=connect_timeout_s,
message_timeout_s=message_timeout_s,
max_message_size_mb=max_message_size_mb,
provider=provider,
mode=mode,
)
token = auth_token or os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
self._ws_headers = {"Authorization": f"Bearer {token}"} if token else None
async def connect(self) -> "BrowserEnv":
"""Establish WebSocket connection to the server, with optional HF auth headers."""
if self._ws is not None:
return self
ws_url_lower = self._ws_url.lower()
is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
old_no_proxy = os.environ.get("NO_PROXY")
if is_localhost:
current_no_proxy = old_no_proxy or ""
if "localhost" not in current_no_proxy.lower():
os.environ["NO_PROXY"] = (
f"{current_no_proxy},localhost,127.0.0.1"
if current_no_proxy
else "localhost,127.0.0.1"
)
try:
kwargs = {
"open_timeout": self._connect_timeout,
"max_size": self._max_message_size,
}
if self._ws_headers:
kwargs["additional_headers"] = self._ws_headers
self._ws = await ws_connect(self._ws_url, **kwargs)
except Exception as e:
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
finally:
if is_localhost:
if old_no_proxy is None:
os.environ.pop("NO_PROXY", None)
else:
os.environ["NO_PROXY"] = old_no_proxy
return self
def _step_payload(self, action: BrowserAction) -> Dict:
if hasattr(action, "model_dump"):
return action.model_dump()
return action.dict()
def _parse_result(self, payload: Dict) -> StepResult[BrowserObservation]:
obs_data = payload.get("observation", {})
observation = _parse_observation(obs_data, payload)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
def _parse_observation(obs_data: Dict, payload: Dict) -> BrowserObservation:
"""Hydrate a typed observation from the raw OpenEnv/WebSocket payload.
The server sends a mostly-JSON observation plus a few transport-level fields
such as the final reward and done flag. Keeping this translation in one
place makes notebook/debug tooling much easier to reason about.
"""
constraints = obs_data.get("constraints") or {}
reward_breakdown = obs_data.get("reward_breakdown") or {}
return BrowserObservation(
episode_id=obs_data.get("episode_id", ""),
task_id=obs_data.get("task_id", ""),
task_family=obs_data.get("task_family", ""),
difficulty=obs_data.get("difficulty", "easy"),
instruction=obs_data.get("instruction", ""),
url=obs_data.get("url", ""),
step_index=obs_data.get("step_index", 0),
max_steps=obs_data.get("max_steps", 0),
elements=[BrowserElement(**element) for element in obs_data.get("elements", [])],
history=obs_data.get("history", []),
constraints=ConstraintState(**constraints),
reward_breakdown=RewardBreakdown(**reward_breakdown),
done=payload.get("done", obs_data.get("done", False)),
reward=payload.get("reward", obs_data.get("reward")),
success=obs_data.get("success", False),
failure_reason=obs_data.get("failure_reason", "none"),
metadata=obs_data.get("metadata", {}),
)