Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/adamant-black-perch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Remove nest-asyncio dependency
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Python SDK for Stagehand"
readme = "README.md"
classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent",]
requires-python = ">=3.9"
dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.99.6", "anthropic>=0.51.0", "litellm>=1.72.0,<=1.80.0", "nest-asyncio>=1.6.0", "google-genai>=1.40.0",]
dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.99.6", "anthropic>=0.51.0", "litellm>=1.72.0,<=1.80.0", "google-genai>=1.40.0",]
[[project.authors]]
name = "Browserbase, Inc."
email = "[email protected]"
Expand Down
135 changes: 87 additions & 48 deletions stagehand/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .metrics import StagehandMetrics
from .utils import convert_dict_keys_to_camel_case

__all__ = ["_create_session", "_execute", "_get_replay_metrics"]
__all__ = ["_create_session", "_execute", "_get_replay_metrics", "_get_replay_metrics_sync"]


async def _create_session(self):
Expand Down Expand Up @@ -210,11 +210,59 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
raise


async def _get_replay_metrics(self):
def _parse_replay_metrics_data(data: dict) -> StagehandMetrics:
"""
Fetch replay metrics from the API and parse them into StagehandMetrics.
Parse raw API response data into StagehandMetrics.
Shared by both async and sync fetch paths.
"""
if not data.get("success"):
raise RuntimeError(
f"Failed to fetch metrics: {data.get('error', 'Unknown error')}"
)

api_data = data.get("data", {})
metrics = StagehandMetrics()

pages = api_data.get("pages", [])
for page in pages:
actions = page.get("actions", [])
for action in actions:
method = action.get("method", "").lower()
token_usage = action.get("tokenUsage", {})

if token_usage:
input_tokens = token_usage.get("inputTokens", 0)
output_tokens = token_usage.get("outputTokens", 0)
time_ms = token_usage.get("timeMs", 0)

if method == "act":
metrics.act_prompt_tokens += input_tokens
metrics.act_completion_tokens += output_tokens
metrics.act_inference_time_ms += time_ms
elif method == "extract":
metrics.extract_prompt_tokens += input_tokens
metrics.extract_completion_tokens += output_tokens
metrics.extract_inference_time_ms += time_ms
elif method == "observe":
metrics.observe_prompt_tokens += input_tokens
metrics.observe_completion_tokens += output_tokens
metrics.observe_inference_time_ms += time_ms
elif method == "agent":
metrics.agent_prompt_tokens += input_tokens
metrics.agent_completion_tokens += output_tokens
metrics.agent_inference_time_ms += time_ms

metrics.total_prompt_tokens += input_tokens
metrics.total_completion_tokens += output_tokens
metrics.total_inference_time_ms += time_ms

return metrics


async def _get_replay_metrics(self):
"""
Fetch replay metrics from the API (async version).
"""
if not self.session_id:
raise ValueError("session_id is required to fetch metrics.")

Expand All @@ -241,55 +289,46 @@ async def _get_replay_metrics(self):
f"Failed to fetch metrics with status {response.status_code}: {error_text}"
)

data = response.json()
return _parse_replay_metrics_data(response.json())

except Exception as e:
self.logger.error(f"[EXCEPTION] Error fetching replay metrics: {str(e)}")
raise


def _get_replay_metrics_sync(self):
"""
Fetch replay metrics from the API (sync version).
Uses a synchronous httpx request so it can be called from sync contexts
even when an async event loop is already running.
"""
import httpx

if not self.session_id:
raise ValueError("session_id is required to fetch metrics.")

headers = {
"x-bb-api-key": self.browserbase_api_key,
"x-bb-project-id": self.browserbase_project_id,
"Content-Type": "application/json",
}

if not data.get("success"):
try:
response = httpx.get(
f"{self.api_url}/sessions/{self.session_id}/replay",
headers=headers,
timeout=self.timeout_settings,
)

if response.status_code != 200:
self.logger.error(
f"[HTTP ERROR] Failed to fetch metrics. Status {response.status_code}: {response.text}"
)
raise RuntimeError(
f"Failed to fetch metrics: {data.get('error', 'Unknown error')}"
f"Failed to fetch metrics with status {response.status_code}: {response.text}"
)

# Parse the API data into StagehandMetrics format
api_data = data.get("data", {})
metrics = StagehandMetrics()

# Parse pages and their actions
pages = api_data.get("pages", [])
for page in pages:
actions = page.get("actions", [])
for action in actions:
# Get method name and token usage
method = action.get("method", "").lower()
token_usage = action.get("tokenUsage", {})

if token_usage:
input_tokens = token_usage.get("inputTokens", 0)
output_tokens = token_usage.get("outputTokens", 0)
time_ms = token_usage.get("timeMs", 0)

# Map method to metrics fields
if method == "act":
metrics.act_prompt_tokens += input_tokens
metrics.act_completion_tokens += output_tokens
metrics.act_inference_time_ms += time_ms
elif method == "extract":
metrics.extract_prompt_tokens += input_tokens
metrics.extract_completion_tokens += output_tokens
metrics.extract_inference_time_ms += time_ms
elif method == "observe":
metrics.observe_prompt_tokens += input_tokens
metrics.observe_completion_tokens += output_tokens
metrics.observe_inference_time_ms += time_ms
elif method == "agent":
metrics.agent_prompt_tokens += input_tokens
metrics.agent_completion_tokens += output_tokens
metrics.agent_inference_time_ms += time_ms

# Always update totals for any method with token usage
metrics.total_prompt_tokens += input_tokens
metrics.total_completion_tokens += output_tokens
metrics.total_inference_time_ms += time_ms

return metrics
return _parse_replay_metrics_data(response.json())

except Exception as e:
self.logger.error(f"[EXCEPTION] Error fetching replay metrics: {str(e)}")
Expand Down
21 changes: 14 additions & 7 deletions stagehand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Optional

import httpx
import nest_asyncio
from dotenv import load_dotenv
from playwright.async_api import (
BrowserContext,
Expand All @@ -17,7 +16,12 @@
from playwright.async_api import Page as PlaywrightPage

from .agent import Agent
from .api import _create_session, _execute, _get_replay_metrics
from .api import (
_create_session,
_execute,
_get_replay_metrics,
_get_replay_metrics_sync,
)
from .browser import (
cleanup_browser_resources,
connect_browserbase_browser,
Expand Down Expand Up @@ -782,12 +786,14 @@ def __getattribute__(self, name):
# Try to get current event loop
try:
asyncio.get_running_loop()
# We're in an async context, need to handle this carefully
# Create a new task and wait for it
nest_asyncio.apply()
return asyncio.run(get_replay_metrics())
# Already in async context - use sync HTTP to avoid
# event loop nesting issues
get_replay_metrics_sync = object.__getattribute__(
self, "_get_replay_metrics_sync"
)
return get_replay_metrics_sync()
except RuntimeError:
# No event loop running, we can use asyncio.run directly
# No event loop running, safe to use asyncio.run
return asyncio.run(get_replay_metrics())
except Exception as e:
# Log error and return empty metrics
Expand All @@ -807,3 +813,4 @@ def __getattribute__(self, name):
Stagehand._create_session = _create_session
Stagehand._execute = _execute
Stagehand._get_replay_metrics = _get_replay_metrics
Stagehand._get_replay_metrics_sync = _get_replay_metrics_sync