diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index 91a0cb42f..83656a910 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -38,13 +38,22 @@ GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
-GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
-# Airtable OAuth for Aitable Connector
+# OAuth for Aitable Connector
AIRTABLE_CLIENT_ID=your_airtable_client_id
AIRTABLE_CLIENT_SECRET=your_airtable_client_secret
AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
+# OAuth for Linear Connector
+LINEAR_CLIENT_ID=your_linear_client_id
+LINEAR_CLIENT_SECRET=your_linear_client_secret
+LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback
+
+# OAuth for Notion Connector
+NOTION_CLIENT_ID=your_notion_client_id
+NOTION_CLIENT_SECRET=your_notion_client_secret
+NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback
+
# Embedding Model
# Examples:
# # Get sentence transformers embeddings
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index 9c503fb18..7c7703470 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -90,6 +90,16 @@ class Config:
AIRTABLE_CLIENT_SECRET = os.getenv("AIRTABLE_CLIENT_SECRET")
AIRTABLE_REDIRECT_URI = os.getenv("AIRTABLE_REDIRECT_URI")
+ # Notion OAuth
+ NOTION_CLIENT_ID = os.getenv("NOTION_CLIENT_ID")
+ NOTION_CLIENT_SECRET = os.getenv("NOTION_CLIENT_SECRET")
+ NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
+
+ # Linear OAuth
+ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
+ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
+ LINEAR_REDIRECT_URI = os.getenv("LINEAR_REDIRECT_URI")
+
# LLM instances are now managed per-user through the LLMConfig system
# Legacy environment variables removed in favor of user-specific configurations
diff --git a/surfsense_backend/app/connectors/google_calendar_connector.py b/surfsense_backend/app/connectors/google_calendar_connector.py
index 164d230e0..6d389ddd5 100644
--- a/surfsense_backend/app/connectors/google_calendar_connector.py
+++ b/surfsense_backend/app/connectors/google_calendar_connector.py
@@ -109,7 +109,36 @@ async def _get_credentials(
raise RuntimeError(
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
)
- connector.config = json.loads(self._credentials.to_json())
+
+ # Encrypt sensitive credentials before storing
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ creds_dict = json.loads(self._credentials.to_json())
+ token_encrypted = connector.config.get("_token_encrypted", False)
+
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ # Encrypt sensitive fields
+ if creds_dict.get("token"):
+ creds_dict["token"] = token_encryption.encrypt_token(
+ creds_dict["token"]
+ )
+ if creds_dict.get("refresh_token"):
+ creds_dict["refresh_token"] = (
+ token_encryption.encrypt_token(
+ creds_dict["refresh_token"]
+ )
+ )
+ if creds_dict.get("client_secret"):
+ creds_dict["client_secret"] = (
+ token_encryption.encrypt_token(
+ creds_dict["client_secret"]
+ )
+ )
+ creds_dict["_token_encrypted"] = True
+
+ connector.config = creds_dict
flag_modified(connector, "config")
await self._session.commit()
except Exception as e:
@@ -182,6 +211,18 @@ async def get_all_primary_calendar_events(
Tuple containing (events list, error message or None)
"""
try:
+ # Validate date strings
+ if not start_date or start_date.lower() in ("undefined", "null", "none"):
+ return (
+ [],
+ "Invalid start_date: must be a valid date string in YYYY-MM-DD format",
+ )
+ if not end_date or end_date.lower() in ("undefined", "null", "none"):
+ return (
+ [],
+ "Invalid end_date: must be a valid date string in YYYY-MM-DD format",
+ )
+
service = await self._get_service()
# Parse both dates
diff --git a/surfsense_backend/app/connectors/google_drive/credentials.py b/surfsense_backend/app/connectors/google_drive/credentials.py
index f88486468..7e6335f6d 100644
--- a/surfsense_backend/app/connectors/google_drive/credentials.py
+++ b/surfsense_backend/app/connectors/google_drive/credentials.py
@@ -1,6 +1,7 @@
"""Google Drive OAuth credential management."""
import json
+import logging
from datetime import datetime
from google.auth.transport.requests import Request
@@ -9,7 +10,11 @@
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
+from app.config import config
from app.db import SearchSourceConnector
+from app.utils.oauth_security import TokenEncryption
+
+logger = logging.getLogger(__name__)
async def get_valid_credentials(
@@ -38,7 +43,41 @@ async def get_valid_credentials(
if not connector:
raise ValueError(f"Connector {connector_id} not found")
- config_data = connector.config
+ config_data = (
+ connector.config.copy()
+ ) # Work with a copy to avoid modifying original
+
+ # Decrypt credentials if they are encrypted
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt sensitive fields
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+
+ logger.info(
+ f"Decrypted Google Drive credentials for connector {connector_id}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to decrypt Google Drive credentials for connector {connector_id}: {e!s}"
+ )
+ raise ValueError(
+ f"Failed to decrypt Google Drive credentials: {e!s}"
+ ) from e
+
exp = config_data.get("expiry", "").replace("Z", "")
if not all(
@@ -66,7 +105,29 @@ async def get_valid_credentials(
try:
credentials.refresh(Request())
- connector.config = json.loads(credentials.to_json())
+ creds_dict = json.loads(credentials.to_json())
+
+ # Encrypt sensitive credentials before storing
+ token_encrypted = connector.config.get("_token_encrypted", False)
+
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ # Encrypt sensitive fields
+ if creds_dict.get("token"):
+ creds_dict["token"] = token_encryption.encrypt_token(
+ creds_dict["token"]
+ )
+ if creds_dict.get("refresh_token"):
+ creds_dict["refresh_token"] = token_encryption.encrypt_token(
+ creds_dict["refresh_token"]
+ )
+ if creds_dict.get("client_secret"):
+ creds_dict["client_secret"] = token_encryption.encrypt_token(
+ creds_dict["client_secret"]
+ )
+ creds_dict["_token_encrypted"] = True
+
+ connector.config = creds_dict
flag_modified(connector, "config")
await session.commit()
diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py
index b4c54fda3..148aa4d0a 100644
--- a/surfsense_backend/app/connectors/linear_connector.py
+++ b/surfsense_backend/app/connectors/linear_connector.py
@@ -5,33 +5,153 @@
Allows fetching issue lists and their comments with date range filtering.
"""
+import logging
from datetime import datetime
from typing import Any
import requests
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.future import select
+
+from app.config import config
+from app.db import SearchSourceConnector
+from app.routes.linear_add_connector_route import refresh_linear_token
+from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
+from app.utils.oauth_security import TokenEncryption
+
+logger = logging.getLogger(__name__)
class LinearConnector:
"""Class for retrieving issues and comments from Linear."""
- def __init__(self, token: str | None = None):
+ def __init__(
+ self,
+ session: AsyncSession,
+ connector_id: int,
+ credentials: LinearAuthCredentialsBase | None = None,
+ ):
"""
- Initialize the LinearConnector class.
+ Initialize the LinearConnector class with auto-refresh capability.
Args:
- token: Linear API token (optional, can be set later with set_token)
+ session: Database session for updating connector
+ connector_id: Connector ID for direct updates
+ credentials: Linear OAuth credentials (optional, will be loaded from DB if not provided)
"""
- self.token = token
+ self._session = session
+ self._connector_id = connector_id
+ self._credentials = credentials
self.api_url = "https://api.linear.app/graphql"
- def set_token(self, token: str) -> None:
+ async def _get_valid_token(self) -> str:
"""
- Set the Linear API token.
+ Get valid Linear access token, refreshing if needed.
- Args:
- token: Linear API token
+ Returns:
+ Valid access token
+
+ Raises:
+ ValueError: If credentials are missing or invalid
+ Exception: If token refresh fails
"""
- self.token = token
+ # Load credentials from DB if not provided
+ if self._credentials is None:
+ result = await self._session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == self._connector_id
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ raise ValueError(f"Connector {self._connector_id} not found")
+
+ config_data = connector.config.copy()
+
+ # Decrypt credentials if they are encrypted
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt sensitive fields
+ if config_data.get("access_token"):
+ config_data["access_token"] = token_encryption.decrypt_token(
+ config_data["access_token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+
+ logger.info(
+ f"Decrypted Linear credentials for connector {self._connector_id}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to decrypt Linear credentials for connector {self._connector_id}: {e!s}"
+ )
+ raise ValueError(
+ f"Failed to decrypt Linear credentials: {e!s}"
+ ) from e
+
+ try:
+ self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
+ except Exception as e:
+ raise ValueError(f"Invalid Linear credentials: {e!s}") from e
+
+ # Check if token is expired and refreshable
+ if self._credentials.is_expired and self._credentials.is_refreshable:
+ try:
+ logger.info(
+ f"Linear token expired for connector {self._connector_id}, refreshing..."
+ )
+
+ # Get connector for refresh
+ result = await self._session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == self._connector_id
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ raise RuntimeError(
+ f"Connector {self._connector_id} not found; cannot refresh token."
+ )
+
+ # Refresh token
+ connector = await refresh_linear_token(self._session, connector)
+
+ # Reload credentials after refresh
+ config_data = connector.config.copy()
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("access_token"):
+ config_data["access_token"] = token_encryption.decrypt_token(
+ config_data["access_token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+
+ self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
+
+ logger.info(
+ f"Successfully refreshed Linear token for connector {self._connector_id}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to refresh Linear token for connector {self._connector_id}: {e!s}"
+ )
+ raise Exception(
+ f"Failed to refresh Linear OAuth credentials: {e!s}"
+ ) from e
+
+ return self._credentials.access_token
def get_headers(self) -> dict[str, str]:
"""
@@ -41,18 +161,26 @@ def get_headers(self) -> dict[str, str]:
Dictionary of headers
Raises:
- ValueError: If no Linear token has been set
+ ValueError: If no Linear access token has been set
"""
- if not self.token:
- raise ValueError("Linear token not initialized. Call set_token() first.")
+ # This is a synchronous method, but we need async token refresh
+ # For now, we'll raise an error if called directly
+ # All API calls should go through execute_graphql_query which handles async refresh
+ if not self._credentials or not self._credentials.access_token:
+ raise ValueError(
+ "Linear access token not initialized. Use execute_graphql_query() method."
+ )
- return {"Content-Type": "application/json", "Authorization": self.token}
+ return {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self._credentials.access_token}",
+ }
- def execute_graphql_query(
+ async def execute_graphql_query(
self, query: str, variables: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
- Execute a GraphQL query against the Linear API.
+ Execute a GraphQL query against the Linear API with automatic token refresh.
Args:
query: GraphQL query string
@@ -62,13 +190,17 @@ def execute_graphql_query(
Response data from the API
Raises:
- ValueError: If no Linear token has been set
+ ValueError: If no Linear access token has been set
Exception: If the API request fails
"""
- if not self.token:
- raise ValueError("Linear token not initialized. Call set_token() first.")
+ # Get valid token (refreshes if needed)
+ access_token = await self._get_valid_token()
+
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {access_token}",
+ }
- headers = self.get_headers()
payload = {"query": query}
if variables:
@@ -83,7 +215,9 @@ def execute_graphql_query(
f"Query failed with status code {response.status_code}: {response.text}"
)
- def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
+ async def get_all_issues(
+ self, include_comments: bool = True
+ ) -> list[dict[str, Any]]:
"""
Fetch all issues from Linear.
@@ -94,7 +228,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
List of issue objects
Raises:
- ValueError: If no Linear token has been set
+ ValueError: If no Linear access token has been set
Exception: If the API request fails
"""
comments_query = ""
@@ -146,7 +280,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
}}
"""
- result = self.execute_graphql_query(query)
+ result = await self.execute_graphql_query(query)
# Extract issues from the response
if (
@@ -158,7 +292,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
return []
- def get_issues_by_date_range(
+ async def get_issues_by_date_range(
self, start_date: str, end_date: str, include_comments: bool = True
) -> tuple[list[dict[str, Any]], str | None]:
"""
@@ -172,6 +306,18 @@ def get_issues_by_date_range(
Returns:
Tuple containing (issues list, error message or None)
"""
+ # Validate date strings
+ if not start_date or start_date.lower() in ("undefined", "null", "none"):
+ return (
+ [],
+ "Invalid start_date: must be a valid date string in YYYY-MM-DD format",
+ )
+ if not end_date or end_date.lower() in ("undefined", "null", "none"):
+ return (
+ [],
+ "Invalid end_date: must be a valid date string in YYYY-MM-DD format",
+ )
+
# Convert date strings to ISO format
try:
# For Linear API: we need to use a more specific format for the filter
@@ -258,7 +404,7 @@ def get_issues_by_date_range(
# Handle pagination to get all issues
while has_next_page:
variables = {"after": cursor} if cursor else {}
- result = self.execute_graphql_query(query, variables)
+ result = await self.execute_graphql_query(query, variables)
# Check for errors
if "errors" in result:
@@ -446,37 +592,3 @@ def format_date(iso_date: str) -> str:
return dt.strftime("%Y-%m-%d %H:%M:%S")
except ValueError:
return iso_date
-
-
-# Example usage (uncomment to use):
-"""
-if __name__ == "__main__":
- # Set your token here
- token = "YOUR_LINEAR_API_KEY"
-
- linear = LinearConnector(token)
-
- try:
- # Get all issues with comments
- issues = linear.get_all_issues()
- print(f"Retrieved {len(issues)} issues")
-
- # Format and print the first issue as markdown
- if issues:
- issue_md = linear.format_issue_to_markdown(issues[0])
- print("\nSample Issue in Markdown:\n")
- print(issue_md)
-
- # Get issues by date range
- start_date = "2023-01-01"
- end_date = "2023-01-31"
- date_issues, error = linear.get_issues_by_date_range(start_date, end_date)
-
- if error:
- print(f"Error: {error}")
- else:
- print(f"\nRetrieved {len(date_issues)} issues from {start_date} to {end_date}")
-
- except Exception as e:
- print(f"Error: {e}")
-"""
diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py
index 81f6642f1..e38218a6e 100644
--- a/surfsense_backend/app/connectors/notion_history.py
+++ b/surfsense_backend/app/connectors/notion_history.py
@@ -1,19 +1,167 @@
+import logging
+
from notion_client import AsyncClient
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.future import select
+
+from app.config import config
+from app.db import SearchSourceConnector
+from app.routes.notion_add_connector_route import refresh_notion_token
+from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
+from app.utils.oauth_security import TokenEncryption
+
+logger = logging.getLogger(__name__)
class NotionHistoryConnector:
- def __init__(self, token):
+ def __init__(
+ self,
+ session: AsyncSession,
+ connector_id: int,
+ credentials: NotionAuthCredentialsBase | None = None,
+ ):
"""
- Initialize the NotionPageFetcher with a token.
+ Initialize the NotionHistoryConnector with auto-refresh capability.
Args:
- token (str): Notion integration token
+ session: Database session for updating connector
+ connector_id: Connector ID for direct updates
+ credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided)
+ """
+ self._session = session
+ self._connector_id = connector_id
+ self._credentials = credentials
+ self._notion_client: AsyncClient | None = None
+
+ async def _get_valid_token(self) -> str:
+ """
+ Get valid Notion access token, refreshing if needed.
+
+ Returns:
+ Valid access token
+
+ Raises:
+ ValueError: If credentials are missing or invalid
+ Exception: If token refresh fails
+ """
+ # Load credentials from DB if not provided
+ if self._credentials is None:
+ result = await self._session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == self._connector_id
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ raise ValueError(f"Connector {self._connector_id} not found")
+
+ config_data = connector.config.copy()
+
+ # Decrypt credentials if they are encrypted
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt sensitive fields
+ if config_data.get("access_token"):
+ config_data["access_token"] = token_encryption.decrypt_token(
+ config_data["access_token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+
+ logger.info(
+ f"Decrypted Notion credentials for connector {self._connector_id}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}"
+ )
+ raise ValueError(
+ f"Failed to decrypt Notion credentials: {e!s}"
+ ) from e
+
+ try:
+ self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
+ except Exception as e:
+ raise ValueError(f"Invalid Notion credentials: {e!s}") from e
+
+ # Check if token is expired and refreshable
+ if self._credentials.is_expired and self._credentials.is_refreshable:
+ try:
+ logger.info(
+ f"Notion token expired for connector {self._connector_id}, refreshing..."
+ )
+
+ # Get connector for refresh
+ result = await self._session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.id == self._connector_id
+ )
+ )
+ connector = result.scalars().first()
+
+ if not connector:
+ raise RuntimeError(
+ f"Connector {self._connector_id} not found; cannot refresh token."
+ )
+
+ # Refresh token
+ connector = await refresh_notion_token(self._session, connector)
+
+ # Reload credentials after refresh
+ config_data = connector.config.copy()
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+ if config_data.get("access_token"):
+ config_data["access_token"] = token_encryption.decrypt_token(
+ config_data["access_token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+
+ self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
+
+ # Invalidate cached client so it's recreated with new token
+ self._notion_client = None
+
+ logger.info(
+ f"Successfully refreshed Notion token for connector {self._connector_id}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}"
+ )
+ raise Exception(
+ f"Failed to refresh Notion OAuth credentials: {e!s}"
+ ) from e
+
+ return self._credentials.access_token
+
+ async def _get_client(self) -> AsyncClient:
+ """
+ Get or create Notion AsyncClient with valid token.
+
+ Returns:
+ Notion AsyncClient instance
"""
- self.notion = AsyncClient(auth=token)
+ if self._notion_client is None:
+ token = await self._get_valid_token()
+ self._notion_client = AsyncClient(auth=token)
+ return self._notion_client
async def close(self):
"""Close the async client connection."""
- await self.notion.aclose()
+ if self._notion_client:
+ await self._notion_client.aclose()
+ self._notion_client = None
async def __aenter__(self):
"""Async context manager entry."""
@@ -34,6 +182,8 @@ async def get_all_pages(self, start_date=None, end_date=None):
Returns:
list: List of dictionaries containing page data
"""
+ notion = await self._get_client()
+
# Build the filter for the search
# Note: Notion API requires specific filter structure
search_params = {}
@@ -67,7 +217,7 @@ async def get_all_pages(self, start_date=None, end_date=None):
if cursor:
search_params["start_cursor"] = cursor
- search_results = await self.notion.search(**search_params)
+ search_results = await notion.search(**search_params)
pages.extend(search_results["results"])
has_more = search_results.get("has_more", False)
@@ -125,6 +275,8 @@ async def get_page_content(self, page_id):
Returns:
list: List of processed blocks from the page
"""
+ notion = await self._get_client()
+
blocks = []
has_more = True
cursor = None
@@ -132,11 +284,11 @@ async def get_page_content(self, page_id):
# Paginate through all blocks
while has_more:
if cursor:
- response = await self.notion.blocks.children.list(
+ response = await notion.blocks.children.list(
block_id=page_id, start_cursor=cursor
)
else:
- response = await self.notion.blocks.children.list(block_id=page_id)
+ response = await notion.blocks.children.list(block_id=page_id)
blocks.extend(response["results"])
has_more = response["has_more"]
@@ -162,6 +314,8 @@ async def process_block(self, block):
Returns:
dict: Processed block with content and children
"""
+ notion = await self._get_client()
+
block_id = block["id"]
block_type = block["type"]
@@ -174,9 +328,7 @@ async def process_block(self, block):
if has_children:
# Fetch and process child blocks
- children_response = await self.notion.blocks.children.list(
- block_id=block_id
- )
+ children_response = await notion.blocks.children.list(block_id=block_id)
for child_block in children_response["results"]:
child_blocks.append(await self.process_block(child_block))
diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py
index 3c18650ae..1d1fa39ad 100644
--- a/surfsense_backend/app/routes/__init__.py
+++ b/surfsense_backend/app/routes/__init__.py
@@ -15,11 +15,13 @@
from .google_gmail_add_connector_route import (
router as google_gmail_add_connector_router,
)
+from .linear_add_connector_route import router as linear_add_connector_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .new_chat_routes import router as new_chat_router
from .new_llm_config_routes import router as new_llm_config_router
from .notes_routes import router as notes_router
+from .notion_add_connector_route import router as notion_add_connector_router
from .podcasts_routes import router as podcasts_router
from .rbac_routes import router as rbac_router
from .search_source_connectors_routes import router as search_source_connectors_router
@@ -39,7 +41,9 @@
router.include_router(google_gmail_add_connector_router)
router.include_router(google_drive_add_connector_router)
router.include_router(airtable_add_connector_router)
+router.include_router(linear_add_connector_router)
router.include_router(luma_add_connector_router)
+router.include_router(notion_add_connector_router)
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py
index 3bcbe4dc0..9284d89e8 100644
--- a/surfsense_backend/app/routes/airtable_add_connector_route.py
+++ b/surfsense_backend/app/routes/airtable_add_connector_route.py
@@ -1,6 +1,5 @@
import base64
import hashlib
-import json
import logging
import secrets
from datetime import UTC, datetime, timedelta
@@ -23,6 +22,7 @@
)
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
@@ -40,6 +40,30 @@
"user.email:read",
]
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
def make_basic_auth_header(client_id: str, client_secret: str) -> str:
credentials = f"{client_id}:{client_secret}".encode()
@@ -90,18 +114,19 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
status_code=500, detail="Airtable OAuth not configured."
)
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
# Generate PKCE parameters
code_verifier, code_challenge = generate_pkce_pair()
- # Generate state parameter
- state_payload = json.dumps(
- {
- "space_id": space_id,
- "user_id": str(user.id),
- "code_verifier": code_verifier,
- }
+ # Generate secure state parameter with HMAC signature (including code_verifier for PKCE)
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(
+ space_id, user.id, code_verifier=code_verifier
)
- state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
# Build authorization URL
auth_params = {
@@ -134,8 +159,9 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
@router.get("/auth/airtable/connector/callback")
async def airtable_callback(
request: Request,
- code: str,
- state: str,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
"""
@@ -143,7 +169,8 @@ async def airtable_callback(
Args:
request: FastAPI request object
- code: Authorization code from Airtable
+ code: Authorization code from Airtable (if user granted access)
+ error: Error code from Airtable (if user denied access or error occurred)
state: State parameter containing user/space info
session: Database session
@@ -151,10 +178,42 @@ async def airtable_callback(
Redirect response to frontend
"""
try:
- # Decode and parse the state
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Airtable OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=airtable_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
try:
- decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
- data = json.loads(decoded_state)
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Invalid state parameter: {e!s}"
@@ -162,7 +221,12 @@ async def airtable_callback(
user_id = UUID(data["user_id"])
space_id = data["space_id"]
- code_verifier = data["code_verifier"]
+ code_verifier = data.get("code_verifier")
+
+ if not code_verifier:
+ raise HTTPException(
+ status_code=400, detail="Missing code_verifier in state parameter"
+ )
auth_header = make_basic_auth_header(
config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET
)
@@ -201,22 +265,38 @@ async def airtable_callback(
token_json = token_response.json()
+ # Encrypt sensitive tokens before storing
+ token_encryption = get_token_encryption()
+ access_token = token_json.get("access_token")
+ refresh_token = token_json.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Airtable"
+ )
+
# Calculate expiration time (UTC, tz-aware)
expires_at = None
if token_json.get("expires_in"):
now_utc = datetime.now(UTC)
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
- # Create credentials object
+ # Create credentials object with encrypted tokens
credentials = AirtableAuthCredentialsBase(
- access_token=token_json["access_token"],
- refresh_token=token_json.get("refresh_token"),
+ access_token=token_encryption.encrypt_token(access_token),
+ refresh_token=token_encryption.encrypt_token(refresh_token)
+ if refresh_token
+ else None,
token_type=token_json.get("token_type", "Bearer"),
expires_in=token_json.get("expires_in"),
expires_at=expires_at,
scope=token_json.get("scope"),
)
+ # Mark that tokens are encrypted for backward compatibility
+ credentials_dict = credentials.to_dict()
+ credentials_dict["_token_encrypted"] = True
+
# Check if connector already exists for this search space and user
existing_connector_result = await session.execute(
select(SearchSourceConnector).filter(
@@ -230,7 +310,7 @@ async def airtable_callback(
if existing_connector:
# Update existing connector
- existing_connector.config = credentials.to_dict()
+ existing_connector.config = credentials_dict
existing_connector.name = "Airtable Connector"
existing_connector.is_indexable = True
logger.info(
@@ -242,7 +322,7 @@ async def airtable_callback(
name="Airtable Connector",
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
is_indexable=True,
- config=credentials.to_dict(),
+ config=credentials_dict,
search_space_id=space_id,
user_id=user_id,
)
@@ -306,6 +386,21 @@ async def refresh_airtable_token(
logger.info(f"Refreshing Airtable token for connector {connector.id}")
credentials = AirtableAuthCredentialsBase.from_dict(connector.config)
+
+ # Decrypt tokens if they are encrypted
+ token_encryption = get_token_encryption()
+ is_encrypted = connector.config.get("_token_encrypted", False)
+
+ refresh_token = credentials.refresh_token
+ if is_encrypted and refresh_token:
+ try:
+ refresh_token = token_encryption.decrypt_token(refresh_token)
+ except Exception as e:
+ logger.error(f"Failed to decrypt refresh token: {e!s}")
+ raise HTTPException(
+ status_code=500, detail="Failed to decrypt stored refresh token"
+ ) from e
+
auth_header = make_basic_auth_header(
config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET
)
@@ -313,7 +408,7 @@ async def refresh_airtable_token(
# Prepare token refresh data
refresh_data = {
"grant_type": "refresh_token",
- "refresh_token": credentials.refresh_token,
+ "refresh_token": refresh_token,
"client_id": config.AIRTABLE_CLIENT_ID,
"client_secret": config.AIRTABLE_CLIENT_SECRET,
}
@@ -342,14 +437,29 @@ async def refresh_airtable_token(
now_utc = datetime.now(UTC)
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
- # Update credentials object
- credentials.access_token = token_json["access_token"]
+ # Encrypt new tokens before storing
+ access_token = token_json.get("access_token")
+ new_refresh_token = token_json.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Airtable refresh"
+ )
+
+ # Update credentials object with encrypted tokens
+ credentials.access_token = token_encryption.encrypt_token(access_token)
+ if new_refresh_token:
+ credentials.refresh_token = token_encryption.encrypt_token(
+ new_refresh_token
+ )
credentials.expires_in = token_json.get("expires_in")
credentials.expires_at = expires_at
credentials.scope = token_json.get("scope")
- # Update connector config
- connector.config = credentials.to_dict()
+ # Update connector config with encrypted tokens
+ credentials_dict = credentials.to_dict()
+ credentials_dict["_token_encrypted"] = True
+ connector.config = credentials_dict
await session.commit()
await session.refresh(connector)
diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py
index 8bb685450..6c6ae4e40 100644
--- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py
+++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py
@@ -2,7 +2,6 @@
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
-import base64
import json
import logging
from uuid import UUID
@@ -23,6 +22,7 @@
get_async_session,
)
from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
@@ -31,6 +31,30 @@
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
def get_google_flow():
try:
@@ -59,16 +83,16 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
flow = get_google_flow()
- # Encode space_id and user_id in state
- state_payload = json.dumps(
- {
- "space_id": space_id,
- "user_id": str(user.id),
- }
- )
- state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
+ # Generate secure state parameter with HMAC signature
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(space_id, user.id)
auth_url, _ = flow.authorization_url(
access_type="offline",
@@ -86,24 +110,86 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
@router.get("/auth/google/calendar/connector/callback")
async def calendar_callback(
request: Request,
- code: str,
- state: str,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
try:
- # Decode and parse the state
- decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
- data = json.loads(decoded_state)
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Google Calendar OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_calendar_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
+ try:
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state parameter: {e!s}"
+ ) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
+ # Validate redirect URI (security: ensure it matches configured value)
+ if not config.GOOGLE_CALENDAR_REDIRECT_URI:
+ raise HTTPException(
+ status_code=500, detail="GOOGLE_CALENDAR_REDIRECT_URI not configured"
+ )
+
flow = get_google_flow()
flow.fetch_token(code=code)
creds = flow.credentials
creds_dict = json.loads(creds.to_json())
+ # Encrypt sensitive credentials before storing
+ token_encryption = get_token_encryption()
+
+ # Encrypt sensitive fields: token, refresh_token, client_secret
+ if creds_dict.get("token"):
+ creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
+ if creds_dict.get("refresh_token"):
+ creds_dict["refresh_token"] = token_encryption.encrypt_token(
+ creds_dict["refresh_token"]
+ )
+ if creds_dict.get("client_secret"):
+ creds_dict["client_secret"] = token_encryption.encrypt_token(
+ creds_dict["client_secret"]
+ )
+
+ # Mark that credentials are encrypted for backward compatibility
+ creds_dict["_token_encrypted"] = True
+
try:
# Check if a connector with the same type already exists for this search space and user
result = await session.execute(
diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py
index 52461319b..6caf3f204 100644
--- a/surfsense_backend/app/routes/google_drive_add_connector_route.py
+++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py
@@ -10,7 +10,6 @@
- GET /connectors/{connector_id}/google-drive/folders - List user's folders (for index-time selection)
"""
-import base64
import json
import logging
import os
@@ -37,6 +36,7 @@
get_async_session,
)
from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
# Relax token scope validation for Google OAuth
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
@@ -44,6 +44,31 @@
logger = logging.getLogger(__name__)
router = APIRouter()
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
+
# Google Drive OAuth scopes
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive
@@ -90,16 +115,16 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
flow = get_google_flow()
- # Encode space_id and user_id in state parameter
- state_payload = json.dumps(
- {
- "space_id": space_id,
- "user_id": str(user.id),
- }
- )
- state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
+ # Generate secure state parameter with HMAC signature
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(space_id, user.id)
# Generate authorization URL
auth_url, _ = flow.authorization_url(
@@ -124,8 +149,9 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
@router.get("/auth/google/drive/connector/callback")
async def drive_callback(
request: Request,
- code: str,
- state: str,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
"""
@@ -133,15 +159,53 @@ async def drive_callback(
Query params:
code: Authorization code from Google
+ error: OAuth error (if user denied access)
state: Encoded state with space_id and user_id
Returns:
Redirect to frontend success page
"""
try:
- # Decode and parse state
- decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
- data = json.loads(decoded_state)
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Google Drive OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_drive_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
+ try:
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state parameter: {e!s}"
+ ) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
@@ -150,6 +214,12 @@ async def drive_callback(
f"Processing Google Drive callback for user {user_id}, space {space_id}"
)
+ # Validate redirect URI (security: ensure it matches configured value)
+ if not config.GOOGLE_DRIVE_REDIRECT_URI:
+ raise HTTPException(
+ status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
+ )
+
# Exchange authorization code for tokens
flow = get_google_flow()
flow.fetch_token(code=code)
@@ -157,6 +227,24 @@ async def drive_callback(
creds = flow.credentials
creds_dict = json.loads(creds.to_json())
+ # Encrypt sensitive credentials before storing
+ token_encryption = get_token_encryption()
+
+ # Encrypt sensitive fields: token, refresh_token, client_secret
+ if creds_dict.get("token"):
+ creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
+ if creds_dict.get("refresh_token"):
+ creds_dict["refresh_token"] = token_encryption.encrypt_token(
+ creds_dict["refresh_token"]
+ )
+ if creds_dict.get("client_secret"):
+ creds_dict["client_secret"] = token_encryption.encrypt_token(
+ creds_dict["client_secret"]
+ )
+
+ # Mark that credentials are encrypted for backward compatibility
+ creds_dict["_token_encrypted"] = True
+
# Check if connector already exists for this space/user
result = await session.execute(
select(SearchSourceConnector).filter(
diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py
index 21fcf2c38..20a51c1a1 100644
--- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py
+++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py
@@ -2,7 +2,6 @@
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
-import base64
import json
import logging
from uuid import UUID
@@ -23,51 +22,90 @@
get_async_session,
)
from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
router = APIRouter()
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
def get_google_flow():
"""Create and return a Google OAuth flow for Gmail API."""
- flow = Flow.from_client_config(
- {
- "web": {
- "client_id": config.GOOGLE_OAUTH_CLIENT_ID,
- "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
- "auth_uri": "https://accounts.google.com/o/oauth2/auth",
- "token_uri": "https://oauth2.googleapis.com/token",
- "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI],
- }
- },
- scopes=[
- "https://www.googleapis.com/auth/gmail.readonly",
- "https://www.googleapis.com/auth/userinfo.email",
- "https://www.googleapis.com/auth/userinfo.profile",
- "openid",
- ],
- )
- flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI
- return flow
+ try:
+ flow = Flow.from_client_config(
+ {
+ "web": {
+ "client_id": config.GOOGLE_OAUTH_CLIENT_ID,
+ "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
+ "token_uri": "https://oauth2.googleapis.com/token",
+ "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI],
+ }
+ },
+ scopes=[
+ "https://www.googleapis.com/auth/gmail.readonly",
+ "https://www.googleapis.com/auth/userinfo.email",
+ "https://www.googleapis.com/auth/userinfo.profile",
+ "openid",
+ ],
+ )
+ flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI
+ return flow
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Failed to create Google flow: {e!s}"
+ ) from e
@router.get("/auth/google/gmail/connector/add")
async def connect_gmail(space_id: int, user: User = Depends(current_active_user)):
+ """
+ Initiate Google Gmail OAuth flow.
+
+ Query params:
+ space_id: Search space ID to add connector to
+
+ Returns:
+ JSON with auth_url to redirect user to Google authorization
+ """
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
flow = get_google_flow()
- # Encode space_id and user_id in state
- state_payload = json.dumps(
- {
- "space_id": space_id,
- "user_id": str(user.id),
- }
- )
- state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
+ # Generate secure state parameter with HMAC signature
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(space_id, user.id)
auth_url, _ = flow.authorization_url(
access_type="offline",
@@ -75,8 +113,13 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
include_granted_scopes="true",
state=state_encoded,
)
+
+ logger.info(
+ f"Initiating Google Gmail OAuth for user {user.id}, space {space_id}"
+ )
return {"auth_url": auth_url}
except Exception as e:
+ logger.error(f"Failed to initiate Google Gmail OAuth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Google OAuth: {e!s}"
) from e
@@ -85,24 +128,99 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
@router.get("/auth/google/gmail/connector/callback")
async def gmail_callback(
request: Request,
- code: str,
- state: str,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
+ """
+ Handle Google Gmail OAuth callback.
+
+ Args:
+ request: FastAPI request object
+ code: Authorization code from Google (if user granted access)
+ error: Error code from Google (if user denied access or error occurred)
+ state: State parameter containing user/space info
+ session: Database session
+
+ Returns:
+ Redirect response to frontend
+ """
try:
- # Decode and parse the state
- decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
- data = json.loads(decoded_state)
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Google Gmail OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_gmail_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
+ try:
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state parameter: {e!s}"
+ ) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
+ # Validate redirect URI (security: ensure it matches configured value)
+ if not config.GOOGLE_GMAIL_REDIRECT_URI:
+ raise HTTPException(
+ status_code=500, detail="GOOGLE_GMAIL_REDIRECT_URI not configured"
+ )
+
flow = get_google_flow()
flow.fetch_token(code=code)
creds = flow.credentials
creds_dict = json.loads(creds.to_json())
+ # Encrypt sensitive credentials before storing
+ token_encryption = get_token_encryption()
+
+ # Encrypt sensitive fields: token, refresh_token, client_secret
+ if creds_dict.get("token"):
+ creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
+ if creds_dict.get("refresh_token"):
+ creds_dict["refresh_token"] = token_encryption.encrypt_token(
+ creds_dict["refresh_token"]
+ )
+ if creds_dict.get("client_secret"):
+ creds_dict["client_secret"] = token_encryption.encrypt_token(
+ creds_dict["client_secret"]
+ )
+
+ # Mark that credentials are encrypted for backward compatibility
+ creds_dict["_token_encrypted"] = True
+
try:
# Check if a connector with the same type already exists for this search space and user
result = await session.execute(
@@ -160,3 +278,6 @@ async def gmail_callback(
raise
except Exception as e:
logger.error(f"Unexpected error in Gmail callback: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to complete Google Gmail OAuth: {e!s}"
+ ) from e
diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py
new file mode 100644
index 000000000..7a7fc196a
--- /dev/null
+++ b/surfsense_backend/app/routes/linear_add_connector_route.py
@@ -0,0 +1,448 @@
+"""
+Linear Connector OAuth Routes.
+
+Handles OAuth 2.0 authentication flow for Linear connector.
+"""
+
+import logging
+from datetime import UTC, datetime, timedelta
+from uuid import UUID
+
+import httpx
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import RedirectResponse
+from pydantic import ValidationError
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.future import select
+
+from app.config import config
+from app.db import (
+ SearchSourceConnector,
+ SearchSourceConnectorType,
+ User,
+ get_async_session,
+)
+from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
+from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+# Linear OAuth endpoints
+AUTHORIZATION_URL = "https://linear.app/oauth/authorize"
+TOKEN_URL = "https://api.linear.app/oauth/token"
+
+# OAuth scopes for Linear
+SCOPES = ["read", "write"]
+
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
+
+def make_basic_auth_header(client_id: str, client_secret: str) -> str:
+ """Create Basic Auth header for Linear OAuth."""
+ import base64
+
+ credentials = f"{client_id}:{client_secret}".encode()
+ b64 = base64.b64encode(credentials).decode("ascii")
+ return f"Basic {b64}"
+
+
+@router.get("/auth/linear/connector/add")
+async def connect_linear(space_id: int, user: User = Depends(current_active_user)):
+ """
+ Initiate Linear OAuth flow.
+
+ Args:
+ space_id: The search space ID
+ user: Current authenticated user
+
+ Returns:
+ Authorization URL for redirect
+ """
+ try:
+ if not space_id:
+ raise HTTPException(status_code=400, detail="space_id is required")
+
+ if not config.LINEAR_CLIENT_ID:
+ raise HTTPException(status_code=500, detail="Linear OAuth not configured.")
+
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
+ # Generate secure state parameter with HMAC signature
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(space_id, user.id)
+
+ # Build authorization URL
+ from urllib.parse import urlencode
+
+ auth_params = {
+ "client_id": config.LINEAR_CLIENT_ID,
+ "response_type": "code",
+ "redirect_uri": config.LINEAR_REDIRECT_URI,
+ "scope": " ".join(SCOPES),
+ "state": state_encoded,
+ }
+
+ auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
+
+ logger.info(f"Generated Linear OAuth URL for user {user.id}, space {space_id}")
+ return {"auth_url": auth_url}
+
+ except Exception as e:
+ logger.error(f"Failed to initiate Linear OAuth: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to initiate Linear OAuth: {e!s}"
+ ) from e
+
+
+@router.get("/auth/linear/connector/callback")
+async def linear_callback(
+ request: Request,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
+ session: AsyncSession = Depends(get_async_session),
+):
+ """
+ Handle Linear OAuth callback.
+
+ Args:
+ request: FastAPI request object
+ code: Authorization code from Linear (if user granted access)
+ error: Error code from Linear (if user denied access or error occurred)
+ state: State parameter containing user/space info
+ session: Database session
+
+ Returns:
+ Redirect response to frontend
+ """
+ try:
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Linear OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=linear_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
+ try:
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state parameter: {e!s}"
+ ) from e
+
+ user_id = UUID(data["user_id"])
+ space_id = data["space_id"]
+
+ # Validate redirect URI (security: ensure it matches configured value)
+ if not config.LINEAR_REDIRECT_URI:
+ raise HTTPException(
+ status_code=500, detail="LINEAR_REDIRECT_URI not configured"
+ )
+
+ # Exchange authorization code for access token
+ auth_header = make_basic_auth_header(
+ config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET
+ )
+
+ token_data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": config.LINEAR_REDIRECT_URI, # Use stored value, not from request
+ }
+
+ async with httpx.AsyncClient() as client:
+ token_response = await client.post(
+ TOKEN_URL,
+ data=token_data,
+ headers={
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": auth_header,
+ },
+ timeout=30.0,
+ )
+
+ if token_response.status_code != 200:
+ error_detail = token_response.text
+ try:
+ error_json = token_response.json()
+ error_detail = error_json.get("error_description", error_detail)
+ except Exception:
+ pass
+ raise HTTPException(
+ status_code=400, detail=f"Token exchange failed: {error_detail}"
+ )
+
+ token_json = token_response.json()
+
+ # Encrypt sensitive tokens before storing
+ token_encryption = get_token_encryption()
+ access_token = token_json.get("access_token")
+ refresh_token = token_json.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Linear"
+ )
+
+ # Calculate expiration time (UTC, tz-aware)
+ expires_at = None
+ if token_json.get("expires_in"):
+ now_utc = datetime.now(UTC)
+ expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
+
+ # Store the encrypted access token and refresh token in connector config
+ connector_config = {
+ "access_token": token_encryption.encrypt_token(access_token),
+ "refresh_token": token_encryption.encrypt_token(refresh_token)
+ if refresh_token
+ else None,
+ "token_type": token_json.get("token_type", "Bearer"),
+ "expires_in": token_json.get("expires_in"),
+ "expires_at": expires_at.isoformat() if expires_at else None,
+ "scope": token_json.get("scope"),
+ # Mark that tokens are encrypted for backward compatibility
+ "_token_encrypted": True,
+ }
+
+ # Check if connector already exists for this search space and user
+ existing_connector_result = await session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.LINEAR_CONNECTOR,
+ )
+ )
+ existing_connector = existing_connector_result.scalars().first()
+
+ if existing_connector:
+ # Update existing connector
+ existing_connector.config = connector_config
+ existing_connector.name = "Linear Connector"
+ existing_connector.is_indexable = True
+ logger.info(
+ f"Updated existing Linear connector for user {user_id} in space {space_id}"
+ )
+ else:
+ # Create new connector
+ new_connector = SearchSourceConnector(
+ name="Linear Connector",
+ connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
+ is_indexable=True,
+ config=connector_config,
+ search_space_id=space_id,
+ user_id=user_id,
+ )
+ session.add(new_connector)
+ logger.info(
+ f"Created new Linear connector for user {user_id} in space {space_id}"
+ )
+
+ try:
+ await session.commit()
+ logger.info(f"Successfully saved Linear connector for user {user_id}")
+
+ # Redirect to the frontend with success params
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector"
+ )
+
+ except ValidationError as e:
+ await session.rollback()
+ raise HTTPException(
+ status_code=422, detail=f"Validation error: {e!s}"
+ ) from e
+ except IntegrityError as e:
+ await session.rollback()
+ raise HTTPException(
+ status_code=409,
+ detail=f"Integrity error: A connector with this type already exists. {e!s}",
+ ) from e
+ except Exception as e:
+ logger.error(f"Failed to create search source connector: {e!s}")
+ await session.rollback()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create search source connector: {e!s}",
+ ) from e
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to complete Linear OAuth: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to complete Linear OAuth: {e!s}"
+ ) from e
+
+
+async def refresh_linear_token(
+ session: AsyncSession, connector: SearchSourceConnector
+) -> SearchSourceConnector:
+ """
+ Refresh the Linear access token for a connector.
+
+ Args:
+ session: Database session
+ connector: Linear connector to refresh
+
+ Returns:
+ Updated connector object
+ """
+ try:
+ logger.info(f"Refreshing Linear token for connector {connector.id}")
+
+ credentials = LinearAuthCredentialsBase.from_dict(connector.config)
+
+ # Decrypt tokens if they are encrypted
+ token_encryption = get_token_encryption()
+ is_encrypted = connector.config.get("_token_encrypted", False)
+
+ refresh_token = credentials.refresh_token
+ if is_encrypted and refresh_token:
+ try:
+ refresh_token = token_encryption.decrypt_token(refresh_token)
+ except Exception as e:
+ logger.error(f"Failed to decrypt refresh token: {e!s}")
+ raise HTTPException(
+ status_code=500, detail="Failed to decrypt stored refresh token"
+ ) from e
+
+ if not refresh_token:
+ raise HTTPException(
+ status_code=400,
+ detail="No refresh token available. Please re-authenticate.",
+ )
+
+ auth_header = make_basic_auth_header(
+ config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET
+ )
+
+ # Prepare token refresh data
+ refresh_data = {
+ "grant_type": "refresh_token",
+ "refresh_token": refresh_token,
+ }
+
+ async with httpx.AsyncClient() as client:
+ token_response = await client.post(
+ TOKEN_URL,
+ data=refresh_data,
+ headers={
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": auth_header,
+ },
+ timeout=30.0,
+ )
+
+ if token_response.status_code != 200:
+ error_detail = token_response.text
+ try:
+ error_json = token_response.json()
+ error_detail = error_json.get("error_description", error_detail)
+ except Exception:
+ pass
+ raise HTTPException(
+ status_code=400, detail=f"Token refresh failed: {error_detail}"
+ )
+
+ token_json = token_response.json()
+
+ # Calculate expiration time (UTC, tz-aware)
+ expires_at = None
+ expires_in = token_json.get("expires_in")
+ if expires_in:
+ now_utc = datetime.now(UTC)
+ expires_at = now_utc + timedelta(seconds=int(expires_in))
+
+ # Encrypt new tokens before storing
+ access_token = token_json.get("access_token")
+ new_refresh_token = token_json.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Linear refresh"
+ )
+
+ # Update credentials object with encrypted tokens
+ credentials.access_token = token_encryption.encrypt_token(access_token)
+ if new_refresh_token:
+ credentials.refresh_token = token_encryption.encrypt_token(
+ new_refresh_token
+ )
+ credentials.expires_in = expires_in
+ credentials.expires_at = expires_at
+ credentials.scope = token_json.get("scope")
+
+ # Update connector config with encrypted tokens
+ credentials_dict = credentials.to_dict()
+ credentials_dict["_token_encrypted"] = True
+ connector.config = credentials_dict
+ await session.commit()
+ await session.refresh(connector)
+
+ logger.info(f"Successfully refreshed Linear token for connector {connector.id}")
+
+ return connector
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to refresh Linear token: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to refresh Linear token: {e!s}"
+ ) from e
diff --git a/surfsense_backend/app/routes/notion_add_connector_route.py b/surfsense_backend/app/routes/notion_add_connector_route.py
new file mode 100644
index 000000000..462ac398c
--- /dev/null
+++ b/surfsense_backend/app/routes/notion_add_connector_route.py
@@ -0,0 +1,459 @@
+"""
+Notion Connector OAuth Routes.
+
+Handles OAuth 2.0 authentication flow for Notion connector.
+"""
+
+import logging
+from datetime import UTC, datetime, timedelta
+from uuid import UUID
+
+import httpx
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import RedirectResponse
+from pydantic import ValidationError
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.future import select
+
+from app.config import config
+from app.db import (
+ SearchSourceConnector,
+ SearchSourceConnectorType,
+ User,
+ get_async_session,
+)
+from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
+from app.users import current_active_user
+from app.utils.oauth_security import OAuthStateManager, TokenEncryption
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+# Notion OAuth endpoints
+AUTHORIZATION_URL = "https://api.notion.com/v1/oauth/authorize"
+TOKEN_URL = "https://api.notion.com/v1/oauth/token"
+
+# Initialize security utilities
+_state_manager = None
+_token_encryption = None
+
+
+def get_state_manager() -> OAuthStateManager:
+ """Get or create OAuth state manager instance."""
+ global _state_manager
+ if _state_manager is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for OAuth security")
+ _state_manager = OAuthStateManager(config.SECRET_KEY)
+ return _state_manager
+
+
+def get_token_encryption() -> TokenEncryption:
+ """Get or create token encryption instance."""
+ global _token_encryption
+ if _token_encryption is None:
+ if not config.SECRET_KEY:
+ raise ValueError("SECRET_KEY must be set for token encryption")
+ _token_encryption = TokenEncryption(config.SECRET_KEY)
+ return _token_encryption
+
+
+def make_basic_auth_header(client_id: str, client_secret: str) -> str:
+ """Create Basic Auth header for Notion OAuth."""
+ import base64
+
+ credentials = f"{client_id}:{client_secret}".encode()
+ b64 = base64.b64encode(credentials).decode("ascii")
+ return f"Basic {b64}"
+
+
+@router.get("/auth/notion/connector/add")
+async def connect_notion(space_id: int, user: User = Depends(current_active_user)):
+ """
+ Initiate Notion OAuth flow.
+
+ Args:
+ space_id: The search space ID
+ user: Current authenticated user
+
+ Returns:
+ Authorization URL for redirect
+ """
+ try:
+ if not space_id:
+ raise HTTPException(status_code=400, detail="space_id is required")
+
+ if not config.NOTION_CLIENT_ID:
+ raise HTTPException(status_code=500, detail="Notion OAuth not configured.")
+
+ if not config.SECRET_KEY:
+ raise HTTPException(
+ status_code=500, detail="SECRET_KEY not configured for OAuth security."
+ )
+
+ # Generate secure state parameter with HMAC signature
+ state_manager = get_state_manager()
+ state_encoded = state_manager.generate_secure_state(space_id, user.id)
+
+ # Build authorization URL
+ from urllib.parse import urlencode
+
+ auth_params = {
+ "client_id": config.NOTION_CLIENT_ID,
+ "response_type": "code",
+ "owner": "user", # Allows both admins and members to authorize
+ "redirect_uri": config.NOTION_REDIRECT_URI,
+ "state": state_encoded,
+ }
+
+ auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
+
+ logger.info(f"Generated Notion OAuth URL for user {user.id}, space {space_id}")
+ return {"auth_url": auth_url}
+
+ except Exception as e:
+ logger.error(f"Failed to initiate Notion OAuth: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to initiate Notion OAuth: {e!s}"
+ ) from e
+
+
+@router.get("/auth/notion/connector/callback")
+async def notion_callback(
+ request: Request,
+ code: str | None = None,
+ error: str | None = None,
+ state: str | None = None,
+ session: AsyncSession = Depends(get_async_session),
+):
+ """
+ Handle Notion OAuth callback.
+
+ Args:
+ request: FastAPI request object
+ code: Authorization code from Notion (if user granted access)
+ error: Error code from Notion (if user denied access or error occurred)
+ state: State parameter containing user/space info
+ session: Database session
+
+ Returns:
+ Redirect response to frontend
+ """
+ try:
+ # Handle OAuth errors (e.g., user denied access)
+ if error:
+ logger.warning(f"Notion OAuth error: {error}")
+ # Try to decode state to get space_id for redirect, but don't fail if it's invalid
+ space_id = None
+ if state:
+ try:
+ state_manager = get_state_manager()
+ data = state_manager.validate_state(state)
+ space_id = data.get("space_id")
+ except Exception:
+ # If state is invalid, we'll redirect without space_id
+ logger.warning("Failed to validate state in error handler")
+
+ # Redirect to frontend with error parameter
+ if space_id:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied"
+ )
+ else:
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=notion_oauth_denied"
+ )
+
+ # Validate required parameters for successful flow
+ if not code:
+ raise HTTPException(status_code=400, detail="Missing authorization code")
+ if not state:
+ raise HTTPException(status_code=400, detail="Missing state parameter")
+
+ # Validate and decode state with signature verification
+ state_manager = get_state_manager()
+ try:
+ data = state_manager.validate_state(state)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state parameter: {e!s}"
+ ) from e
+
+ user_id = UUID(data["user_id"])
+ space_id = data["space_id"]
+
+ # Validate redirect URI (security: ensure it matches configured value)
+ # Note: Notion doesn't send redirect_uri in callback, but we validate
+ # that we're using the configured one in token exchange
+ if not config.NOTION_REDIRECT_URI:
+ raise HTTPException(
+ status_code=500, detail="NOTION_REDIRECT_URI not configured"
+ )
+
+ # Exchange authorization code for access token
+ auth_header = make_basic_auth_header(
+ config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET
+ )
+
+ token_data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": config.NOTION_REDIRECT_URI, # Use stored value, not from request
+ }
+
+ async with httpx.AsyncClient() as client:
+ token_response = await client.post(
+ TOKEN_URL,
+ json=token_data,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": auth_header,
+ },
+ timeout=30.0,
+ )
+
+ if token_response.status_code != 200:
+ error_detail = token_response.text
+ try:
+ error_json = token_response.json()
+ error_detail = error_json.get("error_description", error_detail)
+ except Exception:
+ pass
+ raise HTTPException(
+ status_code=400, detail=f"Token exchange failed: {error_detail}"
+ )
+
+ token_json = token_response.json()
+
+ # Encrypt sensitive tokens before storing
+ token_encryption = get_token_encryption()
+ access_token = token_json.get("access_token")
+ refresh_token = token_json.get("refresh_token")
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Notion"
+ )
+
+ # Calculate expiration time (UTC, tz-aware)
+ expires_at = None
+ expires_in = token_json.get("expires_in")
+ if expires_in:
+ now_utc = datetime.now(UTC)
+ expires_at = now_utc + timedelta(seconds=int(expires_in))
+
+ # Notion returns access_token, refresh_token (if available), and workspace information
+ # Store the encrypted tokens and workspace info in connector config
+ connector_config = {
+ "access_token": token_encryption.encrypt_token(access_token),
+ "refresh_token": token_encryption.encrypt_token(refresh_token)
+ if refresh_token
+ else None,
+ "expires_in": expires_in,
+ "expires_at": expires_at.isoformat() if expires_at else None,
+ "workspace_id": token_json.get("workspace_id"),
+ "workspace_name": token_json.get("workspace_name"),
+ "workspace_icon": token_json.get("workspace_icon"),
+ "bot_id": token_json.get("bot_id"),
+ # Mark that token is encrypted for backward compatibility
+ "_token_encrypted": True,
+ }
+
+ # Check if connector already exists for this search space and user
+ existing_connector_result = await session.execute(
+ select(SearchSourceConnector).filter(
+ SearchSourceConnector.search_space_id == space_id,
+ SearchSourceConnector.user_id == user_id,
+ SearchSourceConnector.connector_type
+ == SearchSourceConnectorType.NOTION_CONNECTOR,
+ )
+ )
+ existing_connector = existing_connector_result.scalars().first()
+
+ if existing_connector:
+ # Update existing connector
+ existing_connector.config = connector_config
+ existing_connector.name = "Notion Connector"
+ existing_connector.is_indexable = True
+ logger.info(
+ f"Updated existing Notion connector for user {user_id} in space {space_id}"
+ )
+ else:
+ # Create new connector
+ new_connector = SearchSourceConnector(
+ name="Notion Connector",
+ connector_type=SearchSourceConnectorType.NOTION_CONNECTOR,
+ is_indexable=True,
+ config=connector_config,
+ search_space_id=space_id,
+ user_id=user_id,
+ )
+ session.add(new_connector)
+ logger.info(
+ f"Created new Notion connector for user {user_id} in space {space_id}"
+ )
+
+ try:
+ await session.commit()
+ logger.info(f"Successfully saved Notion connector for user {user_id}")
+
+ # Redirect to the frontend with success params
+ return RedirectResponse(
+ url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector"
+ )
+
+ except ValidationError as e:
+ await session.rollback()
+ raise HTTPException(
+ status_code=422, detail=f"Validation error: {e!s}"
+ ) from e
+ except IntegrityError as e:
+ await session.rollback()
+ raise HTTPException(
+ status_code=409,
+ detail=f"Integrity error: A connector with this type already exists. {e!s}",
+ ) from e
+ except Exception as e:
+ logger.error(f"Failed to create search source connector: {e!s}")
+ await session.rollback()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create search source connector: {e!s}",
+ ) from e
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to complete Notion OAuth: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to complete Notion OAuth: {e!s}"
+ ) from e
+
+
+async def refresh_notion_token(
+ session: AsyncSession, connector: SearchSourceConnector
+) -> SearchSourceConnector:
+ """
+ Refresh the Notion access token for a connector.
+
+ Args:
+ session: Database session
+ connector: Notion connector to refresh
+
+ Returns:
+ Updated connector object
+ """
+ try:
+ logger.info(f"Refreshing Notion token for connector {connector.id}")
+
+ credentials = NotionAuthCredentialsBase.from_dict(connector.config)
+
+ # Decrypt tokens if they are encrypted
+ token_encryption = get_token_encryption()
+ is_encrypted = connector.config.get("_token_encrypted", False)
+
+ refresh_token = credentials.refresh_token
+ if is_encrypted and refresh_token:
+ try:
+ refresh_token = token_encryption.decrypt_token(refresh_token)
+ except Exception as e:
+ logger.error(f"Failed to decrypt refresh token: {e!s}")
+ raise HTTPException(
+ status_code=500, detail="Failed to decrypt stored refresh token"
+ ) from e
+
+ if not refresh_token:
+ raise HTTPException(
+ status_code=400,
+ detail="No refresh token available. Please re-authenticate.",
+ )
+
+ auth_header = make_basic_auth_header(
+ config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET
+ )
+
+ # Prepare token refresh data
+ refresh_data = {
+ "grant_type": "refresh_token",
+ "refresh_token": refresh_token,
+ }
+
+ async with httpx.AsyncClient() as client:
+ token_response = await client.post(
+ TOKEN_URL,
+ json=refresh_data,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": auth_header,
+ },
+ timeout=30.0,
+ )
+
+ if token_response.status_code != 200:
+ error_detail = token_response.text
+ try:
+ error_json = token_response.json()
+ error_detail = error_json.get("error_description", error_detail)
+ except Exception:
+ pass
+ raise HTTPException(
+ status_code=400, detail=f"Token refresh failed: {error_detail}"
+ )
+
+ token_json = token_response.json()
+
+ # Calculate expiration time (UTC, tz-aware)
+ expires_at = None
+ expires_in = token_json.get("expires_in")
+ if expires_in:
+ now_utc = datetime.now(UTC)
+ expires_at = now_utc + timedelta(seconds=int(expires_in))
+
+ # Encrypt new tokens before storing
+ access_token = token_json.get("access_token")
+ new_refresh_token = token_json.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="No access token received from Notion refresh"
+ )
+
+ # Update credentials object with encrypted tokens
+ credentials.access_token = token_encryption.encrypt_token(access_token)
+ if new_refresh_token:
+ credentials.refresh_token = token_encryption.encrypt_token(
+ new_refresh_token
+ )
+ credentials.expires_in = expires_in
+ credentials.expires_at = expires_at
+
+ # Preserve workspace info
+ if not credentials.workspace_id:
+ credentials.workspace_id = connector.config.get("workspace_id")
+ if not credentials.workspace_name:
+ credentials.workspace_name = connector.config.get("workspace_name")
+ if not credentials.workspace_icon:
+ credentials.workspace_icon = connector.config.get("workspace_icon")
+ if not credentials.bot_id:
+ credentials.bot_id = connector.config.get("bot_id")
+
+ # Update connector config with encrypted tokens
+ credentials_dict = credentials.to_dict()
+ credentials_dict["_token_encrypted"] = True
+ connector.config = credentials_dict
+ await session.commit()
+ await session.refresh(connector)
+
+ logger.info(f"Successfully refreshed Notion token for connector {connector.id}")
+
+ return connector
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to refresh Notion token: {e!s}", exc_info=True)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to refresh Notion token: {e!s}"
+ ) from e
diff --git a/surfsense_backend/app/schemas/linear_auth_credentials.py b/surfsense_backend/app/schemas/linear_auth_credentials.py
new file mode 100644
index 000000000..99e8d9111
--- /dev/null
+++ b/surfsense_backend/app/schemas/linear_auth_credentials.py
@@ -0,0 +1,66 @@
+from datetime import UTC, datetime
+
+from pydantic import BaseModel, field_validator
+
+
+class LinearAuthCredentialsBase(BaseModel):
+ access_token: str
+ refresh_token: str | None = None
+ token_type: str = "Bearer"
+ expires_in: int | None = None
+ expires_at: datetime | None = None
+ scope: str | None = None
+
+ @property
+ def is_expired(self) -> bool:
+ """Check if the credentials have expired."""
+ if self.expires_at is None:
+ return False
+ return self.expires_at <= datetime.now(UTC)
+
+ @property
+ def is_refreshable(self) -> bool:
+ """Check if the credentials can be refreshed."""
+ return self.refresh_token is not None
+
+ def to_dict(self) -> dict:
+ """Convert credentials to dictionary for storage."""
+ return {
+ "access_token": self.access_token,
+ "refresh_token": self.refresh_token,
+ "token_type": self.token_type,
+ "expires_in": self.expires_in,
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
+ "scope": self.scope,
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "LinearAuthCredentialsBase":
+ """Create credentials from dictionary."""
+ expires_at = None
+ if data.get("expires_at"):
+ expires_at = datetime.fromisoformat(data["expires_at"])
+
+ return cls(
+ access_token=data["access_token"],
+ refresh_token=data.get("refresh_token"),
+ token_type=data.get("token_type", "Bearer"),
+ expires_in=data.get("expires_in"),
+ expires_at=expires_at,
+ scope=data.get("scope"),
+ )
+
+ @field_validator("expires_at", mode="before")
+ @classmethod
+ def ensure_aware_utc(cls, v):
+ # Strings like "2025-08-26T14:46:57.367184"
+ if isinstance(v, str):
+ # add +00:00 if missing tz info
+ if v.endswith("Z"):
+ return datetime.fromisoformat(v.replace("Z", "+00:00"))
+ dt = datetime.fromisoformat(v)
+ return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
+ # datetime objects
+ if isinstance(v, datetime):
+ return v if v.tzinfo else v.replace(tzinfo=UTC)
+ return v
diff --git a/surfsense_backend/app/schemas/notion_auth_credentials.py b/surfsense_backend/app/schemas/notion_auth_credentials.py
new file mode 100644
index 000000000..e66afb903
--- /dev/null
+++ b/surfsense_backend/app/schemas/notion_auth_credentials.py
@@ -0,0 +1,72 @@
+from datetime import UTC, datetime
+
+from pydantic import BaseModel, field_validator
+
+
+class NotionAuthCredentialsBase(BaseModel):
+ access_token: str
+ refresh_token: str | None = None
+ expires_in: int | None = None
+ expires_at: datetime | None = None
+ workspace_id: str | None = None
+ workspace_name: str | None = None
+ workspace_icon: str | None = None
+ bot_id: str | None = None
+
+ @property
+ def is_expired(self) -> bool:
+ """Check if the credentials have expired."""
+ if self.expires_at is None:
+ return False # Long-lived token, treat as not expired
+ return self.expires_at <= datetime.now(UTC)
+
+ @property
+ def is_refreshable(self) -> bool:
+ """Check if the credentials can be refreshed."""
+ return self.refresh_token is not None
+
+ def to_dict(self) -> dict:
+ """Convert credentials to dictionary for storage."""
+ return {
+ "access_token": self.access_token,
+ "refresh_token": self.refresh_token,
+ "expires_in": self.expires_in,
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
+ "workspace_id": self.workspace_id,
+ "workspace_name": self.workspace_name,
+ "workspace_icon": self.workspace_icon,
+ "bot_id": self.bot_id,
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "NotionAuthCredentialsBase":
+ """Create credentials from dictionary."""
+ expires_at = None
+ if data.get("expires_at"):
+ expires_at = datetime.fromisoformat(data["expires_at"])
+
+ return cls(
+ access_token=data["access_token"],
+ refresh_token=data.get("refresh_token"),
+ expires_in=data.get("expires_in"),
+ expires_at=expires_at,
+ workspace_id=data.get("workspace_id"),
+ workspace_name=data.get("workspace_name"),
+ workspace_icon=data.get("workspace_icon"),
+ bot_id=data.get("bot_id"),
+ )
+
+ @field_validator("expires_at", mode="before")
+ @classmethod
+ def ensure_aware_utc(cls, v):
+ # Strings like "2025-08-26T14:46:57.367184"
+ if isinstance(v, str):
+ # add +00:00 if missing tz info
+ if v.endswith("Z"):
+ return datetime.fromisoformat(v.replace("Z", "+00:00"))
+ dt = datetime.fromisoformat(v)
+ return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
+ # datetime objects
+ if isinstance(v, datetime):
+ return v if v.tzinfo else v.replace(tzinfo=UTC)
+ return v
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 69b75e5c4..3b87c33f1 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -270,7 +270,8 @@ async def stream_new_chat(
# Track if we just finished a tool (text flows silently after tools)
just_finished_tool: bool = False
# Track write_todos calls to show "Creating plan" vs "Updating plan"
- write_todos_call_count: int = 0
+ # Disabled for now
+ # write_todos_call_count: int = 0
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
@@ -479,60 +480,60 @@ def complete_current_step() -> str | None:
status="in_progress",
items=last_active_step_items,
)
- elif tool_name == "write_todos":
- # Track write_todos calls for better messaging
- write_todos_call_count += 1
- todos = (
- tool_input.get("todos", [])
- if isinstance(tool_input, dict)
- else []
- )
- todo_count = len(todos) if isinstance(todos, list) else 0
-
- if write_todos_call_count == 1:
- # First call - creating the plan
- last_active_step_title = "Creating plan"
- last_active_step_items = [f"Defining {todo_count} tasks..."]
- else:
- # Subsequent calls - updating the plan
- # Try to provide context about what's being updated
- in_progress_count = (
- sum(
- 1
- for t in todos
- if isinstance(t, dict)
- and t.get("status") == "in_progress"
- )
- if isinstance(todos, list)
- else 0
- )
- completed_count = (
- sum(
- 1
- for t in todos
- if isinstance(t, dict)
- and t.get("status") == "completed"
- )
- if isinstance(todos, list)
- else 0
- )
-
- last_active_step_title = "Updating progress"
- last_active_step_items = (
- [
- f"Progress: {completed_count}/{todo_count} completed",
- f"In progress: {in_progress_count} tasks",
- ]
- if completed_count > 0
- else [f"Working on {todo_count} tasks"]
- )
+ # elif tool_name == "write_todos": # Disabled for now
+ # # Track write_todos calls for better messaging
+ # write_todos_call_count += 1
+ # todos = (
+ # tool_input.get("todos", [])
+ # if isinstance(tool_input, dict)
+ # else []
+ # )
+ # todo_count = len(todos) if isinstance(todos, list) else 0
+
+ # if write_todos_call_count == 1:
+ # # First call - creating the plan
+ # last_active_step_title = "Creating plan"
+ # last_active_step_items = [f"Defining {todo_count} tasks..."]
+ # else:
+ # # Subsequent calls - updating the plan
+ # # Try to provide context about what's being updated
+ # in_progress_count = (
+ # sum(
+ # 1
+ # for t in todos
+ # if isinstance(t, dict)
+ # and t.get("status") == "in_progress"
+ # )
+ # if isinstance(todos, list)
+ # else 0
+ # )
+ # completed_count = (
+ # sum(
+ # 1
+ # for t in todos
+ # if isinstance(t, dict)
+ # and t.get("status") == "completed"
+ # )
+ # if isinstance(todos, list)
+ # else 0
+ # )
+
+ # last_active_step_title = "Updating progress"
+ # last_active_step_items = (
+ # [
+ # f"Progress: {completed_count}/{todo_count} completed",
+ # f"In progress: {in_progress_count} tasks",
+ # ]
+ # if completed_count > 0
+ # else [f"Working on {todo_count} tasks"]
+ # )
- yield streaming_service.format_thinking_step(
- step_id=tool_step_id,
- title=last_active_step_title,
- status="in_progress",
- items=last_active_step_items,
- )
+ # yield streaming_service.format_thinking_step(
+ # step_id=tool_step_id,
+ # title=last_active_step_title,
+ # status="in_progress",
+ # items=last_active_step_items,
+ # )
elif tool_name == "generate_podcast":
podcast_title = (
tool_input.get("podcast_title", "SurfSense Podcast")
@@ -596,10 +597,12 @@ def complete_current_step() -> str | None:
raw_output = event.get("data", {}).get("output", "")
# Handle deepagents' write_todos Command object specially
- if tool_name == "write_todos" and hasattr(raw_output, "update"):
- # deepagents returns a Command object - extract todos directly
- tool_output = extract_todos_from_deepagents(raw_output)
- elif hasattr(raw_output, "content"):
+ # Disabled for now
+ # if tool_name == "write_todos" and hasattr(raw_output, "update"):
+ # # deepagents returns a Command object - extract todos directly
+ # tool_output = extract_todos_from_deepagents(raw_output)
+ # elif hasattr(raw_output, "content"):
+ if hasattr(raw_output, "content"):
# It's a ToolMessage object - extract the content
content = raw_output.content
# If content is a string that looks like JSON, try to parse it
@@ -758,63 +761,63 @@ def complete_current_step() -> str | None:
status="completed",
items=completed_items,
)
- elif tool_name == "write_todos":
- # Build completion items for planning/updating
- if isinstance(tool_output, dict):
- todos = tool_output.get("todos", [])
- todo_count = len(todos) if isinstance(todos, list) else 0
- completed_count = (
- sum(
- 1
- for t in todos
- if isinstance(t, dict)
- and t.get("status") == "completed"
- )
- if isinstance(todos, list)
- else 0
- )
- in_progress_count = (
- sum(
- 1
- for t in todos
- if isinstance(t, dict)
- and t.get("status") == "in_progress"
- )
- if isinstance(todos, list)
- else 0
- )
-
- # Use context-aware completion message
- if last_active_step_title == "Creating plan":
- completed_items = [f"Created {todo_count} tasks"]
- else:
- # Updating progress - show stats
- completed_items = [
- f"Progress: {completed_count}/{todo_count} completed",
- ]
- if in_progress_count > 0:
- # Find the currently in-progress task name
- in_progress_task = next(
- (
- t.get("content", "")[:40]
- for t in todos
- if isinstance(t, dict)
- and t.get("status") == "in_progress"
- ),
- None,
- )
- if in_progress_task:
- completed_items.append(
- f"Current: {in_progress_task}..."
- )
- else:
- completed_items = ["Plan updated"]
- yield streaming_service.format_thinking_step(
- step_id=original_step_id,
- title=last_active_step_title,
- status="completed",
- items=completed_items,
- )
+ # elif tool_name == "write_todos": # Disabled for now
+ # # Build completion items for planning/updating
+ # if isinstance(tool_output, dict):
+ # todos = tool_output.get("todos", [])
+ # todo_count = len(todos) if isinstance(todos, list) else 0
+ # completed_count = (
+ # sum(
+ # 1
+ # for t in todos
+ # if isinstance(t, dict)
+ # and t.get("status") == "completed"
+ # )
+ # if isinstance(todos, list)
+ # else 0
+ # )
+ # in_progress_count = (
+ # sum(
+ # 1
+ # for t in todos
+ # if isinstance(t, dict)
+ # and t.get("status") == "in_progress"
+ # )
+ # if isinstance(todos, list)
+ # else 0
+ # )
+
+ # # Use context-aware completion message
+ # if last_active_step_title == "Creating plan":
+ # completed_items = [f"Created {todo_count} tasks"]
+ # else:
+ # # Updating progress - show stats
+ # completed_items = [
+ # f"Progress: {completed_count}/{todo_count} completed",
+ # ]
+ # if in_progress_count > 0:
+ # # Find the currently in-progress task name
+ # in_progress_task = next(
+ # (
+ # t.get("content", "")[:40]
+ # for t in todos
+ # if isinstance(t, dict)
+ # and t.get("status") == "in_progress"
+ # ),
+ # None,
+ # )
+ # if in_progress_task:
+ # completed_items.append(
+ # f"Current: {in_progress_task}..."
+ # )
+ # else:
+ # completed_items = ["Plan updated"]
+ # yield streaming_service.format_thinking_step(
+ # step_id=original_step_id,
+ # title=last_active_step_title,
+ # status="completed",
+ # items=completed_items,
+ # )
elif tool_name == "ls":
# Build completion items showing file names found
if isinstance(tool_output, dict):
@@ -992,27 +995,27 @@ def complete_current_step() -> str | None:
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
- elif tool_name == "write_todos":
- # Stream the full write_todos result so frontend can render the Plan component
- yield streaming_service.format_tool_output_available(
- tool_call_id,
- tool_output
- if isinstance(tool_output, dict)
- else {"result": tool_output},
- )
- # Send terminal message with plan info
- if isinstance(tool_output, dict):
- todos = tool_output.get("todos", [])
- todo_count = len(todos) if isinstance(todos, list) else 0
- yield streaming_service.format_terminal_info(
- f"Plan created ({todo_count} tasks)",
- "success",
- )
- else:
- yield streaming_service.format_terminal_info(
- "Plan created",
- "success",
- )
+ # elif tool_name == "write_todos": # Disabled for now
+ # # Stream the full write_todos result so frontend can render the Plan component
+ # yield streaming_service.format_tool_output_available(
+ # tool_call_id,
+ # tool_output
+ # if isinstance(tool_output, dict)
+ # else {"result": tool_output},
+ # )
+ # # Send terminal message with plan info
+ # if isinstance(tool_output, dict):
+ # todos = tool_output.get("todos", [])
+ # todo_count = len(todos) if isinstance(todos, list) else 0
+ # yield streaming_service.format_terminal_info(
+ # f"Plan created ({todo_count} tasks)",
+ # "success",
+ # )
+ # else:
+ # yield streaming_service.format_terminal_info(
+ # "Plan created",
+ # "success",
+ # )
else:
# Default handling for other tools
yield streaming_service.format_tool_output_available(
diff --git a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py
index cea2a0529..3ea6dccc9 100644
--- a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py
@@ -18,6 +18,7 @@
generate_document_summary,
generate_unique_identifier_hash,
)
+from app.utils.oauth_security import TokenEncryption
from .base import (
calculate_date_range,
@@ -85,7 +86,52 @@ async def index_airtable_records(
return 0, f"Connector with ID {connector_id} not found"
# Create credentials from connector config
- config_data = connector.config
+ config_data = (
+ connector.config.copy()
+ ) # Work with a copy to avoid modifying original
+
+ # Decrypt tokens if they are encrypted (only when explicitly marked)
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted:
+ # Tokens are explicitly marked as encrypted, attempt decryption
+ if not config.SECRET_KEY:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"SECRET_KEY not configured but tokens are marked as encrypted for connector {connector_id}",
+ "Missing SECRET_KEY for token decryption",
+ {"error_type": "MissingSecretKey"},
+ )
+ return 0, "SECRET_KEY not configured but tokens are marked as encrypted"
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt access_token
+ if config_data.get("access_token"):
+ config_data["access_token"] = token_encryption.decrypt_token(
+ config_data["access_token"]
+ )
+ logger.info(
+ f"Decrypted Airtable access token for connector {connector_id}"
+ )
+
+ # Decrypt refresh_token if present
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ logger.info(
+ f"Decrypted Airtable refresh token for connector {connector_id}"
+ )
+ except Exception as e:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"Failed to decrypt Airtable tokens for connector {connector_id}: {e!s}",
+ "Token decryption failed",
+ {"error_type": "TokenDecryptionError"},
+ )
+ return 0, f"Failed to decrypt Airtable tokens: {e!s}"
+ # If _token_encrypted is False or not set, treat tokens as plaintext
+
try:
credentials = AirtableAuthCredentialsBase.from_dict(config_data)
except Exception as e:
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
index a5d2bc73a..499f01d66 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py
@@ -8,7 +8,6 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
-from app.config import config
from app.connectors.google_calendar_connector import GoogleCalendarConnector
from app.db import Document, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
@@ -84,15 +83,52 @@ async def index_google_calendar_events(
return 0, f"Connector with ID {connector_id} not found"
# Get the Google Calendar credentials from the connector config
- exp = connector.config.get("expiry").replace("Z", "")
+ config_data = connector.config
+
+ # Decrypt sensitive credentials if encrypted (for backward compatibility)
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt sensitive fields
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+
+ logger.info(
+ f"Decrypted Google Calendar credentials for connector {connector_id}"
+ )
+ except Exception as e:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}",
+ "Credential decryption failed",
+ {"error_type": "CredentialDecryptionError"},
+ )
+ return 0, f"Failed to decrypt Google Calendar credentials: {e!s}"
+
+ exp = config_data.get("expiry", "").replace("Z", "")
credentials = Credentials(
- token=connector.config.get("token"),
- refresh_token=connector.config.get("refresh_token"),
- token_uri=connector.config.get("token_uri"),
- client_id=connector.config.get("client_id"),
- client_secret=connector.config.get("client_secret"),
- scopes=connector.config.get("scopes"),
- expiry=datetime.fromisoformat(exp),
+ token=config_data.get("token"),
+ refresh_token=config_data.get("refresh_token"),
+ token_uri=config_data.get("token_uri"),
+ client_id=config_data.get("client_id"),
+ client_secret=config_data.get("client_secret"),
+ scopes=config_data.get("scopes"),
+ expiry=datetime.fromisoformat(exp) if exp else None,
)
if (
@@ -122,6 +158,12 @@ async def index_google_calendar_events(
connector_id=connector_id,
)
+ # Handle 'undefined' string from frontend (treat as None)
+ if start_date == "undefined" or start_date == "":
+ start_date = None
+ if end_date == "undefined" or end_date == "":
+ end_date = None
+
# Calculate date range
if start_date is None or end_date is None:
# Fall back to calculating dates based on last_indexed_at
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
index 343d44072..9eeb46fc8 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py
@@ -5,6 +5,7 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
+from app.config import config
from app.connectors.google_drive import (
GoogleDriveClient,
categorize_change,
@@ -87,6 +88,26 @@ async def index_google_drive_files(
{"stage": "client_initialization"},
)
+ # Check if credentials are encrypted (only when explicitly marked)
+ token_encrypted = connector.config.get("_token_encrypted", False)
+ if token_encrypted:
+ # Credentials are explicitly marked as encrypted, will be decrypted during client initialization
+ if not config.SECRET_KEY:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
+ "Missing SECRET_KEY for token decryption",
+ {"error_type": "MissingSecretKey"},
+ )
+ return (
+ 0,
+ "SECRET_KEY not configured but credentials are marked as encrypted",
+ )
+ logger.info(
+ f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
+ )
+ # If _token_encrypted is False or not set, treat credentials as plaintext
+
drive_client = GoogleDriveClient(session, connector_id)
if not folder_id:
@@ -249,6 +270,26 @@ async def index_google_drive_single_file(
{"stage": "client_initialization"},
)
+ # Check if credentials are encrypted (only when explicitly marked)
+ token_encrypted = connector.config.get("_token_encrypted", False)
+ if token_encrypted:
+ # Credentials are explicitly marked as encrypted, will be decrypted during client initialization
+ if not config.SECRET_KEY:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
+ "Missing SECRET_KEY for token decryption",
+ {"error_type": "MissingSecretKey"},
+ )
+ return (
+ 0,
+ "SECRET_KEY not configured but credentials are marked as encrypted",
+ )
+ logger.info(
+ f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
+ )
+ # If _token_encrypted is False or not set, treat credentials as plaintext
+
drive_client = GoogleDriveClient(session, connector_id)
# Fetch the file metadata
diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
index d350411e1..e10297057 100644
--- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py
@@ -8,7 +8,6 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
-from app.config import config
from app.connectors.google_gmail_connector import GoogleGmailConnector
from app.db import (
Document,
@@ -88,9 +87,47 @@ async def index_google_gmail_messages(
)
return 0, error_msg
- # Create credentials from connector config
+ # Get the Google Gmail credentials from the connector config
config_data = connector.config
- exp = config_data.get("expiry").replace("Z", "")
+
+ # Decrypt sensitive credentials if encrypted (for backward compatibility)
+ from app.config import config
+ from app.utils.oauth_security import TokenEncryption
+
+ token_encrypted = config_data.get("_token_encrypted", False)
+ if token_encrypted and config.SECRET_KEY:
+ try:
+ token_encryption = TokenEncryption(config.SECRET_KEY)
+
+ # Decrypt sensitive fields
+ if config_data.get("token"):
+ config_data["token"] = token_encryption.decrypt_token(
+ config_data["token"]
+ )
+ if config_data.get("refresh_token"):
+ config_data["refresh_token"] = token_encryption.decrypt_token(
+ config_data["refresh_token"]
+ )
+ if config_data.get("client_secret"):
+ config_data["client_secret"] = token_encryption.decrypt_token(
+ config_data["client_secret"]
+ )
+
+ logger.info(
+ f"Decrypted Google Gmail credentials for connector {connector_id}"
+ )
+ except Exception as e:
+ await task_logger.log_task_failure(
+ log_entry,
+ f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}",
+ "Credential decryption failed",
+ {"error_type": "CredentialDecryptionError"},
+ )
+ return 0, f"Failed to decrypt Google Gmail credentials: {e!s}"
+
+ exp = config_data.get("expiry", "")
+ if exp:
+ exp = exp.replace("Z", "")
credentials = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
@@ -98,7 +135,7 @@ async def index_google_gmail_messages(
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
- expiry=datetime.fromisoformat(exp),
+ expiry=datetime.fromisoformat(exp) if exp else None,
)
if (
diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py
index afc9ffd3b..f1bfd42e8 100644
--- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py
@@ -92,25 +92,34 @@ async def index_linear_issues(
f"Connector with ID {connector_id} not found or is not a Linear connector",
)
- # Get the Linear token from the connector config
- linear_token = connector.config.get("LINEAR_API_KEY")
- if not linear_token:
+ # Check if access_token exists (support both new OAuth format and old API key format)
+ if not connector.config.get("access_token") and not connector.config.get(
+ "LINEAR_API_KEY"
+ ):
await task_logger.log_task_failure(
log_entry,
- f"Linear API token not found in connector config for connector {connector_id}",
- "Missing Linear token",
+ f"Linear access token not found in connector config for connector {connector_id}",
+ "Missing Linear access token",
{"error_type": "MissingToken"},
)
- return 0, "Linear API token not found in connector config"
+ return 0, "Linear access token not found in connector config"
- # Initialize Linear client
+ # Initialize Linear client with internal refresh capability
await task_logger.log_task_progress(
log_entry,
f"Initializing Linear client for connector {connector_id}",
{"stage": "client_initialization"},
)
- linear_client = LinearConnector(token=linear_token)
+ # Create connector with session and connector_id for internal refresh
+ # Token refresh will happen automatically when needed
+ linear_client = LinearConnector(session=session, connector_id=connector_id)
+
+ # Handle 'undefined' string from frontend (treat as None)
+ if start_date == "undefined" or start_date == "":
+ start_date = None
+ if end_date == "undefined" or end_date == "":
+ end_date = None
# Calculate date range
start_date_str, end_date_str = calculate_date_range(
@@ -131,7 +140,7 @@ async def index_linear_issues(
# Get issues within date range
try:
- issues, error = linear_client.get_issues_by_date_range(
+ issues, error = await linear_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py
index 332d3e39d..13923269d 100644
--- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py
+++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py
@@ -2,7 +2,7 @@
Notion connector indexer.
"""
-from datetime import datetime, timedelta
+from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
@@ -20,6 +20,7 @@
from .base import (
build_document_metadata_string,
+ calculate_date_range,
check_document_by_unique_identifier,
get_connector_by_id,
get_current_timestamp,
@@ -91,18 +92,19 @@ async def index_notion_pages(
f"Connector with ID {connector_id} not found or is not a Notion connector",
)
- # Get the Notion token from the connector config
- notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
- if not notion_token:
+ # Check if access_token exists (support both new OAuth format and old integration token format)
+ if not connector.config.get("access_token") and not connector.config.get(
+ "NOTION_INTEGRATION_TOKEN"
+ ):
await task_logger.log_task_failure(
log_entry,
- f"Notion integration token not found in connector config for connector {connector_id}",
- "Missing Notion token",
+ f"Notion access token not found in connector config for connector {connector_id}",
+ "Missing Notion access token",
{"error_type": "MissingToken"},
)
- return 0, "Notion integration token not found in connector config"
+ return 0, "Notion access token not found in connector config"
- # Initialize Notion client
+ # Initialize Notion client with internal refresh capability
await task_logger.log_task_progress(
log_entry,
f"Initializing Notion client for connector {connector_id}",
@@ -111,40 +113,30 @@ async def index_notion_pages(
logger.info(f"Initializing Notion client for connector {connector_id}")
- # Calculate date range
- if start_date is None or end_date is None:
- # Fall back to calculating dates
- calculated_end_date = datetime.now()
- calculated_start_date = calculated_end_date - timedelta(
- days=365
- ) # Check for last 1 year of pages
-
- # Use calculated dates if not provided
- if start_date is None:
- start_date_iso = calculated_start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
- else:
- # Convert YYYY-MM-DD to ISO format
- start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime(
- "%Y-%m-%dT%H:%M:%SZ"
- )
+ # Handle 'undefined' string from frontend (treat as None)
+ if start_date == "undefined" or start_date == "":
+ start_date = None
+ if end_date == "undefined" or end_date == "":
+ end_date = None
- if end_date is None:
- end_date_iso = calculated_end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
- else:
- # Convert YYYY-MM-DD to ISO format
- end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime(
- "%Y-%m-%dT%H:%M:%SZ"
- )
- else:
- # Convert provided dates to ISO format for Notion API
- start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime(
- "%Y-%m-%dT%H:%M:%SZ"
- )
- end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime(
- "%Y-%m-%dT%H:%M:%SZ"
- )
+ # Calculate date range using the shared utility function
+ start_date_str, end_date_str = calculate_date_range(
+ connector, start_date, end_date, default_days_back=365
+ )
- notion_client = NotionHistoryConnector(token=notion_token)
+ # Convert YYYY-MM-DD to ISO format for Notion API
+ start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime(
+ "%Y-%m-%dT%H:%M:%SZ"
+ )
+ end_date_iso = datetime.strptime(end_date_str, "%Y-%m-%d").strftime(
+ "%Y-%m-%dT%H:%M:%SZ"
+ )
+
+ # Create connector with session and connector_id for internal refresh
+ # Token refresh will happen automatically when needed
+ notion_client = NotionHistoryConnector(
+ session=session, connector_id=connector_id
+ )
logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}")
diff --git a/surfsense_backend/app/utils/oauth_security.py b/surfsense_backend/app/utils/oauth_security.py
new file mode 100644
index 000000000..5135cdef4
--- /dev/null
+++ b/surfsense_backend/app/utils/oauth_security.py
@@ -0,0 +1,210 @@
+"""
+OAuth Security Utilities.
+
+Provides secure state parameter generation/validation and token encryption
+for OAuth 2.0 flows.
+"""
+
+import base64
+import hashlib
+import hmac
+import json
+import logging
+import time
+from uuid import UUID
+
+from cryptography.fernet import Fernet
+from fastapi import HTTPException
+
+logger = logging.getLogger(__name__)
+
+
+class OAuthStateManager:
+ """Manages secure OAuth state parameters with HMAC signatures."""
+
+ def __init__(self, secret_key: str, max_age_seconds: int = 600):
+ """
+ Initialize OAuth state manager.
+
+ Args:
+ secret_key: Secret key for HMAC signing (should be SECRET_KEY from config)
+ max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes)
+ """
+ if not secret_key:
+ raise ValueError("secret_key is required for OAuth state management")
+ self.secret_key = secret_key
+ self.max_age_seconds = max_age_seconds
+
+ def generate_secure_state(
+ self, space_id: int, user_id: UUID, **extra_fields
+ ) -> str:
+ """
+ Generate cryptographically signed state parameter.
+
+ Args:
+ space_id: The search space ID
+ user_id: The user ID
+ **extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE)
+
+ Returns:
+ Base64-encoded state parameter with HMAC signature
+ """
+ timestamp = int(time.time())
+ state_payload = {
+ "space_id": space_id,
+ "user_id": str(user_id),
+ "timestamp": timestamp,
+ }
+
+ # Add any extra fields (e.g., code_verifier for PKCE)
+ state_payload.update(extra_fields)
+
+ # Create signature
+ payload_str = json.dumps(state_payload, sort_keys=True)
+ signature = hmac.new(
+ self.secret_key.encode(),
+ payload_str.encode(),
+ hashlib.sha256,
+ ).hexdigest()
+
+ # Include signature in state
+ state_payload["signature"] = signature
+ state_encoded = base64.urlsafe_b64encode(
+ json.dumps(state_payload).encode()
+ ).decode()
+
+ return state_encoded
+
+ def validate_state(self, state: str) -> dict:
+ """
+ Validate and decode state parameter with signature verification.
+
+ Args:
+ state: The state parameter from OAuth callback
+
+ Returns:
+ Decoded state data (space_id, user_id, timestamp)
+
+ Raises:
+ HTTPException: If state is invalid, expired, or tampered with
+ """
+ try:
+ decoded = base64.urlsafe_b64decode(state.encode()).decode()
+ data = json.loads(decoded)
+ except Exception as e:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid state format: {e!s}"
+ ) from e
+
+ # Verify signature exists
+ signature = data.pop("signature", None)
+ if not signature:
+ raise HTTPException(status_code=400, detail="Missing state signature")
+
+ # Verify signature
+ payload_str = json.dumps(data, sort_keys=True)
+ expected_signature = hmac.new(
+ self.secret_key.encode(),
+ payload_str.encode(),
+ hashlib.sha256,
+ ).hexdigest()
+
+ if not hmac.compare_digest(signature, expected_signature):
+ raise HTTPException(
+ status_code=400, detail="Invalid state signature - possible tampering"
+ )
+
+ # Verify timestamp (prevent replay attacks)
+ timestamp = data.get("timestamp", 0)
+ current_time = time.time()
+ age = current_time - timestamp
+
+ if age < 0:
+ raise HTTPException(status_code=400, detail="Invalid state timestamp")
+
+ if age > self.max_age_seconds:
+ raise HTTPException(
+ status_code=400,
+ detail="State parameter expired. Please try again.",
+ )
+
+ return data
+
+
+class TokenEncryption:
+ """Encrypt/decrypt sensitive OAuth tokens for storage."""
+
+ def __init__(self, secret_key: str):
+ """
+ Initialize token encryption.
+
+ Args:
+ secret_key: Secret key for encryption (should be SECRET_KEY from config)
+ """
+ if not secret_key:
+ raise ValueError("secret_key is required for token encryption")
+ # Derive Fernet key from secret using SHA256
+ # Note: In production, consider using HKDF for key derivation
+ key = base64.urlsafe_b64encode(hashlib.sha256(secret_key.encode()).digest())
+ try:
+ self.cipher = Fernet(key)
+ except Exception as e:
+ raise ValueError(f"Failed to initialize encryption cipher: {e!s}") from e
+
+ def encrypt_token(self, token: str) -> str:
+ """
+ Encrypt a token for storage.
+
+ Args:
+ token: Plaintext token to encrypt
+
+ Returns:
+ Encrypted token string
+ """
+ if not token:
+ return token
+ try:
+ return self.cipher.encrypt(token.encode()).decode()
+ except Exception as e:
+ logger.error(f"Failed to encrypt token: {e!s}")
+ raise ValueError(f"Token encryption failed: {e!s}") from e
+
+ def decrypt_token(self, encrypted_token: str) -> str:
+ """
+ Decrypt a stored token.
+
+ Args:
+ encrypted_token: Encrypted token string
+
+ Returns:
+ Decrypted plaintext token
+ """
+ if not encrypted_token:
+ return encrypted_token
+ try:
+ return self.cipher.decrypt(encrypted_token.encode()).decode()
+ except Exception as e:
+ logger.error(f"Failed to decrypt token: {e!s}")
+ raise ValueError(f"Token decryption failed: {e!s}") from e
+
+ def is_encrypted(self, token: str) -> bool:
+ """
+ Check if a token appears to be encrypted.
+
+ Args:
+ token: Token string to check
+
+ Returns:
+ True if token appears encrypted, False otherwise
+ """
+ if not token:
+ return False
+ # Encrypted tokens are base64-encoded and have specific format
+ # This is a heuristic check - encrypted tokens are longer and base64-like
+ try:
+ # Try to decode as base64
+ base64.urlsafe_b64decode(token.encode())
+ # If it's base64 and reasonably long, likely encrypted
+ return len(token) > 20
+ except Exception:
+ return False
diff --git a/surfsense_backend/app/utils/validators.py b/surfsense_backend/app/utils/validators.py
index 6b69fb3e1..d6622bafd 100644
--- a/surfsense_backend/app/utils/validators.py
+++ b/surfsense_backend/app/utils/validators.py
@@ -514,10 +514,6 @@ def validate_initial_urls() -> None:
"validators": {},
},
"SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}},
- "NOTION_CONNECTOR": {
- "required": ["NOTION_INTEGRATION_TOKEN"],
- "validators": {},
- },
"GITHUB_CONNECTOR": {
"required": ["GITHUB_PAT", "repo_full_names"],
"validators": {
@@ -526,7 +522,6 @@ def validate_initial_urls() -> None:
)
},
},
- "LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}},
"DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}},
"JIRA_CONNECTOR": {
"required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"],
diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
index 35a096497..b1abd647f 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
@@ -20,7 +20,7 @@ import {
} from "@/atoms/chat/mentioned-documents.atom";
import {
clearPlanOwnerRegistry,
- extractWriteTodosFromContent,
+ // extractWriteTodosFromContent,
hydratePlanStateAtom,
} from "@/atoms/chat/plan-state.atom";
import { Thread } from "@/components/assistant-ui/thread";
@@ -30,7 +30,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
-import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
+// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
import { getBearerToken } from "@/lib/auth-utils";
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
import {
@@ -199,7 +199,7 @@ const TOOLS_WITH_UI = new Set([
"link_preview",
"display_image",
"scrape_webpage",
- "write_todos",
+ // "write_todos", // Disabled for now
]);
/**
@@ -291,10 +291,11 @@ export default function NewChatPage() {
restoredThinkingSteps.set(`msg-${msg.id}`, steps);
}
// Hydrate write_todos plan state from persisted tool calls
- const writeTodosCalls = extractWriteTodosFromContent(msg.content);
- for (const todoData of writeTodosCalls) {
- hydratePlanState(todoData);
- }
+ // Disabled for now
+ // const writeTodosCalls = extractWriteTodosFromContent(msg.content);
+ // for (const todoData of writeTodosCalls) {
+ // hydratePlanState(todoData);
+ // }
}
if (msg.role === "user") {
const docs = extractMentionedDocuments(msg.content);
@@ -911,7 +912,7 @@ export default function NewChatPage() {
+
{formatDocumentCount(documentCount)}
)} @@ -130,12 +169,10 @@ export const ConnectorCard: FC