diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 91a0cb42f..83656a910 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -38,13 +38,22 @@ GOOGLE_OAUTH_CLIENT_SECRET=GOCSV GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback -GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback -# Airtable OAuth for Aitable Connector +# OAuth for Aitable Connector AIRTABLE_CLIENT_ID=your_airtable_client_id AIRTABLE_CLIENT_SECRET=your_airtable_client_secret AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback +# OAuth for Linear Connector +LINEAR_CLIENT_ID=your_linear_client_id +LINEAR_CLIENT_SECRET=your_linear_client_secret +LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback + +# OAuth for Notion Connector +NOTION_CLIENT_ID=your_notion_client_id +NOTION_CLIENT_SECRET=your_notion_client_secret +NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback + # Embedding Model # Examples: # # Get sentence transformers embeddings diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 9c503fb18..7c7703470 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -90,6 +90,16 @@ class Config: AIRTABLE_CLIENT_SECRET = os.getenv("AIRTABLE_CLIENT_SECRET") AIRTABLE_REDIRECT_URI = os.getenv("AIRTABLE_REDIRECT_URI") + # Notion OAuth + NOTION_CLIENT_ID = os.getenv("NOTION_CLIENT_ID") + NOTION_CLIENT_SECRET = os.getenv("NOTION_CLIENT_SECRET") + NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI") + + # Linear OAuth + LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") + LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") + LINEAR_REDIRECT_URI = os.getenv("LINEAR_REDIRECT_URI") + # LLM instances are now managed per-user through the LLMConfig system # Legacy environment variables removed in favor of user-specific configurations diff --git a/surfsense_backend/app/connectors/google_calendar_connector.py b/surfsense_backend/app/connectors/google_calendar_connector.py index 164d230e0..6d389ddd5 100644 --- a/surfsense_backend/app/connectors/google_calendar_connector.py +++ b/surfsense_backend/app/connectors/google_calendar_connector.py @@ -109,7 +109,36 @@ async def _get_credentials( raise RuntimeError( "GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token." ) - connector.config = json.loads(self._credentials.to_json()) + + # Encrypt sensitive credentials before storing + from app.config import config + from app.utils.oauth_security import TokenEncryption + + creds_dict = json.loads(self._credentials.to_json()) + token_encrypted = connector.config.get("_token_encrypted", False) + + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + # Encrypt sensitive fields + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token( + creds_dict["token"] + ) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = ( + token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = ( + token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + ) + creds_dict["_token_encrypted"] = True + + connector.config = creds_dict flag_modified(connector, "config") await self._session.commit() except Exception as e: @@ -182,6 +211,18 @@ async def get_all_primary_calendar_events( Tuple containing (events list, error message or None) """ try: + # Validate date strings + if not start_date or start_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid start_date: must be a valid date string in YYYY-MM-DD format", + ) + if not end_date or end_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid end_date: must be a valid date string in YYYY-MM-DD format", + ) + service = await self._get_service() # Parse both dates diff --git a/surfsense_backend/app/connectors/google_drive/credentials.py b/surfsense_backend/app/connectors/google_drive/credentials.py index f88486468..7e6335f6d 100644 --- a/surfsense_backend/app/connectors/google_drive/credentials.py +++ b/surfsense_backend/app/connectors/google_drive/credentials.py @@ -1,6 +1,7 @@ """Google Drive OAuth credential management.""" import json +import logging from datetime import datetime from google.auth.transport.requests import Request @@ -9,7 +10,11 @@ from sqlalchemy.future import select from sqlalchemy.orm.attributes import flag_modified +from app.config import config from app.db import SearchSourceConnector +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) async def get_valid_credentials( @@ -38,7 +43,41 @@ async def get_valid_credentials( if not connector: raise ValueError(f"Connector {connector_id} not found") - config_data = connector.config + config_data = ( + connector.config.copy() + ) # Work with a copy to avoid modifying original + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Drive credentials for connector {connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Google Drive credentials for connector {connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Google Drive credentials: {e!s}" + ) from e + exp = config_data.get("expiry", "").replace("Z", "") if not all( @@ -66,7 +105,29 @@ async def get_valid_credentials( try: credentials.refresh(Request()) - connector.config = json.loads(credentials.to_json()) + creds_dict = json.loads(credentials.to_json()) + + # Encrypt sensitive credentials before storing + token_encrypted = connector.config.get("_token_encrypted", False) + + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + # Encrypt sensitive fields + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token( + creds_dict["token"] + ) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + creds_dict["_token_encrypted"] = True + + connector.config = creds_dict flag_modified(connector, "config") await session.commit() diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index b4c54fda3..148aa4d0a 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -5,33 +5,153 @@ Allows fetching issue lists and their comments with date range filtering. """ +import logging from datetime import datetime from typing import Any import requests +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.linear_add_connector_route import refresh_linear_token +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class LinearConnector: """Class for retrieving issues and comments from Linear.""" - def __init__(self, token: str | None = None): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: LinearAuthCredentialsBase | None = None, + ): """ - Initialize the LinearConnector class. + Initialize the LinearConnector class with auto-refresh capability. Args: - token: Linear API token (optional, can be set later with set_token) + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Linear OAuth credentials (optional, will be loaded from DB if not provided) """ - self.token = token + self._session = session + self._connector_id = connector_id + self._credentials = credentials self.api_url = "https://api.linear.app/graphql" - def set_token(self, token: str) -> None: + async def _get_valid_token(self) -> str: """ - Set the Linear API token. + Get valid Linear access token, refreshing if needed. - Args: - token: Linear API token + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails """ - self.token = token + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Linear credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Linear credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Linear credentials: {e!s}" + ) from e + + try: + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Linear credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Linear token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_linear_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + + logger.info( + f"Successfully refreshed Linear token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Linear token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Linear OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token def get_headers(self) -> dict[str, str]: """ @@ -41,18 +161,26 @@ def get_headers(self) -> dict[str, str]: Dictionary of headers Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set """ - if not self.token: - raise ValueError("Linear token not initialized. Call set_token() first.") + # This is a synchronous method, but we need async token refresh + # For now, we'll raise an error if called directly + # All API calls should go through execute_graphql_query which handles async refresh + if not self._credentials or not self._credentials.access_token: + raise ValueError( + "Linear access token not initialized. Use execute_graphql_query() method." + ) - return {"Content-Type": "application/json", "Authorization": self.token} + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._credentials.access_token}", + } - def execute_graphql_query( + async def execute_graphql_query( self, query: str, variables: dict[str, Any] | None = None ) -> dict[str, Any]: """ - Execute a GraphQL query against the Linear API. + Execute a GraphQL query against the Linear API with automatic token refresh. Args: query: GraphQL query string @@ -62,13 +190,17 @@ def execute_graphql_query( Response data from the API Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set Exception: If the API request fails """ - if not self.token: - raise ValueError("Linear token not initialized. Call set_token() first.") + # Get valid token (refreshes if needed) + access_token = await self._get_valid_token() + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } - headers = self.get_headers() payload = {"query": query} if variables: @@ -83,7 +215,9 @@ def execute_graphql_query( f"Query failed with status code {response.status_code}: {response.text}" ) - def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: + async def get_all_issues( + self, include_comments: bool = True + ) -> list[dict[str, Any]]: """ Fetch all issues from Linear. @@ -94,7 +228,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: List of issue objects Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set Exception: If the API request fails """ comments_query = "" @@ -146,7 +280,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: }} """ - result = self.execute_graphql_query(query) + result = await self.execute_graphql_query(query) # Extract issues from the response if ( @@ -158,7 +292,7 @@ def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: return [] - def get_issues_by_date_range( + async def get_issues_by_date_range( self, start_date: str, end_date: str, include_comments: bool = True ) -> tuple[list[dict[str, Any]], str | None]: """ @@ -172,6 +306,18 @@ def get_issues_by_date_range( Returns: Tuple containing (issues list, error message or None) """ + # Validate date strings + if not start_date or start_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid start_date: must be a valid date string in YYYY-MM-DD format", + ) + if not end_date or end_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid end_date: must be a valid date string in YYYY-MM-DD format", + ) + # Convert date strings to ISO format try: # For Linear API: we need to use a more specific format for the filter @@ -258,7 +404,7 @@ def get_issues_by_date_range( # Handle pagination to get all issues while has_next_page: variables = {"after": cursor} if cursor else {} - result = self.execute_graphql_query(query, variables) + result = await self.execute_graphql_query(query, variables) # Check for errors if "errors" in result: @@ -446,37 +592,3 @@ def format_date(iso_date: str) -> str: return dt.strftime("%Y-%m-%d %H:%M:%S") except ValueError: return iso_date - - -# Example usage (uncomment to use): -""" -if __name__ == "__main__": - # Set your token here - token = "YOUR_LINEAR_API_KEY" - - linear = LinearConnector(token) - - try: - # Get all issues with comments - issues = linear.get_all_issues() - print(f"Retrieved {len(issues)} issues") - - # Format and print the first issue as markdown - if issues: - issue_md = linear.format_issue_to_markdown(issues[0]) - print("\nSample Issue in Markdown:\n") - print(issue_md) - - # Get issues by date range - start_date = "2023-01-01" - end_date = "2023-01-31" - date_issues, error = linear.get_issues_by_date_range(start_date, end_date) - - if error: - print(f"Error: {error}") - else: - print(f"\nRetrieved {len(date_issues)} issues from {start_date} to {end_date}") - - except Exception as e: - print(f"Error: {e}") -""" diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py index 81f6642f1..e38218a6e 100644 --- a/surfsense_backend/app/connectors/notion_history.py +++ b/surfsense_backend/app/connectors/notion_history.py @@ -1,19 +1,167 @@ +import logging + from notion_client import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.notion_add_connector_route import refresh_notion_token +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class NotionHistoryConnector: - def __init__(self, token): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: NotionAuthCredentialsBase | None = None, + ): """ - Initialize the NotionPageFetcher with a token. + Initialize the NotionHistoryConnector with auto-refresh capability. Args: - token (str): Notion integration token + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided) + """ + self._session = session + self._connector_id = connector_id + self._credentials = credentials + self._notion_client: AsyncClient | None = None + + async def _get_valid_token(self) -> str: + """ + Get valid Notion access token, refreshing if needed. + + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Notion credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Notion credentials: {e!s}" + ) from e + + try: + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Notion credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Notion token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_notion_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + + # Invalidate cached client so it's recreated with new token + self._notion_client = None + + logger.info( + f"Successfully refreshed Notion token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Notion OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token + + async def _get_client(self) -> AsyncClient: + """ + Get or create Notion AsyncClient with valid token. + + Returns: + Notion AsyncClient instance """ - self.notion = AsyncClient(auth=token) + if self._notion_client is None: + token = await self._get_valid_token() + self._notion_client = AsyncClient(auth=token) + return self._notion_client async def close(self): """Close the async client connection.""" - await self.notion.aclose() + if self._notion_client: + await self._notion_client.aclose() + self._notion_client = None async def __aenter__(self): """Async context manager entry.""" @@ -34,6 +182,8 @@ async def get_all_pages(self, start_date=None, end_date=None): Returns: list: List of dictionaries containing page data """ + notion = await self._get_client() + # Build the filter for the search # Note: Notion API requires specific filter structure search_params = {} @@ -67,7 +217,7 @@ async def get_all_pages(self, start_date=None, end_date=None): if cursor: search_params["start_cursor"] = cursor - search_results = await self.notion.search(**search_params) + search_results = await notion.search(**search_params) pages.extend(search_results["results"]) has_more = search_results.get("has_more", False) @@ -125,6 +275,8 @@ async def get_page_content(self, page_id): Returns: list: List of processed blocks from the page """ + notion = await self._get_client() + blocks = [] has_more = True cursor = None @@ -132,11 +284,11 @@ async def get_page_content(self, page_id): # Paginate through all blocks while has_more: if cursor: - response = await self.notion.blocks.children.list( + response = await notion.blocks.children.list( block_id=page_id, start_cursor=cursor ) else: - response = await self.notion.blocks.children.list(block_id=page_id) + response = await notion.blocks.children.list(block_id=page_id) blocks.extend(response["results"]) has_more = response["has_more"] @@ -162,6 +314,8 @@ async def process_block(self, block): Returns: dict: Processed block with content and children """ + notion = await self._get_client() + block_id = block["id"] block_type = block["type"] @@ -174,9 +328,7 @@ async def process_block(self, block): if has_children: # Fetch and process child blocks - children_response = await self.notion.blocks.children.list( - block_id=block_id - ) + children_response = await notion.blocks.children.list(block_id=block_id) for child_block in children_response["results"]: child_blocks.append(await self.process_block(child_block)) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 3c18650ae..1d1fa39ad 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -15,11 +15,13 @@ from .google_gmail_add_connector_route import ( router as google_gmail_add_connector_router, ) +from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .new_chat_routes import router as new_chat_router from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router +from .notion_add_connector_route import router as notion_add_connector_router from .podcasts_routes import router as podcasts_router from .rbac_routes import router as rbac_router from .search_source_connectors_routes import router as search_source_connectors_router @@ -39,7 +41,9 @@ router.include_router(google_gmail_add_connector_router) router.include_router(google_drive_add_connector_router) router.include_router(airtable_add_connector_router) +router.include_router(linear_add_connector_router) router.include_router(luma_add_connector_router) +router.include_router(notion_add_connector_router) router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index 3bcbe4dc0..9284d89e8 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -1,6 +1,5 @@ import base64 import hashlib -import json import logging import secrets from datetime import UTC, datetime, timedelta @@ -23,6 +22,7 @@ ) from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) @@ -40,6 +40,30 @@ "user.email:read", ] +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def make_basic_auth_header(client_id: str, client_secret: str) -> str: credentials = f"{client_id}:{client_secret}".encode() @@ -90,18 +114,19 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us status_code=500, detail="Airtable OAuth not configured." ) + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + # Generate PKCE parameters code_verifier, code_challenge = generate_pkce_pair() - # Generate state parameter - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - "code_verifier": code_verifier, - } + # Generate secure state parameter with HMAC signature (including code_verifier for PKCE) + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state( + space_id, user.id, code_verifier=code_verifier ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() # Build authorization URL auth_params = { @@ -134,8 +159,9 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us @router.get("/auth/airtable/connector/callback") async def airtable_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): """ @@ -143,7 +169,8 @@ async def airtable_callback( Args: request: FastAPI request object - code: Authorization code from Airtable + code: Authorization code from Airtable (if user granted access) + error: Error code from Airtable (if user denied access or error occurred) state: State parameter containing user/space info session: Database session @@ -151,10 +178,42 @@ async def airtable_callback( Redirect response to frontend """ try: - # Decode and parse the state + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Airtable OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=airtable_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() try: - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + data = state_manager.validate_state(state) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid state parameter: {e!s}" @@ -162,7 +221,12 @@ async def airtable_callback( user_id = UUID(data["user_id"]) space_id = data["space_id"] - code_verifier = data["code_verifier"] + code_verifier = data.get("code_verifier") + + if not code_verifier: + raise HTTPException( + status_code=400, detail="Missing code_verifier in state parameter" + ) auth_header = make_basic_auth_header( config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET ) @@ -201,22 +265,38 @@ async def airtable_callback( token_json = token_response.json() + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Airtable" + ) + # Calculate expiration time (UTC, tz-aware) expires_at = None if token_json.get("expires_in"): now_utc = datetime.now(UTC) expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) - # Create credentials object + # Create credentials object with encrypted tokens credentials = AirtableAuthCredentialsBase( - access_token=token_json["access_token"], - refresh_token=token_json.get("refresh_token"), + access_token=token_encryption.encrypt_token(access_token), + refresh_token=token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, token_type=token_json.get("token_type", "Bearer"), expires_in=token_json.get("expires_in"), expires_at=expires_at, scope=token_json.get("scope"), ) + # Mark that tokens are encrypted for backward compatibility + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + # Check if connector already exists for this search space and user existing_connector_result = await session.execute( select(SearchSourceConnector).filter( @@ -230,7 +310,7 @@ async def airtable_callback( if existing_connector: # Update existing connector - existing_connector.config = credentials.to_dict() + existing_connector.config = credentials_dict existing_connector.name = "Airtable Connector" existing_connector.is_indexable = True logger.info( @@ -242,7 +322,7 @@ async def airtable_callback( name="Airtable Connector", connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR, is_indexable=True, - config=credentials.to_dict(), + config=credentials_dict, search_space_id=space_id, user_id=user_id, ) @@ -306,6 +386,21 @@ async def refresh_airtable_token( logger.info(f"Refreshing Airtable token for connector {connector.id}") credentials = AirtableAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + auth_header = make_basic_auth_header( config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET ) @@ -313,7 +408,7 @@ async def refresh_airtable_token( # Prepare token refresh data refresh_data = { "grant_type": "refresh_token", - "refresh_token": credentials.refresh_token, + "refresh_token": refresh_token, "client_id": config.AIRTABLE_CLIENT_ID, "client_secret": config.AIRTABLE_CLIENT_SECRET, } @@ -342,14 +437,29 @@ async def refresh_airtable_token( now_utc = datetime.now(UTC) expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) - # Update credentials object - credentials.access_token = token_json["access_token"] + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Airtable refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) credentials.expires_in = token_json.get("expires_in") credentials.expires_at = expires_at credentials.scope = token_json.get("scope") - # Update connector config - connector.config = credentials.to_dict() + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict await session.commit() await session.refresh(connector) diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index 8bb685450..6c6ae4e40 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -2,7 +2,6 @@ os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" -import base64 import json import logging from uuid import UUID @@ -23,6 +22,7 @@ get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) @@ -31,6 +31,30 @@ SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"] REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def get_google_flow(): try: @@ -59,16 +83,16 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) auth_url, _ = flow.authorization_url( access_type="offline", @@ -86,24 +110,86 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us @router.get("/auth/google/calendar/connector/callback") async def calendar_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): try: - # Decode and parse the state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Calendar OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_calendar_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_CALENDAR_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_CALENDAR_REDIRECT_URI not configured" + ) + flow = get_google_flow() flow.fetch_token(code=code) creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + try: # Check if a connector with the same type already exists for this search space and user result = await session.execute( diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py index 52461319b..6caf3f204 100644 --- a/surfsense_backend/app/routes/google_drive_add_connector_route.py +++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py @@ -10,7 +10,6 @@ - GET /connectors/{connector_id}/google-drive/folders - List user's folders (for index-time selection) """ -import base64 import json import logging import os @@ -37,6 +36,7 @@ get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption # Relax token scope validation for Google OAuth os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" @@ -44,6 +44,31 @@ logger = logging.getLogger(__name__) router = APIRouter() +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + # Google Drive OAuth scopes SCOPES = [ "https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive @@ -90,16 +115,16 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state parameter - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) # Generate authorization URL auth_url, _ = flow.authorization_url( @@ -124,8 +149,9 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) @router.get("/auth/google/drive/connector/callback") async def drive_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): """ @@ -133,15 +159,53 @@ async def drive_callback( Query params: code: Authorization code from Google + error: OAuth error (if user denied access) state: Encoded state with space_id and user_id Returns: Redirect to frontend success page """ try: - # Decode and parse state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Drive OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_drive_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] @@ -150,6 +214,12 @@ async def drive_callback( f"Processing Google Drive callback for user {user_id}, space {space_id}" ) + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_DRIVE_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured" + ) + # Exchange authorization code for tokens flow = get_google_flow() flow.fetch_token(code=code) @@ -157,6 +227,24 @@ async def drive_callback( creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + # Check if connector already exists for this space/user result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index 21fcf2c38..20a51c1a1 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -2,7 +2,6 @@ os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" -import base64 import json import logging from uuid import UUID @@ -23,51 +22,90 @@ get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) router = APIRouter() +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def get_google_flow(): """Create and return a Google OAuth flow for Gmail API.""" - flow = Flow.from_client_config( - { - "web": { - "client_id": config.GOOGLE_OAUTH_CLIENT_ID, - "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET, - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI], - } - }, - scopes=[ - "https://www.googleapis.com/auth/gmail.readonly", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "openid", - ], - ) - flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI - return flow + try: + flow = Flow.from_client_config( + { + "web": { + "client_id": config.GOOGLE_OAUTH_CLIENT_ID, + "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI], + } + }, + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "openid", + ], + ) + flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI + return flow + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to create Google flow: {e!s}" + ) from e @router.get("/auth/google/gmail/connector/add") async def connect_gmail(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Google Gmail OAuth flow. + + Query params: + space_id: Search space ID to add connector to + + Returns: + JSON with auth_url to redirect user to Google authorization + """ try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) auth_url, _ = flow.authorization_url( access_type="offline", @@ -75,8 +113,13 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) include_granted_scopes="true", state=state_encoded, ) + + logger.info( + f"Initiating Google Gmail OAuth for user {user.id}, space {space_id}" + ) return {"auth_url": auth_url} except Exception as e: + logger.error(f"Failed to initiate Google Gmail OAuth: {e!s}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to initiate Google OAuth: {e!s}" ) from e @@ -85,24 +128,99 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) @router.get("/auth/google/gmail/connector/callback") async def gmail_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): + """ + Handle Google Gmail OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Google (if user granted access) + error: Error code from Google (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ try: - # Decode and parse the state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Gmail OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_gmail_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_GMAIL_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_GMAIL_REDIRECT_URI not configured" + ) + flow = get_google_flow() flow.fetch_token(code=code) creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + try: # Check if a connector with the same type already exists for this search space and user result = await session.execute( @@ -160,3 +278,6 @@ async def gmail_callback( raise except Exception as e: logger.error(f"Unexpected error in Gmail callback: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Google Gmail OAuth: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py new file mode 100644 index 000000000..7a7fc196a --- /dev/null +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -0,0 +1,448 @@ +""" +Linear Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Linear connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Linear OAuth endpoints +AUTHORIZATION_URL = "https://linear.app/oauth/authorize" +TOKEN_URL = "https://api.linear.app/oauth/token" + +# OAuth scopes for Linear +SCOPES = ["read", "write"] + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def make_basic_auth_header(client_id: str, client_secret: str) -> str: + """Create Basic Auth header for Linear OAuth.""" + import base64 + + credentials = f"{client_id}:{client_secret}".encode() + b64 = base64.b64encode(credentials).decode("ascii") + return f"Basic {b64}" + + +@router.get("/auth/linear/connector/add") +async def connect_linear(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Linear OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.LINEAR_CLIENT_ID: + raise HTTPException(status_code=500, detail="Linear OAuth not configured.") + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.LINEAR_CLIENT_ID, + "response_type": "code", + "redirect_uri": config.LINEAR_REDIRECT_URI, + "scope": " ".join(SCOPES), + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Linear OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Linear OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Linear OAuth: {e!s}" + ) from e + + +@router.get("/auth/linear/connector/callback") +async def linear_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Linear OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Linear (if user granted access) + error: Error code from Linear (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Linear OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=linear_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + if not config.LINEAR_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="LINEAR_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + auth_header = make_basic_auth_header( + config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET + ) + + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": config.LINEAR_REDIRECT_URI, # Use stored value, not from request + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=token_data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Linear" + ) + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + if token_json.get("expires_in"): + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) + + # Store the encrypted access token and refresh token in connector config + connector_config = { + "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + # Mark that tokens are encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Linear Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Linear connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Linear Connector", + connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Linear connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Linear connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Linear OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Linear OAuth: {e!s}" + ) from e + + +async def refresh_linear_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Linear access token for a connector. + + Args: + session: Database session + connector: Linear connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Linear token for connector {connector.id}") + + credentials = LinearAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Linear refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Linear token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Linear token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Linear token: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/notion_add_connector_route.py b/surfsense_backend/app/routes/notion_add_connector_route.py new file mode 100644 index 000000000..462ac398c --- /dev/null +++ b/surfsense_backend/app/routes/notion_add_connector_route.py @@ -0,0 +1,459 @@ +""" +Notion Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Notion connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Notion OAuth endpoints +AUTHORIZATION_URL = "https://api.notion.com/v1/oauth/authorize" +TOKEN_URL = "https://api.notion.com/v1/oauth/token" + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def make_basic_auth_header(client_id: str, client_secret: str) -> str: + """Create Basic Auth header for Notion OAuth.""" + import base64 + + credentials = f"{client_id}:{client_secret}".encode() + b64 = base64.b64encode(credentials).decode("ascii") + return f"Basic {b64}" + + +@router.get("/auth/notion/connector/add") +async def connect_notion(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Notion OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.NOTION_CLIENT_ID: + raise HTTPException(status_code=500, detail="Notion OAuth not configured.") + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.NOTION_CLIENT_ID, + "response_type": "code", + "owner": "user", # Allows both admins and members to authorize + "redirect_uri": config.NOTION_REDIRECT_URI, + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Notion OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Notion OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Notion OAuth: {e!s}" + ) from e + + +@router.get("/auth/notion/connector/callback") +async def notion_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Notion OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Notion (if user granted access) + error: Error code from Notion (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Notion OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=notion_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + # Note: Notion doesn't send redirect_uri in callback, but we validate + # that we're using the configured one in token exchange + if not config.NOTION_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="NOTION_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + auth_header = make_basic_auth_header( + config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET + ) + + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": config.NOTION_REDIRECT_URI, # Use stored value, not from request + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + json=token_data, + headers={ + "Content-Type": "application/json", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Notion" + ) + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Notion returns access_token, refresh_token (if available), and workspace information + # Store the encrypted tokens and workspace info in connector config + connector_config = { + "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "expires_in": expires_in, + "expires_at": expires_at.isoformat() if expires_at else None, + "workspace_id": token_json.get("workspace_id"), + "workspace_name": token_json.get("workspace_name"), + "workspace_icon": token_json.get("workspace_icon"), + "bot_id": token_json.get("bot_id"), + # Mark that token is encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Notion Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Notion connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Notion Connector", + connector_type=SearchSourceConnectorType.NOTION_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Notion connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Notion connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Notion OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Notion OAuth: {e!s}" + ) from e + + +async def refresh_notion_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Notion access token for a connector. + + Args: + session: Database session + connector: Notion connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Notion token for connector {connector.id}") + + credentials = NotionAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + json=refresh_data, + headers={ + "Content-Type": "application/json", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Notion refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + + # Preserve workspace info + if not credentials.workspace_id: + credentials.workspace_id = connector.config.get("workspace_id") + if not credentials.workspace_name: + credentials.workspace_name = connector.config.get("workspace_name") + if not credentials.workspace_icon: + credentials.workspace_icon = connector.config.get("workspace_icon") + if not credentials.bot_id: + credentials.bot_id = connector.config.get("bot_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Notion token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Notion token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Notion token: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/linear_auth_credentials.py b/surfsense_backend/app/schemas/linear_auth_credentials.py new file mode 100644 index 000000000..99e8d9111 --- /dev/null +++ b/surfsense_backend/app/schemas/linear_auth_credentials.py @@ -0,0 +1,66 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class LinearAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + } + + @classmethod + def from_dict(cls, data: dict) -> "LinearAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/schemas/notion_auth_credentials.py b/surfsense_backend/app/schemas/notion_auth_credentials.py new file mode 100644 index 000000000..e66afb903 --- /dev/null +++ b/surfsense_backend/app/schemas/notion_auth_credentials.py @@ -0,0 +1,72 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class NotionAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + expires_in: int | None = None + expires_at: datetime | None = None + workspace_id: str | None = None + workspace_name: str | None = None + workspace_icon: str | None = None + bot_id: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "workspace_id": self.workspace_id, + "workspace_name": self.workspace_name, + "workspace_icon": self.workspace_icon, + "bot_id": self.bot_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "NotionAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + workspace_id=data.get("workspace_id"), + workspace_name=data.get("workspace_name"), + workspace_icon=data.get("workspace_icon"), + bot_id=data.get("bot_id"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 69b75e5c4..3b87c33f1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -270,7 +270,8 @@ async def stream_new_chat( # Track if we just finished a tool (text flows silently after tools) just_finished_tool: bool = False # Track write_todos calls to show "Creating plan" vs "Updating plan" - write_todos_call_count: int = 0 + # Disabled for now + # write_todos_call_count: int = 0 def next_thinking_step_id() -> str: nonlocal thinking_step_counter @@ -479,60 +480,60 @@ def complete_current_step() -> str | None: status="in_progress", items=last_active_step_items, ) - elif tool_name == "write_todos": - # Track write_todos calls for better messaging - write_todos_call_count += 1 - todos = ( - tool_input.get("todos", []) - if isinstance(tool_input, dict) - else [] - ) - todo_count = len(todos) if isinstance(todos, list) else 0 - - if write_todos_call_count == 1: - # First call - creating the plan - last_active_step_title = "Creating plan" - last_active_step_items = [f"Defining {todo_count} tasks..."] - else: - # Subsequent calls - updating the plan - # Try to provide context about what's being updated - in_progress_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ) - if isinstance(todos, list) - else 0 - ) - completed_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "completed" - ) - if isinstance(todos, list) - else 0 - ) - - last_active_step_title = "Updating progress" - last_active_step_items = ( - [ - f"Progress: {completed_count}/{todo_count} completed", - f"In progress: {in_progress_count} tasks", - ] - if completed_count > 0 - else [f"Working on {todo_count} tasks"] - ) + # elif tool_name == "write_todos": # Disabled for now + # # Track write_todos calls for better messaging + # write_todos_call_count += 1 + # todos = ( + # tool_input.get("todos", []) + # if isinstance(tool_input, dict) + # else [] + # ) + # todo_count = len(todos) if isinstance(todos, list) else 0 + + # if write_todos_call_count == 1: + # # First call - creating the plan + # last_active_step_title = "Creating plan" + # last_active_step_items = [f"Defining {todo_count} tasks..."] + # else: + # # Subsequent calls - updating the plan + # # Try to provide context about what's being updated + # in_progress_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ) + # if isinstance(todos, list) + # else 0 + # ) + # completed_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "completed" + # ) + # if isinstance(todos, list) + # else 0 + # ) + + # last_active_step_title = "Updating progress" + # last_active_step_items = ( + # [ + # f"Progress: {completed_count}/{todo_count} completed", + # f"In progress: {in_progress_count} tasks", + # ] + # if completed_count > 0 + # else [f"Working on {todo_count} tasks"] + # ) - yield streaming_service.format_thinking_step( - step_id=tool_step_id, - title=last_active_step_title, - status="in_progress", - items=last_active_step_items, - ) + # yield streaming_service.format_thinking_step( + # step_id=tool_step_id, + # title=last_active_step_title, + # status="in_progress", + # items=last_active_step_items, + # ) elif tool_name == "generate_podcast": podcast_title = ( tool_input.get("podcast_title", "SurfSense Podcast") @@ -596,10 +597,12 @@ def complete_current_step() -> str | None: raw_output = event.get("data", {}).get("output", "") # Handle deepagents' write_todos Command object specially - if tool_name == "write_todos" and hasattr(raw_output, "update"): - # deepagents returns a Command object - extract todos directly - tool_output = extract_todos_from_deepagents(raw_output) - elif hasattr(raw_output, "content"): + # Disabled for now + # if tool_name == "write_todos" and hasattr(raw_output, "update"): + # # deepagents returns a Command object - extract todos directly + # tool_output = extract_todos_from_deepagents(raw_output) + # elif hasattr(raw_output, "content"): + if hasattr(raw_output, "content"): # It's a ToolMessage object - extract the content content = raw_output.content # If content is a string that looks like JSON, try to parse it @@ -758,63 +761,63 @@ def complete_current_step() -> str | None: status="completed", items=completed_items, ) - elif tool_name == "write_todos": - # Build completion items for planning/updating - if isinstance(tool_output, dict): - todos = tool_output.get("todos", []) - todo_count = len(todos) if isinstance(todos, list) else 0 - completed_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "completed" - ) - if isinstance(todos, list) - else 0 - ) - in_progress_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ) - if isinstance(todos, list) - else 0 - ) - - # Use context-aware completion message - if last_active_step_title == "Creating plan": - completed_items = [f"Created {todo_count} tasks"] - else: - # Updating progress - show stats - completed_items = [ - f"Progress: {completed_count}/{todo_count} completed", - ] - if in_progress_count > 0: - # Find the currently in-progress task name - in_progress_task = next( - ( - t.get("content", "")[:40] - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ), - None, - ) - if in_progress_task: - completed_items.append( - f"Current: {in_progress_task}..." - ) - else: - completed_items = ["Plan updated"] - yield streaming_service.format_thinking_step( - step_id=original_step_id, - title=last_active_step_title, - status="completed", - items=completed_items, - ) + # elif tool_name == "write_todos": # Disabled for now + # # Build completion items for planning/updating + # if isinstance(tool_output, dict): + # todos = tool_output.get("todos", []) + # todo_count = len(todos) if isinstance(todos, list) else 0 + # completed_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "completed" + # ) + # if isinstance(todos, list) + # else 0 + # ) + # in_progress_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ) + # if isinstance(todos, list) + # else 0 + # ) + + # # Use context-aware completion message + # if last_active_step_title == "Creating plan": + # completed_items = [f"Created {todo_count} tasks"] + # else: + # # Updating progress - show stats + # completed_items = [ + # f"Progress: {completed_count}/{todo_count} completed", + # ] + # if in_progress_count > 0: + # # Find the currently in-progress task name + # in_progress_task = next( + # ( + # t.get("content", "")[:40] + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ), + # None, + # ) + # if in_progress_task: + # completed_items.append( + # f"Current: {in_progress_task}..." + # ) + # else: + # completed_items = ["Plan updated"] + # yield streaming_service.format_thinking_step( + # step_id=original_step_id, + # title=last_active_step_title, + # status="completed", + # items=completed_items, + # ) elif tool_name == "ls": # Build completion items showing file names found if isinstance(tool_output, dict): @@ -992,27 +995,27 @@ def complete_current_step() -> str | None: yield streaming_service.format_terminal_info( "Knowledge base search completed", "success" ) - elif tool_name == "write_todos": - # Stream the full write_todos result so frontend can render the Plan component - yield streaming_service.format_tool_output_available( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - # Send terminal message with plan info - if isinstance(tool_output, dict): - todos = tool_output.get("todos", []) - todo_count = len(todos) if isinstance(todos, list) else 0 - yield streaming_service.format_terminal_info( - f"Plan created ({todo_count} tasks)", - "success", - ) - else: - yield streaming_service.format_terminal_info( - "Plan created", - "success", - ) + # elif tool_name == "write_todos": # Disabled for now + # # Stream the full write_todos result so frontend can render the Plan component + # yield streaming_service.format_tool_output_available( + # tool_call_id, + # tool_output + # if isinstance(tool_output, dict) + # else {"result": tool_output}, + # ) + # # Send terminal message with plan info + # if isinstance(tool_output, dict): + # todos = tool_output.get("todos", []) + # todo_count = len(todos) if isinstance(todos, list) else 0 + # yield streaming_service.format_terminal_info( + # f"Plan created ({todo_count} tasks)", + # "success", + # ) + # else: + # yield streaming_service.format_terminal_info( + # "Plan created", + # "success", + # ) else: # Default handling for other tools yield streaming_service.format_tool_output_available( diff --git a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py index cea2a0529..3ea6dccc9 100644 --- a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py @@ -18,6 +18,7 @@ generate_document_summary, generate_unique_identifier_hash, ) +from app.utils.oauth_security import TokenEncryption from .base import ( calculate_date_range, @@ -85,7 +86,52 @@ async def index_airtable_records( return 0, f"Connector with ID {connector_id} not found" # Create credentials from connector config - config_data = connector.config + config_data = ( + connector.config.copy() + ) # Work with a copy to avoid modifying original + + # Decrypt tokens if they are encrypted (only when explicitly marked) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted: + # Tokens are explicitly marked as encrypted, attempt decryption + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but tokens are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return 0, "SECRET_KEY not configured but tokens are marked as encrypted" + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt access_token + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + logger.info( + f"Decrypted Airtable access token for connector {connector_id}" + ) + + # Decrypt refresh_token if present + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + logger.info( + f"Decrypted Airtable refresh token for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Airtable tokens for connector {connector_id}: {e!s}", + "Token decryption failed", + {"error_type": "TokenDecryptionError"}, + ) + return 0, f"Failed to decrypt Airtable tokens: {e!s}" + # If _token_encrypted is False or not set, treat tokens as plaintext + try: credentials = AirtableAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index a5d2bc73a..499f01d66 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -8,7 +8,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.db import Document, DocumentType, SearchSourceConnectorType from app.services.llm_service import get_user_long_context_llm @@ -84,15 +83,52 @@ async def index_google_calendar_events( return 0, f"Connector with ID {connector_id} not found" # Get the Google Calendar credentials from the connector config - exp = connector.config.get("expiry").replace("Z", "") + config_data = connector.config + + # Decrypt sensitive credentials if encrypted (for backward compatibility) + from app.config import config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Calendar credentials for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}", + "Credential decryption failed", + {"error_type": "CredentialDecryptionError"}, + ) + return 0, f"Failed to decrypt Google Calendar credentials: {e!s}" + + exp = config_data.get("expiry", "").replace("Z", "") credentials = Credentials( - token=connector.config.get("token"), - refresh_token=connector.config.get("refresh_token"), - token_uri=connector.config.get("token_uri"), - client_id=connector.config.get("client_id"), - client_secret=connector.config.get("client_secret"), - scopes=connector.config.get("scopes"), - expiry=datetime.fromisoformat(exp), + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes"), + expiry=datetime.fromisoformat(exp) if exp else None, ) if ( @@ -122,6 +158,12 @@ async def index_google_calendar_events( connector_id=connector_id, ) + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None + # Calculate date range if start_date is None or end_date is None: # Fall back to calculating dates based on last_indexed_at diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 343d44072..9eeb46fc8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.config import config from app.connectors.google_drive import ( GoogleDriveClient, categorize_change, @@ -87,6 +88,26 @@ async def index_google_drive_files( {"stage": "client_initialization"}, ) + # Check if credentials are encrypted (only when explicitly marked) + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted: + # Credentials are explicitly marked as encrypted, will be decrypted during client initialization + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return ( + 0, + "SECRET_KEY not configured but credentials are marked as encrypted", + ) + logger.info( + f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + ) + # If _token_encrypted is False or not set, treat credentials as plaintext + drive_client = GoogleDriveClient(session, connector_id) if not folder_id: @@ -249,6 +270,26 @@ async def index_google_drive_single_file( {"stage": "client_initialization"}, ) + # Check if credentials are encrypted (only when explicitly marked) + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted: + # Credentials are explicitly marked as encrypted, will be decrypted during client initialization + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return ( + 0, + "SECRET_KEY not configured but credentials are marked as encrypted", + ) + logger.info( + f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + ) + # If _token_encrypted is False or not set, treat credentials as plaintext + drive_client = GoogleDriveClient(session, connector_id) # Fetch the file metadata diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index d350411e1..e10297057 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -8,7 +8,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config from app.connectors.google_gmail_connector import GoogleGmailConnector from app.db import ( Document, @@ -88,9 +87,47 @@ async def index_google_gmail_messages( ) return 0, error_msg - # Create credentials from connector config + # Get the Google Gmail credentials from the connector config config_data = connector.config - exp = config_data.get("expiry").replace("Z", "") + + # Decrypt sensitive credentials if encrypted (for backward compatibility) + from app.config import config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Gmail credentials for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}", + "Credential decryption failed", + {"error_type": "CredentialDecryptionError"}, + ) + return 0, f"Failed to decrypt Google Gmail credentials: {e!s}" + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") credentials = Credentials( token=config_data.get("token"), refresh_token=config_data.get("refresh_token"), @@ -98,7 +135,7 @@ async def index_google_gmail_messages( client_id=config_data.get("client_id"), client_secret=config_data.get("client_secret"), scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp), + expiry=datetime.fromisoformat(exp) if exp else None, ) if ( diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py index afc9ffd3b..f1bfd42e8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py @@ -92,25 +92,34 @@ async def index_linear_issues( f"Connector with ID {connector_id} not found or is not a Linear connector", ) - # Get the Linear token from the connector config - linear_token = connector.config.get("LINEAR_API_KEY") - if not linear_token: + # Check if access_token exists (support both new OAuth format and old API key format) + if not connector.config.get("access_token") and not connector.config.get( + "LINEAR_API_KEY" + ): await task_logger.log_task_failure( log_entry, - f"Linear API token not found in connector config for connector {connector_id}", - "Missing Linear token", + f"Linear access token not found in connector config for connector {connector_id}", + "Missing Linear access token", {"error_type": "MissingToken"}, ) - return 0, "Linear API token not found in connector config" + return 0, "Linear access token not found in connector config" - # Initialize Linear client + # Initialize Linear client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Linear client for connector {connector_id}", {"stage": "client_initialization"}, ) - linear_client = LinearConnector(token=linear_token) + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + linear_client = LinearConnector(session=session, connector_id=connector_id) + + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None # Calculate date range start_date_str, end_date_str = calculate_date_range( @@ -131,7 +140,7 @@ async def index_linear_issues( # Get issues within date range try: - issues, error = linear_client.get_issues_by_date_range( + issues, error = await linear_client.get_issues_by_date_range( start_date=start_date_str, end_date=end_date_str, include_comments=True ) diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index 332d3e39d..13923269d 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -2,7 +2,7 @@ Notion connector indexer. """ -from datetime import datetime, timedelta +from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession @@ -20,6 +20,7 @@ from .base import ( build_document_metadata_string, + calculate_date_range, check_document_by_unique_identifier, get_connector_by_id, get_current_timestamp, @@ -91,18 +92,19 @@ async def index_notion_pages( f"Connector with ID {connector_id} not found or is not a Notion connector", ) - # Get the Notion token from the connector config - notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN") - if not notion_token: + # Check if access_token exists (support both new OAuth format and old integration token format) + if not connector.config.get("access_token") and not connector.config.get( + "NOTION_INTEGRATION_TOKEN" + ): await task_logger.log_task_failure( log_entry, - f"Notion integration token not found in connector config for connector {connector_id}", - "Missing Notion token", + f"Notion access token not found in connector config for connector {connector_id}", + "Missing Notion access token", {"error_type": "MissingToken"}, ) - return 0, "Notion integration token not found in connector config" + return 0, "Notion access token not found in connector config" - # Initialize Notion client + # Initialize Notion client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Notion client for connector {connector_id}", @@ -111,40 +113,30 @@ async def index_notion_pages( logger.info(f"Initializing Notion client for connector {connector_id}") - # Calculate date range - if start_date is None or end_date is None: - # Fall back to calculating dates - calculated_end_date = datetime.now() - calculated_start_date = calculated_end_date - timedelta( - days=365 - ) # Check for last 1 year of pages - - # Use calculated dates if not provided - if start_date is None: - start_date_iso = calculated_start_date.strftime("%Y-%m-%dT%H:%M:%SZ") - else: - # Convert YYYY-MM-DD to ISO format - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None - if end_date is None: - end_date_iso = calculated_end_date.strftime("%Y-%m-%dT%H:%M:%SZ") - else: - # Convert YYYY-MM-DD to ISO format - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) - else: - # Convert provided dates to ISO format for Notion API - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + # Calculate date range using the shared utility function + start_date_str, end_date_str = calculate_date_range( + connector, start_date, end_date, default_days_back=365 + ) - notion_client = NotionHistoryConnector(token=notion_token) + # Convert YYYY-MM-DD to ISO format for Notion API + start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + end_date_iso = datetime.strptime(end_date_str, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + notion_client = NotionHistoryConnector( + session=session, connector_id=connector_id + ) logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}") diff --git a/surfsense_backend/app/utils/oauth_security.py b/surfsense_backend/app/utils/oauth_security.py new file mode 100644 index 000000000..5135cdef4 --- /dev/null +++ b/surfsense_backend/app/utils/oauth_security.py @@ -0,0 +1,210 @@ +""" +OAuth Security Utilities. + +Provides secure state parameter generation/validation and token encryption +for OAuth 2.0 flows. +""" + +import base64 +import hashlib +import hmac +import json +import logging +import time +from uuid import UUID + +from cryptography.fernet import Fernet +from fastapi import HTTPException + +logger = logging.getLogger(__name__) + + +class OAuthStateManager: + """Manages secure OAuth state parameters with HMAC signatures.""" + + def __init__(self, secret_key: str, max_age_seconds: int = 600): + """ + Initialize OAuth state manager. + + Args: + secret_key: Secret key for HMAC signing (should be SECRET_KEY from config) + max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes) + """ + if not secret_key: + raise ValueError("secret_key is required for OAuth state management") + self.secret_key = secret_key + self.max_age_seconds = max_age_seconds + + def generate_secure_state( + self, space_id: int, user_id: UUID, **extra_fields + ) -> str: + """ + Generate cryptographically signed state parameter. + + Args: + space_id: The search space ID + user_id: The user ID + **extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE) + + Returns: + Base64-encoded state parameter with HMAC signature + """ + timestamp = int(time.time()) + state_payload = { + "space_id": space_id, + "user_id": str(user_id), + "timestamp": timestamp, + } + + # Add any extra fields (e.g., code_verifier for PKCE) + state_payload.update(extra_fields) + + # Create signature + payload_str = json.dumps(state_payload, sort_keys=True) + signature = hmac.new( + self.secret_key.encode(), + payload_str.encode(), + hashlib.sha256, + ).hexdigest() + + # Include signature in state + state_payload["signature"] = signature + state_encoded = base64.urlsafe_b64encode( + json.dumps(state_payload).encode() + ).decode() + + return state_encoded + + def validate_state(self, state: str) -> dict: + """ + Validate and decode state parameter with signature verification. + + Args: + state: The state parameter from OAuth callback + + Returns: + Decoded state data (space_id, user_id, timestamp) + + Raises: + HTTPException: If state is invalid, expired, or tampered with + """ + try: + decoded = base64.urlsafe_b64decode(state.encode()).decode() + data = json.loads(decoded) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state format: {e!s}" + ) from e + + # Verify signature exists + signature = data.pop("signature", None) + if not signature: + raise HTTPException(status_code=400, detail="Missing state signature") + + # Verify signature + payload_str = json.dumps(data, sort_keys=True) + expected_signature = hmac.new( + self.secret_key.encode(), + payload_str.encode(), + hashlib.sha256, + ).hexdigest() + + if not hmac.compare_digest(signature, expected_signature): + raise HTTPException( + status_code=400, detail="Invalid state signature - possible tampering" + ) + + # Verify timestamp (prevent replay attacks) + timestamp = data.get("timestamp", 0) + current_time = time.time() + age = current_time - timestamp + + if age < 0: + raise HTTPException(status_code=400, detail="Invalid state timestamp") + + if age > self.max_age_seconds: + raise HTTPException( + status_code=400, + detail="State parameter expired. Please try again.", + ) + + return data + + +class TokenEncryption: + """Encrypt/decrypt sensitive OAuth tokens for storage.""" + + def __init__(self, secret_key: str): + """ + Initialize token encryption. + + Args: + secret_key: Secret key for encryption (should be SECRET_KEY from config) + """ + if not secret_key: + raise ValueError("secret_key is required for token encryption") + # Derive Fernet key from secret using SHA256 + # Note: In production, consider using HKDF for key derivation + key = base64.urlsafe_b64encode(hashlib.sha256(secret_key.encode()).digest()) + try: + self.cipher = Fernet(key) + except Exception as e: + raise ValueError(f"Failed to initialize encryption cipher: {e!s}") from e + + def encrypt_token(self, token: str) -> str: + """ + Encrypt a token for storage. + + Args: + token: Plaintext token to encrypt + + Returns: + Encrypted token string + """ + if not token: + return token + try: + return self.cipher.encrypt(token.encode()).decode() + except Exception as e: + logger.error(f"Failed to encrypt token: {e!s}") + raise ValueError(f"Token encryption failed: {e!s}") from e + + def decrypt_token(self, encrypted_token: str) -> str: + """ + Decrypt a stored token. + + Args: + encrypted_token: Encrypted token string + + Returns: + Decrypted plaintext token + """ + if not encrypted_token: + return encrypted_token + try: + return self.cipher.decrypt(encrypted_token.encode()).decode() + except Exception as e: + logger.error(f"Failed to decrypt token: {e!s}") + raise ValueError(f"Token decryption failed: {e!s}") from e + + def is_encrypted(self, token: str) -> bool: + """ + Check if a token appears to be encrypted. + + Args: + token: Token string to check + + Returns: + True if token appears encrypted, False otherwise + """ + if not token: + return False + # Encrypted tokens are base64-encoded and have specific format + # This is a heuristic check - encrypted tokens are longer and base64-like + try: + # Try to decode as base64 + base64.urlsafe_b64decode(token.encode()) + # If it's base64 and reasonably long, likely encrypted + return len(token) > 20 + except Exception: + return False diff --git a/surfsense_backend/app/utils/validators.py b/surfsense_backend/app/utils/validators.py index 6b69fb3e1..d6622bafd 100644 --- a/surfsense_backend/app/utils/validators.py +++ b/surfsense_backend/app/utils/validators.py @@ -514,10 +514,6 @@ def validate_initial_urls() -> None: "validators": {}, }, "SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}}, - "NOTION_CONNECTOR": { - "required": ["NOTION_INTEGRATION_TOKEN"], - "validators": {}, - }, "GITHUB_CONNECTOR": { "required": ["GITHUB_PAT", "repo_full_names"], "validators": { @@ -526,7 +522,6 @@ def validate_initial_urls() -> None: ) }, }, - "LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}}, "DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}}, "JIRA_CONNECTOR": { "required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"], diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 35a096497..b1abd647f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -20,7 +20,7 @@ import { } from "@/atoms/chat/mentioned-documents.atom"; import { clearPlanOwnerRegistry, - extractWriteTodosFromContent, + // extractWriteTodosFromContent, hydratePlanStateAtom, } from "@/atoms/chat/plan-state.atom"; import { Thread } from "@/components/assistant-ui/thread"; @@ -30,7 +30,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image"; import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview"; import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage"; -import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; +// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; import { getBearerToken } from "@/lib/auth-utils"; import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter"; import { @@ -199,7 +199,7 @@ const TOOLS_WITH_UI = new Set([ "link_preview", "display_image", "scrape_webpage", - "write_todos", + // "write_todos", // Disabled for now ]); /** @@ -291,10 +291,11 @@ export default function NewChatPage() { restoredThinkingSteps.set(`msg-${msg.id}`, steps); } // Hydrate write_todos plan state from persisted tool calls - const writeTodosCalls = extractWriteTodosFromContent(msg.content); - for (const todoData of writeTodosCalls) { - hydratePlanState(todoData); - } + // Disabled for now + // const writeTodosCalls = extractWriteTodosFromContent(msg.content); + // for (const todoData of writeTodosCalls) { + // hydratePlanState(todoData); + // } } if (msg.role === "user") { const docs = extractMentionedDocuments(msg.content); @@ -911,7 +912,7 @@ export default function NewChatPage() { - + {/* Disabled for now */}
{ )} - + {/* YouTube Crawler View - shown when adding YouTube videos */} {isYouTubeView && searchSpaceId ? ( @@ -272,7 +272,7 @@ export const ConnectorIndicator: FC = () => { {/* Content */}
-
+
= ({ id, title, @@ -86,13 +125,13 @@ export const ConnectorCard: FC = ({ // Show last indexed date for connected connectors if (lastIndexedAt) { return ( - - Last indexed: {format(new Date(lastIndexedAt), "MMM d, yyyy")} + + Last indexed: {formatLastIndexedDate(lastIndexedAt)} ); } // Fallback for connected but never indexed - return Never indexed; + return Never indexed; } return description; @@ -113,9 +152,9 @@ export const ConnectorCard: FC = ({
{title}
-
{getStatusContent()}
+
{getStatusContent()}
{isConnected && documentCount !== undefined && ( -

+

{formatDocumentCount(documentCount)}

)} @@ -130,12 +169,10 @@ export const ConnectorCard: FC = ({ !isConnected && "shadow-xs" )} onClick={isConnected ? onManage : onConnect} - disabled={isConnecting || isIndexing} + disabled={isConnecting} > {isConnecting ? ( - ) : isIndexing ? ( - "Syncing..." ) : isConnected ? ( "Manage" ) : id === "youtube-crawler" ? ( diff --git a/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx b/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx index a18c79a1f..34e1ae2e9 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx @@ -24,20 +24,20 @@ export const ConnectorDialogHeader: FC = ({ return (
- + Connectors - + Search across all your apps and data in one place. -
+
= ({
- + = ({ + {/* Back button - only show if not from OAuth */} + {!isFromOAuth && ( + + )} {/* Success header */}
@@ -187,15 +193,7 @@ export const IndexingConfigurationView: FC = ({
{/* Fixed Footer - Action buttons */} -
- +
@@ -163,9 +199,8 @@ export const ActiveConnectorsTab: FC = ({ size="sm" className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" onClick={onManage ? () => onManage(connector) : undefined} - disabled={isIndexing} > - {isIndexing ? "Syncing..." : "Manage"} + Manage
); diff --git a/surfsense_web/components/assistant-ui/document-upload-popup.tsx b/surfsense_web/components/assistant-ui/document-upload-popup.tsx index d1fa208d2..da3b820e5 100644 --- a/surfsense_web/components/assistant-ui/document-upload-popup.tsx +++ b/surfsense_web/components/assistant-ui/document-upload-popup.tsx @@ -1,5 +1,6 @@ "use client"; +import { Upload } from "lucide-react"; import { useAtomValue } from "jotai"; import { useRouter } from "next/navigation"; import { @@ -85,6 +86,7 @@ const DocumentUploadPopupContent: FC<{ }> = ({ isOpen, onOpenChange }) => { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const router = useRouter(); + const [isAccordionExpanded, setIsAccordionExpanded] = useState(false); if (!searchSpaceId) return null; @@ -95,16 +97,40 @@ const DocumentUploadPopupContent: FC<{ return ( - + Upload Document + + {/* Fixed Header */} +
+ {/* Upload header */} +
+
+ +
+
+

Upload Documents

+

+ Upload and sync your documents to your search space +

+
+
+
+ + {/* Scrollable Content */}
-
-
- +
+
+
- {/* Bottom fade shadow */} -
+ {/* Bottom fade shadow - only show when scrolling */} + {isAccordionExpanded && ( +
+ )}
diff --git a/surfsense_web/components/editConnector/types.ts b/surfsense_web/components/editConnector/types.ts index e17a3513a..43fab23e0 100644 --- a/surfsense_web/components/editConnector/types.ts +++ b/surfsense_web/components/editConnector/types.ts @@ -36,7 +36,6 @@ export const editConnectorSchema = z.object({ SEARXNG_LANGUAGE: z.string().optional(), SEARXNG_SAFESEARCH: z.string().optional(), SEARXNG_VERIFY_SSL: z.string().optional(), - LINEAR_API_KEY: z.string().optional(), LINKUP_API_KEY: z.string().optional(), DISCORD_BOT_TOKEN: z.string().optional(), CONFLUENCE_BASE_URL: z.string().optional(), diff --git a/surfsense_web/components/new-chat/document-mention-picker.tsx b/surfsense_web/components/new-chat/document-mention-picker.tsx index 2d5d46267..7a9e7aaa5 100644 --- a/surfsense_web/components/new-chat/document-mention-picker.tsx +++ b/surfsense_web/components/new-chat/document-mention-picker.tsx @@ -12,7 +12,7 @@ import { useState, } from "react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; -import type { Document } from "@/contracts/types/document.types"; +import type { Document, GetDocumentsResponse } from "@/contracts/types/document.types"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cn } from "@/lib/utils"; @@ -31,6 +31,8 @@ interface DocumentMentionPickerProps { externalSearch?: string; } +const PAGE_SIZE = 20; + function useDebounced(value: T, delay = 300) { const [debounced, setDebounced] = useState(value); useEffect(() => { @@ -52,12 +54,29 @@ export const DocumentMentionPicker = forwardRef< const debouncedSearch = useDebounced(search, 150); const [highlightedIndex, setHighlightedIndex] = useState(0); const itemRefs = useRef>(new Map()); + const scrollContainerRef = useRef(null); + + // State for pagination + const [accumulatedDocuments, setAccumulatedDocuments] = useState([]); + const [currentPage, setCurrentPage] = useState(0); + const [hasMore, setHasMore] = useState(false); + const [isLoadingMore, setIsLoadingMore] = useState(false); + + // Reset pagination when search or search space changes + // biome-ignore lint/correctness/useExhaustiveDependencies: intentionally reset pagination when search/space changes + useEffect(() => { + setAccumulatedDocuments([]); + setCurrentPage(0); + setHasMore(false); + setHighlightedIndex(0); + }, [debouncedSearch, searchSpaceId]); + // Query params for initial fetch (page 0) const fetchQueryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: 0, - page_size: 20, + page_size: PAGE_SIZE, }), [searchSpaceId] ); @@ -66,31 +85,97 @@ export const DocumentMentionPicker = forwardRef< return { search_space_id: searchSpaceId, page: 0, - page_size: 20, + page_size: PAGE_SIZE, title: debouncedSearch, }; }, [debouncedSearch, searchSpaceId]); - // Use query for fetching documents + // Use query for fetching first page of documents const { data: documents, isLoading: isDocumentsLoading } = useQuery({ queryKey: cacheKeys.documents.withQueryParams(fetchQueryParams), queryFn: () => documentsApiService.getDocuments({ queryParams: fetchQueryParams }), staleTime: 3 * 60 * 1000, - enabled: !!searchSpaceId && !debouncedSearch.trim(), + enabled: !!searchSpaceId && !debouncedSearch.trim() && currentPage === 0, }); - // Searching + // Searching - first page const { data: searchedDocuments, isLoading: isSearchedDocumentsLoading } = useQuery({ queryKey: cacheKeys.documents.withQueryParams(searchQueryParams), queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }), staleTime: 3 * 60 * 1000, - enabled: !!searchSpaceId && !!debouncedSearch.trim(), + enabled: !!searchSpaceId && !!debouncedSearch.trim() && currentPage === 0, }); - const actualDocuments = debouncedSearch.trim() - ? searchedDocuments?.items || [] - : documents?.items || []; - const actualLoading = debouncedSearch.trim() ? isSearchedDocumentsLoading : isDocumentsLoading; + // Update accumulated documents when first page loads + useEffect(() => { + if (currentPage === 0) { + if (debouncedSearch.trim()) { + if (searchedDocuments) { + setAccumulatedDocuments(searchedDocuments.items); + setHasMore(searchedDocuments.has_more); + } + } else { + if (documents) { + setAccumulatedDocuments(documents.items); + setHasMore(documents.has_more); + } + } + } + }, [documents, searchedDocuments, debouncedSearch, currentPage]); + + // Function to load next page + const loadNextPage = useCallback(async () => { + if (isLoadingMore || !hasMore) return; + + const nextPage = currentPage + 1; + setIsLoadingMore(true); + + try { + let response: GetDocumentsResponse; + if (debouncedSearch.trim()) { + const queryParams = { + search_space_id: searchSpaceId, + page: nextPage, + page_size: PAGE_SIZE, + title: debouncedSearch, + }; + response = await documentsApiService.searchDocuments({ queryParams }); + } else { + const queryParams = { + search_space_id: searchSpaceId, + page: nextPage, + page_size: PAGE_SIZE, + }; + response = await documentsApiService.getDocuments({ queryParams }); + } + + setAccumulatedDocuments((prev) => [...prev, ...response.items]); + setHasMore(response.has_more); + setCurrentPage(nextPage); + } catch (error) { + console.error("Failed to load next page:", error); + } finally { + setIsLoadingMore(false); + } + }, [currentPage, hasMore, isLoadingMore, debouncedSearch, searchSpaceId]); + + // Infinite scroll handler + const handleScroll = useCallback( + (e: React.UIEvent) => { + const target = e.currentTarget; + const scrollBottom = target.scrollHeight - target.scrollTop - target.clientHeight; + + // Load more when within 50px of bottom + if (scrollBottom < 50 && hasMore && !isLoadingMore) { + loadNextPage(); + } + }, + [hasMore, isLoadingMore, loadNextPage] + ); + + const actualDocuments = accumulatedDocuments; + const actualLoading = + (debouncedSearch.trim() ? isSearchedDocumentsLoading : isDocumentsLoading) && currentPage === 0; // Track already selected document IDs const selectedIds = useMemo( @@ -184,8 +269,12 @@ export const DocumentMentionPicker = forwardRef< role="listbox" tabIndex={-1} > - {/* Document List - Shows max 3 items on mobile, 5 items on desktop */} -
+ {/* Document List - Shows max 5 items on mobile, 7-8 items on desktop */} +
{actualLoading ? (
@@ -235,6 +324,12 @@ export const DocumentMentionPicker = forwardRef< ); })} + {/* Loading indicator for additional pages */} + {isLoadingMore && ( +
+
+
+ )}
)}
diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 5280ea850..0b7f7b51f 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -31,6 +31,7 @@ import { GridPattern } from "./GridPattern"; interface DocumentUploadTabProps { searchSpaceId: string; onSuccess?: () => void; + onAccordionStateChange?: (isExpanded: boolean) => void; } const audioFileTypes = { @@ -109,11 +110,16 @@ const FILE_TYPE_CONFIG: Record> = { const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5"; -export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTabProps) { +export function DocumentUploadTab({ + searchSpaceId, + onSuccess, + onAccordionStateChange, +}: DocumentUploadTabProps) { const t = useTranslations("upload_documents"); const router = useRouter(); const [files, setFiles] = useState([]); const [uploadProgress, setUploadProgress] = useState(0); + const [accordionValue, setAccordionValue] = useState(""); const [uploadDocumentMutation] = useAtom(uploadDocumentMutationAtom); const { mutate: uploadDocuments, isPending: isUploading } = uploadDocumentMutation; const fileInputRef = useRef(null); @@ -154,6 +160,15 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa const totalFileSize = files.reduce((total, file) => total + file.size, 0); + // Track accordion state changes + const handleAccordionChange = useCallback( + (value: string) => { + setAccordionValue(value); + onAccordionStateChange?.(value === "supported-file-types"); + }, + [onAccordionStateChange] + ); + const handleUpload = async () => { setUploadProgress(0); trackDocumentUploadStarted(Number(searchSpaceId), files.length, totalFileSize); @@ -190,11 +205,13 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa initial={{ opacity: 0, y: 20 }} animate={{ opacity: 1, y: 0 }} transition={{ duration: 0.3 }} - className="space-y-3 sm:space-y-6 max-w-4xl mx-auto" + className="space-y-3 sm:space-y-6 max-w-4xl mx-auto pt-0" > - - - {t("file_size_limit")} + + + + {t("file_size_limit")} + @@ -366,11 +383,13 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa - -
+ +
diff --git a/surfsense_web/hooks/use-connector-edit-page.ts b/surfsense_web/hooks/use-connector-edit-page.ts index 8fa690d04..3beb80247 100644 --- a/surfsense_web/hooks/use-connector-edit-page.ts +++ b/surfsense_web/hooks/use-connector-edit-page.ts @@ -86,7 +86,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) SEARXNG_LANGUAGE: "", SEARXNG_SAFESEARCH: "", SEARXNG_VERIFY_SSL: "", - LINEAR_API_KEY: "", DISCORD_BOT_TOKEN: "", CONFLUENCE_BASE_URL: "", CONFLUENCE_EMAIL: "", @@ -134,7 +133,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) config.SEARXNG_VERIFY_SSL !== undefined && config.SEARXNG_VERIFY_SSL !== null ? String(config.SEARXNG_VERIFY_SSL) : "", - LINEAR_API_KEY: config.LINEAR_API_KEY || "", LINKUP_API_KEY: config.LINKUP_API_KEY || "", DISCORD_BOT_TOKEN: config.DISCORD_BOT_TOKEN || "", CONFLUENCE_BASE_URL: config.CONFLUENCE_BASE_URL || "", @@ -384,16 +382,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) break; } - case "LINEAR_CONNECTOR": - if (formData.LINEAR_API_KEY !== originalConfig.LINEAR_API_KEY) { - if (!formData.LINEAR_API_KEY) { - toast.error("Linear API Key cannot be empty."); - setIsSaving(false); - return; - } - newConfig = { LINEAR_API_KEY: formData.LINEAR_API_KEY }; - } - break; case "LINKUP_API": if (formData.LINKUP_API_KEY !== originalConfig.LINKUP_API_KEY) { if (!formData.LINKUP_API_KEY) { @@ -599,8 +587,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) "SEARXNG_VERIFY_SSL", verifyValue === null ? "" : String(verifyValue) ); - } else if (connector.connector_type === "LINEAR_CONNECTOR") { - editForm.setValue("LINEAR_API_KEY", newlySavedConfig.LINEAR_API_KEY || ""); } else if (connector.connector_type === "LINKUP_API") { editForm.setValue("LINKUP_API_KEY", newlySavedConfig.LINKUP_API_KEY || ""); } else if (connector.connector_type === "DISCORD_CONNECTOR") {