From 161c96d0614fdd765f8765a8d65e37520d3c93f5 Mon Sep 17 00:00:00 2001 From: Ankit Pasayat Date: Sat, 6 Dec 2025 22:13:17 +0530 Subject: [PATCH] test(backend): add unit tests for core modules Tests added: - test_config.py: Configuration and environment variable tests - test_schemas.py: Pydantic schema validation tests - test_validators.py: Input validation tests - test_db_models.py: SQLAlchemy model tests - test_blocknote_converter.py: BlockNote conversion tests - test_document_converters.py: Document converter tests - test_permissions.py: Permission system tests - test_rbac.py: Role-based access control tests - test_rbac_schemas.py: RBAC schema tests - test_rbac_utils.py: RBAC utility function tests These tests have minimal external dependencies and can run without a database or external services. --- .../tests/test_blocknote_converter.py | 380 ++++++++++++ surfsense_backend/tests/test_config.py | 86 +++ surfsense_backend/tests/test_db_models.py | 325 ++++++++++ .../tests/test_document_converters.py | 513 ++++++++++++++++ surfsense_backend/tests/test_permissions.py | 270 +++++++++ surfsense_backend/tests/test_rbac.py | 355 +++++++++++ surfsense_backend/tests/test_rbac_schemas.py | 392 ++++++++++++ surfsense_backend/tests/test_rbac_utils.py | 340 +++++++++++ surfsense_backend/tests/test_schemas.py | 569 ++++++++++++++++++ surfsense_backend/tests/test_validators.py | 441 ++++++++++++++ 10 files changed, 3671 insertions(+) create mode 100644 surfsense_backend/tests/test_blocknote_converter.py create mode 100644 surfsense_backend/tests/test_config.py create mode 100644 surfsense_backend/tests/test_db_models.py create mode 100644 surfsense_backend/tests/test_document_converters.py create mode 100644 surfsense_backend/tests/test_permissions.py create mode 100644 surfsense_backend/tests/test_rbac.py create mode 100644 surfsense_backend/tests/test_rbac_schemas.py create mode 100644 surfsense_backend/tests/test_rbac_utils.py create mode 100644 surfsense_backend/tests/test_schemas.py create mode 100644 surfsense_backend/tests/test_validators.py diff --git a/surfsense_backend/tests/test_blocknote_converter.py b/surfsense_backend/tests/test_blocknote_converter.py new file mode 100644 index 000000000..60a770e47 --- /dev/null +++ b/surfsense_backend/tests/test_blocknote_converter.py @@ -0,0 +1,380 @@ +""" +Tests for the blocknote_converter utility module. + +These tests validate: +1. Empty/invalid input is handled gracefully (returns None, not crash) +2. API failures don't crash the application +3. Response structure is correctly parsed +4. Network errors are properly handled +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import httpx + +# Skip these tests if app dependencies aren't installed +pytest.importorskip("yaml") + +from app.utils.blocknote_converter import ( + convert_markdown_to_blocknote, + convert_blocknote_to_markdown, +) + + +class TestMarkdownToBlocknoteInputValidation: + """ + Tests validating input handling for markdown to BlockNote conversion. + """ + + @pytest.mark.asyncio + async def test_empty_string_returns_none(self): + """ + Empty markdown must return None, not error. + This is a common edge case when content hasn't been written yet. + """ + result = await convert_markdown_to_blocknote("") + assert result is None + + @pytest.mark.asyncio + async def test_whitespace_only_returns_none(self): + """ + Whitespace-only content should be treated as empty. + Spaces, tabs, newlines alone don't constitute content. + """ + test_cases = [" ", "\t\t", "\n\n", " \n \t "] + + for whitespace in test_cases: + result = await convert_markdown_to_blocknote(whitespace) + assert result is None, f"Expected None for whitespace: {repr(whitespace)}" + + @pytest.mark.asyncio + async def test_very_short_content_returns_fallback(self): + """ + Very short content should return a fallback document. + Content too short to convert meaningfully should still return something. + """ + result = await convert_markdown_to_blocknote("x") + + assert result is not None + assert isinstance(result, list) + assert len(result) > 0 + # Fallback document should be a paragraph + assert result[0]["type"] == "paragraph" + + +class TestMarkdownToBlocknoteNetworkResilience: + """ + Tests validating network error handling. + The converter should never crash on network issues. + """ + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_timeout_returns_none_not_exception( + self, mock_config, mock_client_class + ): + """ + Network timeout must return None, not raise exception. + Timeouts are common and shouldn't crash the application. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + # Long enough content to trigger API call + result = await convert_markdown_to_blocknote( + "# Heading\n\nThis is a paragraph with enough content." + ) + + assert result is None # Not an exception + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_server_error_returns_none_not_exception( + self, mock_config, mock_client_class + ): + """ + HTTP 5xx errors must return None, not raise exception. + Server errors shouldn't crash the caller. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_client = AsyncMock() + mock_client.post = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Server error", + request=MagicMock(), + response=mock_response, + ) + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + result = await convert_markdown_to_blocknote( + "# Heading\n\nThis is a paragraph with enough content." + ) + + assert result is None + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_connection_error_returns_none(self, mock_config, mock_client_class): + """ + Connection errors (server unreachable) must return None. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + result = await convert_markdown_to_blocknote( + "# Heading\n\nThis is a paragraph with enough content." + ) + + assert result is None + + +class TestMarkdownToBlocknoteSuccessfulConversion: + """ + Tests for successful conversion scenarios. + """ + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_successful_conversion_returns_document( + self, mock_config, mock_client_class + ): + """ + Successful API response should return the BlockNote document. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + expected_document = [{"type": "paragraph", "content": [{"text": "Test"}]}] + + mock_response = MagicMock() + mock_response.json.return_value = {"blocknote_document": expected_document} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + result = await convert_markdown_to_blocknote( + "# This is a heading\n\nThis is a paragraph with enough content." + ) + + assert result == expected_document + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_empty_api_response_returns_none( + self, mock_config, mock_client_class + ): + """ + If API returns null/empty document, function should return None. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_response = MagicMock() + mock_response.json.return_value = {"blocknote_document": None} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + result = await convert_markdown_to_blocknote( + "# Heading\n\nSome content that is long enough." + ) + + assert result is None + + +class TestBlocknoteToMarkdownInputValidation: + """ + Tests validating input handling for BlockNote to markdown conversion. + """ + + @pytest.mark.asyncio + async def test_none_document_returns_none(self): + """None input must return None, not crash.""" + result = await convert_blocknote_to_markdown(None) + assert result is None + + @pytest.mark.asyncio + async def test_empty_dict_returns_none(self): + """Empty dict should be treated as no content.""" + result = await convert_blocknote_to_markdown({}) + assert result is None + + @pytest.mark.asyncio + async def test_empty_list_returns_none(self): + """Empty list should be treated as no content.""" + result = await convert_blocknote_to_markdown([]) + assert result is None + + +class TestBlocknoteToMarkdownNetworkResilience: + """ + Tests validating network error handling for BlockNote to markdown. + """ + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_timeout_returns_none(self, mock_config, mock_client_class): + """Timeout must return None, not exception.""" + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + blocknote_doc = [{"type": "paragraph", "content": []}] + result = await convert_blocknote_to_markdown(blocknote_doc) + + assert result is None + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_server_error_returns_none(self, mock_config, mock_client_class): + """HTTP errors must return None, not exception.""" + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_client = AsyncMock() + mock_client.post = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Server error", + request=MagicMock(), + response=mock_response, + ) + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + blocknote_doc = [{"type": "paragraph", "content": []}] + result = await convert_blocknote_to_markdown(blocknote_doc) + + assert result is None + + +class TestBlocknoteToMarkdownSuccessfulConversion: + """ + Tests for successful BlockNote to markdown conversion. + """ + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_successful_conversion_returns_markdown( + self, mock_config, mock_client_class + ): + """Successful conversion should return markdown string.""" + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + expected_markdown = "# Converted Heading\n\nParagraph text." + + mock_response = MagicMock() + mock_response.json.return_value = {"markdown": expected_markdown} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + blocknote_doc = [ + {"type": "heading", "content": [{"type": "text", "text": "Test"}]} + ] + result = await convert_blocknote_to_markdown(blocknote_doc) + + assert result == expected_markdown + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_null_markdown_response_returns_none( + self, mock_config, mock_client_class + ): + """If API returns null markdown, function should return None.""" + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + mock_response = MagicMock() + mock_response.json.return_value = {"markdown": None} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + blocknote_doc = [{"type": "paragraph", "content": []}] + result = await convert_blocknote_to_markdown(blocknote_doc) + + assert result is None + + @pytest.mark.asyncio + @patch("app.utils.blocknote_converter.httpx.AsyncClient") + @patch("app.utils.blocknote_converter.config") + async def test_list_document_is_handled(self, mock_config, mock_client_class): + """ + List documents (multiple blocks) should be handled correctly. + """ + mock_config.NEXT_FRONTEND_URL = "http://localhost:3000" + + expected_markdown = "- Item 1\n- Item 2" + + mock_response = MagicMock() + mock_response.json.return_value = {"markdown": expected_markdown} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + mock_client_class.return_value = mock_client + + blocknote_doc = [ + { + "type": "bulletListItem", + "content": [{"type": "text", "text": "Item 1"}], + }, + { + "type": "bulletListItem", + "content": [{"type": "text", "text": "Item 2"}], + }, + ] + result = await convert_blocknote_to_markdown(blocknote_doc) + + assert result == expected_markdown diff --git a/surfsense_backend/tests/test_config.py b/surfsense_backend/tests/test_config.py new file mode 100644 index 000000000..27b0abc55 --- /dev/null +++ b/surfsense_backend/tests/test_config.py @@ -0,0 +1,86 @@ +""" +Tests for config module. +Tests application configuration and environment variable handling. +""" + + +class TestConfigEnvironmentVariables: + """Tests for config environment variable handling.""" + + def test_config_loads_without_error(self): + """Test that config module loads without error.""" + from app.config import config + + # Config should be an object + assert config is not None + + def test_config_has_expected_attributes(self): + """Test config has expected attributes.""" + from app.config import config + + # These should exist (may have default values) + assert hasattr(config, 'DATABASE_URL') or True # Optional + assert hasattr(config, 'SECRET_KEY') or True # Optional + + +class TestGlobalLLMConfigs: + """Tests for global LLM configurations.""" + + def test_global_llm_configs_is_list(self): + """Test GLOBAL_LLM_CONFIGS is a list.""" + from app.config import config + + assert isinstance(config.GLOBAL_LLM_CONFIGS, list) + + def test_global_llm_configs_have_required_fields(self): + """Test each global config has required fields.""" + from app.config import config + + required_fields = {"id", "name", "provider", "model_name"} + + for cfg in config.GLOBAL_LLM_CONFIGS: + for field in required_fields: + assert field in cfg, f"Missing field {field} in global config" + + def test_global_llm_configs_have_negative_ids(self): + """Test all global configs have negative IDs.""" + from app.config import config + + for cfg in config.GLOBAL_LLM_CONFIGS: + assert cfg["id"] < 0, f"Global config {cfg['name']} should have negative ID" + + +class TestEmbeddingModelInstance: + """Tests for embedding model instance.""" + + def test_embedding_model_instance_exists(self): + """Test embedding model instance is configured.""" + from app.config import config + + # Should have an embedding model instance + assert hasattr(config, 'embedding_model_instance') + + def test_embedding_model_has_embed_method(self): + """Test embedding model has embed method.""" + from app.config import config + + if config.embedding_model_instance is not None: + assert hasattr(config.embedding_model_instance, 'embed') + + +class TestAuthConfiguration: + """Tests for authentication configuration.""" + + def test_auth_type_is_string(self): + """Test AUTH_TYPE is a string.""" + from app.config import config + + if hasattr(config, 'AUTH_TYPE'): + assert isinstance(config.AUTH_TYPE, str) + + def test_registration_enabled_is_boolean(self): + """Test REGISTRATION_ENABLED is boolean.""" + from app.config import config + + if hasattr(config, 'REGISTRATION_ENABLED'): + assert isinstance(config.REGISTRATION_ENABLED, bool) diff --git a/surfsense_backend/tests/test_db_models.py b/surfsense_backend/tests/test_db_models.py new file mode 100644 index 000000000..2a3a2b8db --- /dev/null +++ b/surfsense_backend/tests/test_db_models.py @@ -0,0 +1,325 @@ +""" +Tests for database models and functions. +Tests SQLAlchemy models, enums, and database utility functions. +""" + +from app.db import ( + DocumentType, + LiteLLMProvider, + SearchSourceConnectorType, + Permission, + SearchSpace, + Document, + Chunk, + Chat, + Podcast, + LLMConfig, + SearchSourceConnector, + SearchSpaceRole, + SearchSpaceMembership, + SearchSpaceInvite, + User, + LogLevel, + LogStatus, + ChatType, +) + + +class TestDocumentType: + """Tests for DocumentType enum.""" + + def test_all_document_types_are_strings(self): + """Test all document types have string values.""" + for doc_type in list(DocumentType): + assert isinstance(doc_type.value, str) + + def test_extension_type(self): + """Test EXTENSION document type.""" + assert DocumentType.EXTENSION.value == "EXTENSION" + + def test_file_type(self): + """Test FILE document type.""" + assert DocumentType.FILE.value == "FILE" + + def test_youtube_video_type(self): + """Test YOUTUBE_VIDEO document type.""" + assert DocumentType.YOUTUBE_VIDEO.value == "YOUTUBE_VIDEO" + + def test_crawled_url_type(self): + """Test CRAWLED_URL document type.""" + assert DocumentType.CRAWLED_URL.value == "CRAWLED_URL" + + def test_connector_types_exist(self): + """Test connector document types exist.""" + connector_types = [ + "SLACK_CONNECTOR", + "NOTION_CONNECTOR", + "GITHUB_CONNECTOR", + "JIRA_CONNECTOR", + "CONFLUENCE_CONNECTOR", + "LINEAR_CONNECTOR", + "DISCORD_CONNECTOR", + ] + + for conn_type in connector_types: + assert hasattr(DocumentType, conn_type) + + +class TestLiteLLMProvider: + """Tests for LiteLLMProvider enum.""" + + def test_openai_provider(self): + """Test OPENAI provider.""" + assert LiteLLMProvider.OPENAI.value == "OPENAI" + + def test_anthropic_provider(self): + """Test ANTHROPIC provider.""" + assert LiteLLMProvider.ANTHROPIC.value == "ANTHROPIC" + + def test_google_provider(self): + """Test GOOGLE provider.""" + assert LiteLLMProvider.GOOGLE.value == "GOOGLE" + + def test_ollama_provider(self): + """Test OLLAMA provider.""" + assert LiteLLMProvider.OLLAMA.value == "OLLAMA" + + def test_all_providers_are_strings(self): + """Test all providers have string values.""" + for provider in list(LiteLLMProvider): + assert isinstance(provider.value, str) + + +class TestSearchSourceConnectorType: + """Tests for SearchSourceConnectorType enum.""" + + def test_tavily_api(self): + """Test TAVILY_API connector type.""" + assert SearchSourceConnectorType.TAVILY_API.value == "TAVILY_API" + + def test_searxng_api(self): + """Test SEARXNG_API connector type.""" + assert SearchSourceConnectorType.SEARXNG_API.value == "SEARXNG_API" + + def test_slack_connector(self): + """Test SLACK_CONNECTOR connector type.""" + assert SearchSourceConnectorType.SLACK_CONNECTOR.value == "SLACK_CONNECTOR" + + def test_notion_connector(self): + """Test NOTION_CONNECTOR connector type.""" + assert SearchSourceConnectorType.NOTION_CONNECTOR.value == "NOTION_CONNECTOR" + + def test_all_connector_types_are_strings(self): + """Test all connector types have string values.""" + for conn_type in list(SearchSourceConnectorType): + assert isinstance(conn_type.value, str) + + +class TestPermission: + """Tests for Permission enum.""" + + def test_full_access_permission(self): + """Test FULL_ACCESS permission.""" + assert Permission.FULL_ACCESS.value == "*" + + def test_document_permissions(self): + """Test document permissions exist.""" + doc_permissions = [ + "DOCUMENTS_CREATE", + "DOCUMENTS_READ", + "DOCUMENTS_UPDATE", + "DOCUMENTS_DELETE", + ] + + for perm in doc_permissions: + assert hasattr(Permission, perm) + + def test_chat_permissions(self): + """Test chat permissions exist.""" + chat_permissions = [ + "CHATS_CREATE", + "CHATS_READ", + "CHATS_UPDATE", + "CHATS_DELETE", + ] + + for perm in chat_permissions: + assert hasattr(Permission, perm) + + def test_llm_config_permissions(self): + """Test LLM config permissions exist.""" + llm_permissions = [ + "LLM_CONFIGS_CREATE", + "LLM_CONFIGS_READ", + "LLM_CONFIGS_UPDATE", + "LLM_CONFIGS_DELETE", + ] + + for perm in llm_permissions: + assert hasattr(Permission, perm) + + def test_settings_permissions(self): + """Test settings permissions exist.""" + settings_permissions = [ + "SETTINGS_VIEW", + "SETTINGS_UPDATE", + "SETTINGS_DELETE", + ] + + for perm in settings_permissions: + assert hasattr(Permission, perm) + + +class TestSearchSpaceModel: + """Tests for SearchSpace model.""" + + def test_search_space_has_required_fields(self): + """Test SearchSpace has required fields.""" + # Check that the model has expected columns + assert hasattr(SearchSpace, 'id') + assert hasattr(SearchSpace, 'name') + assert hasattr(SearchSpace, 'user_id') + assert hasattr(SearchSpace, 'created_at') + + +class TestDocumentModel: + """Tests for Document model.""" + + def test_document_has_required_fields(self): + """Test Document has required fields.""" + assert hasattr(Document, 'id') + assert hasattr(Document, 'title') + assert hasattr(Document, 'document_type') + assert hasattr(Document, 'content') + assert hasattr(Document, 'search_space_id') + + def test_document_has_chunks_relationship(self): + """Test Document has chunks relationship.""" + assert hasattr(Document, 'chunks') + + +class TestChunkModel: + """Tests for Chunk model.""" + + def test_chunk_has_required_fields(self): + """Test Chunk has required fields.""" + assert hasattr(Chunk, 'id') + assert hasattr(Chunk, 'content') + assert hasattr(Chunk, 'document_id') + + def test_chunk_has_embedding_field(self): + """Test Chunk has embedding field.""" + assert hasattr(Chunk, 'embedding') + + +class TestChatModel: + """Tests for Chat model.""" + + def test_chat_has_required_fields(self): + """Test Chat has required fields.""" + assert hasattr(Chat, 'id') + assert hasattr(Chat, 'title') + assert hasattr(Chat, 'search_space_id') + + +class TestChatType: + """Tests for ChatType enum.""" + + def test_chat_type_values(self): + """Test ChatType values.""" + assert hasattr(ChatType, 'QNA') + + +class TestLogLevel: + """Tests for LogLevel enum.""" + + def test_log_level_values(self): + """Test LogLevel values exist.""" + assert hasattr(LogLevel, 'INFO') + assert hasattr(LogLevel, 'WARNING') + assert hasattr(LogLevel, 'ERROR') + + +class TestLogStatus: + """Tests for LogStatus enum.""" + + def test_log_status_values(self): + """Test LogStatus values exist.""" + assert hasattr(LogStatus, 'IN_PROGRESS') + assert hasattr(LogStatus, 'SUCCESS') + assert hasattr(LogStatus, 'FAILED') + assert LogStatus.IN_PROGRESS.value == "IN_PROGRESS" + + +class TestLLMConfigModel: + """Tests for LLMConfig model.""" + + def test_llm_config_has_required_fields(self): + """Test LLMConfig has required fields.""" + assert hasattr(LLMConfig, 'id') + assert hasattr(LLMConfig, 'name') + assert hasattr(LLMConfig, 'provider') + assert hasattr(LLMConfig, 'model_name') + assert hasattr(LLMConfig, 'api_key') + assert hasattr(LLMConfig, 'search_space_id') + + +class TestSearchSourceConnectorModel: + """Tests for SearchSourceConnector model.""" + + def test_connector_has_required_fields(self): + """Test SearchSourceConnector has required fields.""" + assert hasattr(SearchSourceConnector, 'id') + assert hasattr(SearchSourceConnector, 'connector_type') + assert hasattr(SearchSourceConnector, 'config') + assert hasattr(SearchSourceConnector, 'search_space_id') + + +class TestRBACModels: + """Tests for RBAC models.""" + + def test_search_space_role_has_required_fields(self): + """Test SearchSpaceRole has required fields.""" + assert hasattr(SearchSpaceRole, 'id') + assert hasattr(SearchSpaceRole, 'name') + assert hasattr(SearchSpaceRole, 'permissions') + assert hasattr(SearchSpaceRole, 'search_space_id') + + def test_search_space_membership_has_required_fields(self): + """Test SearchSpaceMembership has required fields.""" + assert hasattr(SearchSpaceMembership, 'id') + assert hasattr(SearchSpaceMembership, 'user_id') + assert hasattr(SearchSpaceMembership, 'search_space_id') + assert hasattr(SearchSpaceMembership, 'role_id') + assert hasattr(SearchSpaceMembership, 'is_owner') + + def test_search_space_invite_has_required_fields(self): + """Test SearchSpaceInvite has required fields.""" + assert hasattr(SearchSpaceInvite, 'id') + assert hasattr(SearchSpaceInvite, 'invite_code') + assert hasattr(SearchSpaceInvite, 'search_space_id') + assert hasattr(SearchSpaceInvite, 'role_id') + + +class TestUserModel: + """Tests for User model.""" + + def test_user_has_required_fields(self): + """Test User has required fields.""" + assert hasattr(User, 'id') + assert hasattr(User, 'email') + + def test_user_has_page_limit_fields(self): + """Test User has page limit fields.""" + assert hasattr(User, 'pages_used') + assert hasattr(User, 'pages_limit') + + +class TestPodcastModel: + """Tests for Podcast model.""" + + def test_podcast_has_required_fields(self): + """Test Podcast has required fields.""" + assert hasattr(Podcast, 'id') + assert hasattr(Podcast, 'title') + assert hasattr(Podcast, 'search_space_id') diff --git a/surfsense_backend/tests/test_document_converters.py b/surfsense_backend/tests/test_document_converters.py new file mode 100644 index 000000000..2c06a3107 --- /dev/null +++ b/surfsense_backend/tests/test_document_converters.py @@ -0,0 +1,513 @@ +""" +Tests for document_converters utility module. + +This module tests the document conversion functions including +content hash generation, markdown conversion, and chunking utilities. +""" + +import hashlib +from unittest.mock import MagicMock + +import pytest + +from app.db import DocumentType +from app.utils.document_converters import ( + convert_chunks_to_langchain_documents, + convert_document_to_markdown, + convert_element_to_markdown, + generate_content_hash, + generate_unique_identifier_hash, +) + + +class TestGenerateContentHash: + """Tests for generate_content_hash function.""" + + def test_generates_sha256_hash(self): + """Test that function generates SHA-256 hash.""" + content = "Test content" + search_space_id = 1 + result = generate_content_hash(content, search_space_id) + + # Verify it's a valid SHA-256 hash (64 hex characters) + assert len(result) == 64 + assert all(c in "0123456789abcdef" for c in result) + + def test_combines_content_and_search_space_id(self): + """Test that hash is generated from combined data.""" + content = "Test content" + search_space_id = 1 + + # Manually compute expected hash + combined_data = f"{search_space_id}:{content}" + expected_hash = hashlib.sha256(combined_data.encode("utf-8")).hexdigest() + + result = generate_content_hash(content, search_space_id) + assert result == expected_hash + + def test_different_content_produces_different_hash(self): + """Test that different content produces different hashes.""" + hash1 = generate_content_hash("Content 1", 1) + hash2 = generate_content_hash("Content 2", 1) + assert hash1 != hash2 + + def test_different_search_space_produces_different_hash(self): + """Test that different search space ID produces different hashes.""" + hash1 = generate_content_hash("Same content", 1) + hash2 = generate_content_hash("Same content", 2) + assert hash1 != hash2 + + def test_same_input_produces_same_hash(self): + """Test that same input always produces same hash.""" + content = "Consistent content" + search_space_id = 42 + + hash1 = generate_content_hash(content, search_space_id) + hash2 = generate_content_hash(content, search_space_id) + assert hash1 == hash2 + + def test_empty_content(self): + """Test with empty content.""" + result = generate_content_hash("", 1) + assert len(result) == 64 # Still produces valid hash + + def test_unicode_content(self): + """Test with unicode content.""" + result = generate_content_hash("こんにちは世界 🌍", 1) + assert len(result) == 64 + + +class TestGenerateUniqueIdentifierHash: + """Tests for generate_unique_identifier_hash function.""" + + def test_generates_sha256_hash(self): + """Test that function generates SHA-256 hash.""" + result = generate_unique_identifier_hash( + DocumentType.SLACK_CONNECTOR, + "message123", + 1, + ) + assert len(result) == 64 + assert all(c in "0123456789abcdef" for c in result) + + def test_combines_all_parameters(self): + """Test that hash is generated from all parameters.""" + doc_type = DocumentType.SLACK_CONNECTOR + unique_id = "message123" + search_space_id = 42 + + # Manually compute expected hash + combined_data = f"{doc_type.value}:{unique_id}:{search_space_id}" + expected_hash = hashlib.sha256(combined_data.encode("utf-8")).hexdigest() + + result = generate_unique_identifier_hash(doc_type, unique_id, search_space_id) + assert result == expected_hash + + def test_different_document_types_produce_different_hashes(self): + """Test different document types produce different hashes.""" + hash1 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id123", 1) + hash2 = generate_unique_identifier_hash(DocumentType.NOTION_CONNECTOR, "id123", 1) + assert hash1 != hash2 + + def test_different_identifiers_produce_different_hashes(self): + """Test different identifiers produce different hashes.""" + hash1 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id123", 1) + hash2 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id456", 1) + assert hash1 != hash2 + + def test_integer_identifier(self): + """Test with integer unique identifier.""" + result = generate_unique_identifier_hash(DocumentType.JIRA_CONNECTOR, 12345, 1) + assert len(result) == 64 + + def test_float_identifier(self): + """Test with float unique identifier (e.g., Slack timestamps).""" + result = generate_unique_identifier_hash( + DocumentType.SLACK_CONNECTOR, + 1234567890.123456, + 1, + ) + assert len(result) == 64 + + def test_consistency(self): + """Test that same inputs always produce same hash.""" + params = (DocumentType.GITHUB_CONNECTOR, "pr-123", 5) + + hash1 = generate_unique_identifier_hash(*params) + hash2 = generate_unique_identifier_hash(*params) + assert hash1 == hash2 + + +class TestConvertElementToMarkdown: + """Tests for convert_element_to_markdown function.""" + + @pytest.mark.asyncio + async def test_formula_element(self): + """Test Formula element conversion.""" + element = MagicMock() + element.metadata = {"category": "Formula"} + element.page_content = "E = mc^2" + + result = await convert_element_to_markdown(element) + assert "```math" in result + assert "E = mc^2" in result + + @pytest.mark.asyncio + async def test_figure_caption_element(self): + """Test FigureCaption element conversion.""" + element = MagicMock() + element.metadata = {"category": "FigureCaption"} + element.page_content = "Figure 1: Test image" + + result = await convert_element_to_markdown(element) + assert "*Figure:" in result + + @pytest.mark.asyncio + async def test_narrative_text_element(self): + """Test NarrativeText element conversion.""" + element = MagicMock() + element.metadata = {"category": "NarrativeText"} + element.page_content = "This is a paragraph of text." + + result = await convert_element_to_markdown(element) + assert "This is a paragraph of text." in result + assert result.endswith("\n\n") + + @pytest.mark.asyncio + async def test_list_item_element(self): + """Test ListItem element conversion.""" + element = MagicMock() + element.metadata = {"category": "ListItem"} + element.page_content = "Item one" + + result = await convert_element_to_markdown(element) + assert result.startswith("- ") + assert "Item one" in result + + @pytest.mark.asyncio + async def test_title_element(self): + """Test Title element conversion.""" + element = MagicMock() + element.metadata = {"category": "Title"} + element.page_content = "Document Title" + + result = await convert_element_to_markdown(element) + assert result.startswith("# ") + assert "Document Title" in result + + @pytest.mark.asyncio + async def test_address_element(self): + """Test Address element conversion.""" + element = MagicMock() + element.metadata = {"category": "Address"} + element.page_content = "123 Main St" + + result = await convert_element_to_markdown(element) + assert result.startswith("> ") + + @pytest.mark.asyncio + async def test_email_address_element(self): + """Test EmailAddress element conversion.""" + element = MagicMock() + element.metadata = {"category": "EmailAddress"} + element.page_content = "test@example.com" + + result = await convert_element_to_markdown(element) + assert "`test@example.com`" in result + + @pytest.mark.asyncio + async def test_table_element(self): + """Test Table element conversion.""" + element = MagicMock() + element.metadata = {"category": "Table", "text_as_html": "
data
"} + element.page_content = "Table content" + + result = await convert_element_to_markdown(element) + assert "```html" in result + assert "" in result + + @pytest.mark.asyncio + async def test_header_element(self): + """Test Header element conversion.""" + element = MagicMock() + element.metadata = {"category": "Header"} + element.page_content = "Section Header" + + result = await convert_element_to_markdown(element) + assert result.startswith("## ") + + @pytest.mark.asyncio + async def test_code_snippet_element(self): + """Test CodeSnippet element conversion.""" + element = MagicMock() + element.metadata = {"category": "CodeSnippet"} + element.page_content = "print('hello')" + + result = await convert_element_to_markdown(element) + assert "```" in result + assert "print('hello')" in result + + @pytest.mark.asyncio + async def test_page_number_element(self): + """Test PageNumber element conversion.""" + element = MagicMock() + element.metadata = {"category": "PageNumber"} + element.page_content = "42" + + result = await convert_element_to_markdown(element) + assert "*Page 42*" in result + + @pytest.mark.asyncio + async def test_page_break_element(self): + """Test PageBreak element conversion.""" + element = MagicMock() + element.metadata = {"category": "PageBreak"} + # PageBreak with content returns horizontal rule + element.page_content = "page break content" + + result = await convert_element_to_markdown(element) + assert "---" in result + + @pytest.mark.asyncio + async def test_empty_content(self): + """Test element with empty content.""" + element = MagicMock() + element.metadata = {"category": "NarrativeText"} + element.page_content = "" + + result = await convert_element_to_markdown(element) + assert result == "" + + @pytest.mark.asyncio + async def test_uncategorized_element(self): + """Test UncategorizedText element conversion.""" + element = MagicMock() + element.metadata = {"category": "UncategorizedText"} + element.page_content = "Some uncategorized text" + + result = await convert_element_to_markdown(element) + assert "Some uncategorized text" in result + + +class TestConvertDocumentToMarkdown: + """Tests for convert_document_to_markdown function.""" + + @pytest.mark.asyncio + async def test_converts_multiple_elements(self): + """Test converting multiple elements.""" + elements = [] + + # Title element + title = MagicMock() + title.metadata = {"category": "Title"} + title.page_content = "Document Title" + elements.append(title) + + # Narrative text element + para = MagicMock() + para.metadata = {"category": "NarrativeText"} + para.page_content = "This is a paragraph." + elements.append(para) + + result = await convert_document_to_markdown(elements) + + assert "# Document Title" in result + assert "This is a paragraph." in result + + @pytest.mark.asyncio + async def test_empty_elements(self): + """Test with empty elements list.""" + result = await convert_document_to_markdown([]) + assert result == "" + + @pytest.mark.asyncio + async def test_preserves_order(self): + """Test that element order is preserved.""" + elements = [] + + for i in range(3): + elem = MagicMock() + elem.metadata = {"category": "NarrativeText"} + elem.page_content = f"Paragraph {i}" + elements.append(elem) + + result = await convert_document_to_markdown(elements) + + # Check order is preserved + pos0 = result.find("Paragraph 0") + pos1 = result.find("Paragraph 1") + pos2 = result.find("Paragraph 2") + + assert pos0 < pos1 < pos2 + + +class TestConvertChunksToLangchainDocuments: + """Tests for convert_chunks_to_langchain_documents function.""" + + def test_converts_basic_chunks(self): + """Test converting basic chunk structure.""" + chunks = [ + { + "chunk_id": 1, + "content": "This is chunk content", + "score": 0.95, + "document": { + "id": 10, + "title": "Test Document", + "document_type": "FILE", + "metadata": {"url": "https://example.com"}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert len(result) == 1 + assert "This is chunk content" in result[0].page_content + assert result[0].metadata["chunk_id"] == 1 + assert result[0].metadata["document_id"] == 10 + assert result[0].metadata["document_title"] == "Test Document" + + def test_includes_source_id_in_content(self): + """Test that source_id is included in XML content.""" + chunks = [ + { + "chunk_id": 1, + "content": "Test content", + "score": 0.9, + "document": { + "id": 5, + "title": "Doc", + "document_type": "FILE", + "metadata": {}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert "5" in result[0].page_content + + def test_extracts_source_url(self): + """Test source URL extraction from metadata.""" + chunks = [ + { + "chunk_id": 1, + "content": "Content", + "score": 0.9, + "document": { + "id": 1, + "title": "Doc", + "document_type": "CRAWLED_URL", + "metadata": {"url": "https://example.com/page"}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert result[0].metadata["source"] == "https://example.com/page" + + def test_extracts_source_url_alternate_key(self): + """Test source URL extraction with sourceURL key.""" + chunks = [ + { + "chunk_id": 1, + "content": "Content", + "score": 0.9, + "document": { + "id": 1, + "title": "Doc", + "document_type": "CRAWLED_URL", + "metadata": {"sourceURL": "https://example.com/alternate"}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert result[0].metadata["source"] == "https://example.com/alternate" + + def test_handles_missing_document(self): + """Test handling chunks without document info.""" + chunks = [ + { + "chunk_id": 1, + "content": "Content without document", + "score": 0.8, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert len(result) == 1 + assert "Content without document" in result[0].page_content + + def test_prefixes_document_metadata(self): + """Test document metadata is prefixed.""" + chunks = [ + { + "chunk_id": 1, + "content": "Content", + "score": 0.9, + "document": { + "id": 1, + "title": "Doc", + "document_type": "FILE", + "metadata": {"custom_field": "custom_value"}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert result[0].metadata["doc_meta_custom_field"] == "custom_value" + + def test_handles_rank_field(self): + """Test handling of rank field when present.""" + chunks = [ + { + "chunk_id": 1, + "content": "Content", + "score": 0.9, + "rank": 1, + "document": { + "id": 1, + "title": "Doc", + "document_type": "FILE", + "metadata": {}, + }, + } + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert result[0].metadata["rank"] == 1 + + def test_empty_chunks_list(self): + """Test with empty chunks list.""" + result = convert_chunks_to_langchain_documents([]) + assert result == [] + + def test_multiple_chunks(self): + """Test converting multiple chunks.""" + chunks = [ + { + "chunk_id": i, + "content": f"Content {i}", + "score": 0.9 - (i * 0.1), + "document": { + "id": i, + "title": f"Doc {i}", + "document_type": "FILE", + "metadata": {}, + }, + } + for i in range(3) + ] + + result = convert_chunks_to_langchain_documents(chunks) + + assert len(result) == 3 + for i, doc in enumerate(result): + assert f"Content {i}" in doc.page_content diff --git a/surfsense_backend/tests/test_permissions.py b/surfsense_backend/tests/test_permissions.py new file mode 100644 index 000000000..3dbef13b6 --- /dev/null +++ b/surfsense_backend/tests/test_permissions.py @@ -0,0 +1,270 @@ +""" +Tests for permission functions in db.py. + +This module tests the permission checking functions used in RBAC. +""" + +from app.db import ( + DEFAULT_ROLE_PERMISSIONS, + Permission, + get_default_roles_config, + has_all_permissions, + has_any_permission, + has_permission, +) + + +class TestHasPermission: + """Tests for has_permission function.""" + + def test_has_permission_with_exact_match(self): + """Test has_permission returns True for exact permission match.""" + permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value] + assert has_permission(permissions, Permission.DOCUMENTS_READ.value) is True + + def test_has_permission_with_no_match(self): + """Test has_permission returns False when permission not in list.""" + permissions = [Permission.DOCUMENTS_READ.value] + assert has_permission(permissions, Permission.DOCUMENTS_CREATE.value) is False + + def test_has_permission_with_full_access(self): + """Test has_permission returns True for any permission when user has FULL_ACCESS.""" + permissions = [Permission.FULL_ACCESS.value] + assert has_permission(permissions, Permission.DOCUMENTS_CREATE.value) is True + assert has_permission(permissions, Permission.SETTINGS_DELETE.value) is True + assert has_permission(permissions, Permission.MEMBERS_MANAGE_ROLES.value) is True + + def test_has_permission_with_empty_list(self): + """Test has_permission returns False for empty permission list.""" + assert has_permission([], Permission.DOCUMENTS_READ.value) is False + + def test_has_permission_with_none(self): + """Test has_permission returns False for None.""" + assert has_permission(None, Permission.DOCUMENTS_READ.value) is False + + +class TestHasAnyPermission: + """Tests for has_any_permission function.""" + + def test_has_any_permission_with_one_match(self): + """Test has_any_permission returns True when at least one permission matches.""" + user_permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value] + required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value] + assert has_any_permission(user_permissions, required) is True + + def test_has_any_permission_with_all_match(self): + """Test has_any_permission returns True when all permissions match.""" + user_permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value] + required = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value] + assert has_any_permission(user_permissions, required) is True + + def test_has_any_permission_with_no_match(self): + """Test has_any_permission returns False when no permissions match.""" + user_permissions = [Permission.DOCUMENTS_READ.value] + required = [Permission.CHATS_CREATE.value, Permission.SETTINGS_UPDATE.value] + assert has_any_permission(user_permissions, required) is False + + def test_has_any_permission_with_full_access(self): + """Test has_any_permission returns True with FULL_ACCESS.""" + user_permissions = [Permission.FULL_ACCESS.value] + required = [Permission.SETTINGS_DELETE.value] + assert has_any_permission(user_permissions, required) is True + + def test_has_any_permission_with_empty_user_permissions(self): + """Test has_any_permission returns False with empty user permissions.""" + assert has_any_permission([], [Permission.DOCUMENTS_READ.value]) is False + + def test_has_any_permission_with_none(self): + """Test has_any_permission returns False with None.""" + assert has_any_permission(None, [Permission.DOCUMENTS_READ.value]) is False + + +class TestHasAllPermissions: + """Tests for has_all_permissions function.""" + + def test_has_all_permissions_with_all_match(self): + """Test has_all_permissions returns True when all permissions match.""" + user_permissions = [ + Permission.DOCUMENTS_READ.value, + Permission.DOCUMENTS_CREATE.value, + Permission.CHATS_READ.value, + ] + required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value] + assert has_all_permissions(user_permissions, required) is True + + def test_has_all_permissions_with_partial_match(self): + """Test has_all_permissions returns False when only some permissions match.""" + user_permissions = [Permission.DOCUMENTS_READ.value] + required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value] + assert has_all_permissions(user_permissions, required) is False + + def test_has_all_permissions_with_no_match(self): + """Test has_all_permissions returns False when no permissions match.""" + user_permissions = [Permission.CHATS_READ.value] + required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value] + assert has_all_permissions(user_permissions, required) is False + + def test_has_all_permissions_with_full_access(self): + """Test has_all_permissions returns True with FULL_ACCESS.""" + user_permissions = [Permission.FULL_ACCESS.value] + required = [ + Permission.DOCUMENTS_READ.value, + Permission.DOCUMENTS_CREATE.value, + Permission.SETTINGS_DELETE.value, + ] + assert has_all_permissions(user_permissions, required) is True + + def test_has_all_permissions_with_empty_user_permissions(self): + """Test has_all_permissions returns False with empty user permissions.""" + assert has_all_permissions([], [Permission.DOCUMENTS_READ.value]) is False + + def test_has_all_permissions_with_none(self): + """Test has_all_permissions returns False with None.""" + assert has_all_permissions(None, [Permission.DOCUMENTS_READ.value]) is False + + def test_has_all_permissions_with_empty_required(self): + """Test has_all_permissions returns True with empty required list.""" + user_permissions = [Permission.DOCUMENTS_READ.value] + assert has_all_permissions(user_permissions, []) is True + + +class TestPermissionEnum: + """Tests for Permission enum values.""" + + def test_permission_values_are_strings(self): + """Test all permission values are strings.""" + for perm in list(Permission): + assert isinstance(perm.value, str) + + def test_permission_document_values(self): + """Test document permission values.""" + assert Permission.DOCUMENTS_CREATE.value == "documents:create" + assert Permission.DOCUMENTS_READ.value == "documents:read" + assert Permission.DOCUMENTS_UPDATE.value == "documents:update" + assert Permission.DOCUMENTS_DELETE.value == "documents:delete" + + def test_permission_chat_values(self): + """Test chat permission values.""" + assert Permission.CHATS_CREATE.value == "chats:create" + assert Permission.CHATS_READ.value == "chats:read" + assert Permission.CHATS_UPDATE.value == "chats:update" + assert Permission.CHATS_DELETE.value == "chats:delete" + + def test_permission_llm_config_values(self): + """Test LLM config permission values.""" + assert Permission.LLM_CONFIGS_CREATE.value == "llm_configs:create" + assert Permission.LLM_CONFIGS_READ.value == "llm_configs:read" + assert Permission.LLM_CONFIGS_UPDATE.value == "llm_configs:update" + assert Permission.LLM_CONFIGS_DELETE.value == "llm_configs:delete" + + def test_permission_members_values(self): + """Test member permission values.""" + assert Permission.MEMBERS_INVITE.value == "members:invite" + assert Permission.MEMBERS_VIEW.value == "members:view" + assert Permission.MEMBERS_REMOVE.value == "members:remove" + assert Permission.MEMBERS_MANAGE_ROLES.value == "members:manage_roles" + + def test_permission_full_access_value(self): + """Test FULL_ACCESS permission value.""" + assert Permission.FULL_ACCESS.value == "*" + + +class TestDefaultRolePermissions: + """Tests for DEFAULT_ROLE_PERMISSIONS configuration.""" + + def test_owner_has_full_access(self): + """Test Owner role has full access.""" + assert Permission.FULL_ACCESS.value in DEFAULT_ROLE_PERMISSIONS["Owner"] + + def test_admin_permissions(self): + """Test Admin role has appropriate permissions.""" + admin_perms = DEFAULT_ROLE_PERMISSIONS["Admin"] + # Admin should have document permissions + assert Permission.DOCUMENTS_CREATE.value in admin_perms + assert Permission.DOCUMENTS_READ.value in admin_perms + assert Permission.DOCUMENTS_UPDATE.value in admin_perms + assert Permission.DOCUMENTS_DELETE.value in admin_perms + # Admin should NOT have settings:delete + assert Permission.SETTINGS_DELETE.value not in admin_perms + + def test_editor_permissions(self): + """Test Editor role has appropriate permissions.""" + editor_perms = DEFAULT_ROLE_PERMISSIONS["Editor"] + # Editor should have document CRUD + assert Permission.DOCUMENTS_CREATE.value in editor_perms + assert Permission.DOCUMENTS_READ.value in editor_perms + assert Permission.DOCUMENTS_UPDATE.value in editor_perms + assert Permission.DOCUMENTS_DELETE.value in editor_perms + # Editor should have chat CRUD + assert Permission.CHATS_CREATE.value in editor_perms + assert Permission.CHATS_READ.value in editor_perms + # Editor should NOT have member management + assert Permission.MEMBERS_REMOVE.value not in editor_perms + + def test_viewer_permissions(self): + """Test Viewer role has read-only permissions.""" + viewer_perms = DEFAULT_ROLE_PERMISSIONS["Viewer"] + # Viewer should have read permissions + assert Permission.DOCUMENTS_READ.value in viewer_perms + assert Permission.CHATS_READ.value in viewer_perms + assert Permission.LLM_CONFIGS_READ.value in viewer_perms + # Viewer should NOT have create/update/delete permissions + assert Permission.DOCUMENTS_CREATE.value not in viewer_perms + assert Permission.DOCUMENTS_UPDATE.value not in viewer_perms + assert Permission.DOCUMENTS_DELETE.value not in viewer_perms + assert Permission.CHATS_CREATE.value not in viewer_perms + + +class TestGetDefaultRolesConfig: + """Tests for get_default_roles_config function.""" + + def test_returns_list(self): + """Test get_default_roles_config returns a list.""" + config = get_default_roles_config() + assert isinstance(config, list) + + def test_contains_four_roles(self): + """Test get_default_roles_config returns 4 roles.""" + config = get_default_roles_config() + assert len(config) == 4 + + def test_role_names(self): + """Test get_default_roles_config contains expected role names.""" + config = get_default_roles_config() + role_names = [role["name"] for role in config] + assert "Owner" in role_names + assert "Admin" in role_names + assert "Editor" in role_names + assert "Viewer" in role_names + + def test_all_roles_are_system_roles(self): + """Test all default roles are system roles.""" + config = get_default_roles_config() + for role in config: + assert role["is_system_role"] is True + + def test_editor_is_default_role(self): + """Test Editor is the default role for new members.""" + config = get_default_roles_config() + editor_role = next(role for role in config if role["name"] == "Editor") + assert editor_role["is_default"] is True + + def test_owner_is_not_default_role(self): + """Test Owner is not the default role.""" + config = get_default_roles_config() + owner_role = next(role for role in config if role["name"] == "Owner") + assert owner_role["is_default"] is False + + def test_role_structure(self): + """Test each role has required fields.""" + config = get_default_roles_config() + required_fields = ["name", "description", "permissions", "is_default", "is_system_role"] + for role in config: + for field in required_fields: + assert field in role, f"Role {role.get('name')} missing field {field}" + + def test_owner_role_permissions(self): + """Test Owner role has full access permission.""" + config = get_default_roles_config() + owner_role = next(role for role in config if role["name"] == "Owner") + assert Permission.FULL_ACCESS.value in owner_role["permissions"] diff --git a/surfsense_backend/tests/test_rbac.py b/surfsense_backend/tests/test_rbac.py new file mode 100644 index 000000000..1e0c35bdd --- /dev/null +++ b/surfsense_backend/tests/test_rbac.py @@ -0,0 +1,355 @@ +""" +Tests for the RBAC (Role-Based Access Control) utility functions. + +These tests validate the security-critical RBAC behavior: +1. Users without membership should NEVER access resources +2. Permission checks must be strict - no false positives +3. Owners must have full access +4. Role permissions must be properly enforced +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +# Skip these tests if app dependencies aren't installed +pytest.importorskip("sqlalchemy") +pytest.importorskip("fastapi_users") + +from app.db import Permission, SearchSpaceMembership, SearchSpaceRole +from app.utils.rbac import ( + check_permission, + check_search_space_access, + generate_invite_code, + get_default_role, + get_owner_role, + get_user_permissions, + is_search_space_owner, +) + + +class TestSecurityCriticalAccessControl: + """ + Critical security tests - these MUST pass to prevent unauthorized access. + """ + + @pytest.mark.asyncio + async def test_non_member_cannot_access_search_space(self, mock_session, mock_user): + """ + SECURITY: Non-members must be denied access with 403. + This is critical - allowing access would be a security breach. + """ + search_space_id = 1 + + # Simulate user not being a member + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + with pytest.raises(HTTPException) as exc_info: + await check_search_space_access(mock_session, mock_user, search_space_id) + + # Must be 403 Forbidden, not 404 or other + assert exc_info.value.status_code == 403 + assert "access" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_member_without_permission_is_denied(self, mock_session, mock_user): + """ + SECURITY: Members without specific permission must be denied. + Having membership alone is insufficient for sensitive operations. + """ + search_space_id = 1 + + # Member exists but has limited permissions (only read, not write) + mock_role = MagicMock(spec=SearchSpaceRole) + mock_role.permissions = ["documents:read"] # Does NOT have write + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = False + mock_membership.role = mock_role + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + # Attempt to access a write operation - must fail + with patch("app.utils.rbac.has_permission", return_value=False): + with pytest.raises(HTTPException) as exc_info: + await check_permission( + mock_session, + mock_user, + search_space_id, + "documents:write", + ) + + assert exc_info.value.status_code == 403 + + @pytest.mark.asyncio + async def test_owner_has_full_access_regardless_of_operation( + self, mock_session, mock_user + ): + """ + SECURITY: Owners must have full access to all operations. + This ensures owners can always manage their search spaces. + """ + search_space_id = 1 + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = True + mock_membership.role = None # Owners may not have explicit roles + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + # Owner should pass permission check with FULL_ACCESS + with patch("app.utils.rbac.has_permission", return_value=True) as mock_has_perm: + result = await check_permission( + mock_session, + mock_user, + search_space_id, + "any:permission", + ) + + assert result == mock_membership + # Verify FULL_ACCESS was checked + mock_has_perm.assert_called_once() + call_args = mock_has_perm.call_args[0] + assert Permission.FULL_ACCESS.value in call_args[0] + + +class TestGetUserPermissions: + """Tests for permission retrieval - validates correct permission inheritance.""" + + @pytest.mark.asyncio + async def test_non_member_has_no_permissions(self, mock_session): + """Non-members must have zero permissions.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_user_permissions(mock_session, user_id, search_space_id) + + assert result == [] + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_owner_gets_full_access_permission(self, mock_session): + """Owners must receive FULL_ACCESS permission.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = True + mock_membership.role = None + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_user_permissions(mock_session, user_id, search_space_id) + + assert Permission.FULL_ACCESS.value in result + + @pytest.mark.asyncio + async def test_member_gets_only_role_permissions(self, mock_session): + """Members should get exactly the permissions from their role - no more, no less.""" + user_id = uuid.uuid4() + search_space_id = 1 + + expected_permissions = ["documents:read", "chats:read"] + + mock_role = MagicMock(spec=SearchSpaceRole) + mock_role.permissions = expected_permissions.copy() + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = False + mock_membership.role = mock_role + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_user_permissions(mock_session, user_id, search_space_id) + + # Must match exactly - no extra permissions sneaking in + assert set(result) == set(expected_permissions) + assert len(result) == len(expected_permissions) + + @pytest.mark.asyncio + async def test_member_without_role_has_no_permissions(self, mock_session): + """Members without an assigned role must have empty permissions.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = False + mock_membership.role = None + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_user_permissions(mock_session, user_id, search_space_id) + + assert result == [] + + +class TestOwnershipChecks: + """Tests for ownership verification.""" + + @pytest.mark.asyncio + async def test_is_owner_returns_true_only_for_actual_owner(self, mock_session): + """is_search_space_owner must return True ONLY for actual owners.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = True + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await is_search_space_owner(mock_session, user_id, search_space_id) + + assert result is True + + @pytest.mark.asyncio + async def test_is_owner_returns_false_for_non_owner_member(self, mock_session): + """Regular members must NOT be identified as owners.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_membership = MagicMock(spec=SearchSpaceMembership) + mock_membership.is_owner = False + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await is_search_space_owner(mock_session, user_id, search_space_id) + + assert result is False + + @pytest.mark.asyncio + async def test_is_owner_returns_false_for_non_member(self, mock_session): + """Non-members must NOT be identified as owners.""" + user_id = uuid.uuid4() + search_space_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await is_search_space_owner(mock_session, user_id, search_space_id) + + assert result is False + + +class TestInviteCodeSecurity: + """Tests for invite code generation - validates security requirements.""" + + def test_invite_codes_are_cryptographically_unique(self): + """ + Invite codes must be cryptographically random to prevent guessing. + Generate many codes and verify no collisions. + """ + codes = set() + num_codes = 1000 + + for _ in range(num_codes): + code = generate_invite_code() + codes.add(code) + + # All codes must be unique - any collision indicates weak randomness + assert len(codes) == num_codes + + def test_invite_code_has_sufficient_entropy(self): + """ + Invite codes must have sufficient length for security. + 32 characters of URL-safe base64 = ~192 bits of entropy. + """ + code = generate_invite_code() + + # Minimum 32 characters for adequate security + assert len(code) >= 32 + + def test_invite_code_is_url_safe(self): + """Invite codes must be safe for use in URLs without encoding.""" + import re + + code = generate_invite_code() + + # Must only contain URL-safe characters + assert re.match(r"^[A-Za-z0-9_-]+$", code) is not None + + def test_invite_codes_are_unpredictable(self): + """ + Sequential invite codes must not be predictable. + Verify no obvious patterns in consecutive codes. + """ + codes = [generate_invite_code() for _ in range(10)] + + # No two consecutive codes should share significant prefixes + for i in range(len(codes) - 1): + # First 8 chars should differ between consecutive codes + assert codes[i][:8] != codes[i + 1][:8] + + +class TestRoleRetrieval: + """Tests for role lookup functions.""" + + @pytest.mark.asyncio + async def test_get_default_role_returns_correct_role(self, mock_session): + """Default role lookup must return the role marked as default.""" + search_space_id = 1 + + mock_role = MagicMock(spec=SearchSpaceRole) + mock_role.name = "Viewer" + mock_role.is_default = True + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_role + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_default_role(mock_session, search_space_id) + + assert result is not None + assert result.is_default is True + + @pytest.mark.asyncio + async def test_get_default_role_returns_none_when_no_default(self, mock_session): + """Must return None if no default role exists - not raise an error.""" + search_space_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_default_role(mock_session, search_space_id) + + assert result is None + + @pytest.mark.asyncio + async def test_get_owner_role_returns_owner_named_role(self, mock_session): + """Owner role lookup must return the role named 'Owner'.""" + search_space_id = 1 + + mock_role = MagicMock(spec=SearchSpaceRole) + mock_role.name = "Owner" + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_role + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_owner_role(mock_session, search_space_id) + + assert result is not None + assert result.name == "Owner" diff --git a/surfsense_backend/tests/test_rbac_schemas.py b/surfsense_backend/tests/test_rbac_schemas.py new file mode 100644 index 000000000..1d4a336c2 --- /dev/null +++ b/surfsense_backend/tests/test_rbac_schemas.py @@ -0,0 +1,392 @@ +""" +Tests for RBAC schemas. + +This module tests the Pydantic schemas used for role-based access control. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from app.schemas.rbac_schemas import ( + InviteAcceptRequest, + InviteAcceptResponse, + InviteBase, + InviteCreate, + InviteInfoResponse, + InviteRead, + InviteUpdate, + MembershipBase, + MembershipRead, + MembershipReadWithUser, + MembershipUpdate, + PermissionInfo, + PermissionsListResponse, + RoleBase, + RoleCreate, + RoleRead, + RoleUpdate, + UserSearchSpaceAccess, +) + + +class TestRoleSchemas: + """Tests for role-related schemas.""" + + def test_role_base_minimal(self): + """Test RoleBase with minimal data.""" + role = RoleBase(name="TestRole") + assert role.name == "TestRole" + assert role.description is None + assert role.permissions == [] + assert role.is_default is False + + def test_role_base_full(self): + """Test RoleBase with all fields.""" + role = RoleBase( + name="Admin", + description="Administrator role", + permissions=["documents:read", "documents:write"], + is_default=True, + ) + assert role.name == "Admin" + assert role.description == "Administrator role" + assert len(role.permissions) == 2 + assert role.is_default is True + + def test_role_base_name_validation(self): + """Test RoleBase name length validation.""" + # Empty name should fail + with pytest.raises(ValidationError): + RoleBase(name="") + + # Name at max length should work + role = RoleBase(name="x" * 100) + assert len(role.name) == 100 + + # Name over max length should fail + with pytest.raises(ValidationError): + RoleBase(name="x" * 101) + + def test_role_base_description_validation(self): + """Test RoleBase description length validation.""" + # Description at max length should work + role = RoleBase(name="Test", description="x" * 500) + assert len(role.description) == 500 + + # Description over max length should fail + with pytest.raises(ValidationError): + RoleBase(name="Test", description="x" * 501) + + def test_role_create(self): + """Test RoleCreate schema.""" + role = RoleCreate( + name="Editor", + permissions=["documents:create", "documents:read"], + ) + assert role.name == "Editor" + + def test_role_update_partial(self): + """Test RoleUpdate with partial data.""" + update = RoleUpdate(name="NewName") + assert update.name == "NewName" + assert update.description is None + assert update.permissions is None + assert update.is_default is None + + def test_role_update_full(self): + """Test RoleUpdate with all fields.""" + update = RoleUpdate( + name="UpdatedRole", + description="Updated description", + permissions=["chats:read"], + is_default=True, + ) + assert update.permissions == ["chats:read"] + + def test_role_read(self): + """Test RoleRead schema.""" + now = datetime.now(timezone.utc) + role = RoleRead( + id=1, + name="Viewer", + description="View-only access", + permissions=["documents:read"], + is_default=False, + search_space_id=5, + is_system_role=True, + created_at=now, + ) + assert role.id == 1 + assert role.is_system_role is True + assert role.search_space_id == 5 + + +class TestMembershipSchemas: + """Tests for membership-related schemas.""" + + def test_membership_base(self): + """Test MembershipBase schema.""" + membership = MembershipBase() + assert membership is not None + + def test_membership_update(self): + """Test MembershipUpdate schema.""" + update = MembershipUpdate(role_id=5) + assert update.role_id == 5 + + def test_membership_update_optional(self): + """Test MembershipUpdate with no data.""" + update = MembershipUpdate() + assert update.role_id is None + + def test_membership_read(self): + """Test MembershipRead schema.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + membership = MembershipRead( + id=1, + user_id=user_id, + search_space_id=10, + role_id=2, + is_owner=False, + joined_at=now, + created_at=now, + role=None, + ) + assert membership.user_id == user_id + assert membership.search_space_id == 10 + assert membership.is_owner is False + + def test_membership_read_with_role(self): + """Test MembershipRead with nested role.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + role = RoleRead( + id=2, + name="Editor", + permissions=["documents:create"], + is_default=True, + search_space_id=10, + is_system_role=True, + created_at=now, + ) + membership = MembershipRead( + id=1, + user_id=user_id, + search_space_id=10, + role_id=2, + is_owner=False, + joined_at=now, + created_at=now, + role=role, + ) + assert membership.role.name == "Editor" + + def test_membership_read_with_user(self): + """Test MembershipReadWithUser schema.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + membership = MembershipReadWithUser( + id=1, + user_id=user_id, + search_space_id=10, + role_id=2, + is_owner=True, + joined_at=now, + created_at=now, + user_email="test@example.com", + user_is_active=True, + ) + assert membership.user_email == "test@example.com" + assert membership.user_is_active is True + + +class TestInviteSchemas: + """Tests for invite-related schemas.""" + + def test_invite_base_minimal(self): + """Test InviteBase with minimal data.""" + invite = InviteBase() + assert invite.name is None + assert invite.role_id is None + assert invite.expires_at is None + assert invite.max_uses is None + + def test_invite_base_full(self): + """Test InviteBase with all fields.""" + expires = datetime.now(timezone.utc) + invite = InviteBase( + name="Team Invite", + role_id=3, + expires_at=expires, + max_uses=10, + ) + assert invite.name == "Team Invite" + assert invite.max_uses == 10 + + def test_invite_base_max_uses_validation(self): + """Test InviteBase max_uses must be >= 1.""" + with pytest.raises(ValidationError): + InviteBase(max_uses=0) + + # Valid minimum + invite = InviteBase(max_uses=1) + assert invite.max_uses == 1 + + def test_invite_create(self): + """Test InviteCreate schema.""" + invite = InviteCreate( + name="Dev Team", + role_id=2, + max_uses=5, + ) + assert invite.name == "Dev Team" + + def test_invite_update_partial(self): + """Test InviteUpdate with partial data.""" + update = InviteUpdate(is_active=False) + assert update.is_active is False + assert update.name is None + + def test_invite_update_full(self): + """Test InviteUpdate with all fields.""" + expires = datetime.now(timezone.utc) + update = InviteUpdate( + name="Updated Invite", + role_id=4, + expires_at=expires, + max_uses=20, + is_active=True, + ) + assert update.name == "Updated Invite" + + def test_invite_read(self): + """Test InviteRead schema.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + invite = InviteRead( + id=1, + invite_code="abc123xyz", + search_space_id=5, + created_by_id=user_id, + uses_count=3, + is_active=True, + created_at=now, + ) + assert invite.invite_code == "abc123xyz" + assert invite.uses_count == 3 + + def test_invite_accept_request(self): + """Test InviteAcceptRequest schema.""" + request = InviteAcceptRequest(invite_code="valid-code-123") + assert request.invite_code == "valid-code-123" + + def test_invite_accept_request_validation(self): + """Test InviteAcceptRequest requires non-empty code.""" + with pytest.raises(ValidationError): + InviteAcceptRequest(invite_code="") + + def test_invite_accept_response(self): + """Test InviteAcceptResponse schema.""" + response = InviteAcceptResponse( + message="Successfully joined", + search_space_id=10, + search_space_name="My Workspace", + role_name="Editor", + ) + assert response.message == "Successfully joined" + assert response.search_space_name == "My Workspace" + + def test_invite_info_response(self): + """Test InviteInfoResponse schema.""" + response = InviteInfoResponse( + search_space_name="Public Space", + role_name="Viewer", + is_valid=True, + message=None, + ) + assert response.is_valid is True + + def test_invite_info_response_invalid(self): + """Test InviteInfoResponse for invalid invite.""" + response = InviteInfoResponse( + search_space_name="", + role_name=None, + is_valid=False, + message="Invite has expired", + ) + assert response.is_valid is False + assert response.message == "Invite has expired" + + +class TestPermissionSchemas: + """Tests for permission-related schemas.""" + + def test_permission_info(self): + """Test PermissionInfo schema.""" + perm = PermissionInfo( + value="documents:create", + name="Create Documents", + category="Documents", + ) + assert perm.value == "documents:create" + assert perm.category == "Documents" + + def test_permissions_list_response(self): + """Test PermissionsListResponse schema.""" + perms = [ + PermissionInfo(value="documents:read", name="Read Documents", category="Documents"), + PermissionInfo(value="chats:read", name="Read Chats", category="Chats"), + ] + response = PermissionsListResponse(permissions=perms) + assert len(response.permissions) == 2 + + def test_permissions_list_response_empty(self): + """Test PermissionsListResponse with empty list.""" + response = PermissionsListResponse(permissions=[]) + assert response.permissions == [] + + +class TestUserAccessSchemas: + """Tests for user access schemas.""" + + def test_user_search_space_access(self): + """Test UserSearchSpaceAccess schema.""" + access = UserSearchSpaceAccess( + search_space_id=5, + search_space_name="My Workspace", + is_owner=True, + role_name="Owner", + permissions=["*"], + ) + assert access.search_space_id == 5 + assert access.is_owner is True + assert "*" in access.permissions + + def test_user_search_space_access_member(self): + """Test UserSearchSpaceAccess for regular member.""" + access = UserSearchSpaceAccess( + search_space_id=10, + search_space_name="Team Space", + is_owner=False, + role_name="Editor", + permissions=["documents:create", "documents:read", "chats:create"], + ) + assert access.is_owner is False + assert access.role_name == "Editor" + assert len(access.permissions) == 3 + + def test_user_search_space_access_no_role(self): + """Test UserSearchSpaceAccess with no role.""" + access = UserSearchSpaceAccess( + search_space_id=15, + search_space_name="Guest Space", + is_owner=False, + role_name=None, + permissions=[], + ) + assert access.role_name is None + assert access.permissions == [] diff --git a/surfsense_backend/tests/test_rbac_utils.py b/surfsense_backend/tests/test_rbac_utils.py new file mode 100644 index 000000000..193baada6 --- /dev/null +++ b/surfsense_backend/tests/test_rbac_utils.py @@ -0,0 +1,340 @@ +""" +Tests for RBAC utility functions. + +This module tests the RBAC helper functions used for access control. +""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from app.db import Permission +from app.utils.rbac import ( + check_permission, + check_search_space_access, + generate_invite_code, + get_user_membership, + get_user_permissions, + is_search_space_owner, +) + + +class TestGenerateInviteCode: + """Tests for generate_invite_code function.""" + + def test_generates_string(self): + """Test that function generates a string.""" + code = generate_invite_code() + assert isinstance(code, str) + + def test_generates_unique_codes(self): + """Test that function generates unique codes.""" + codes = {generate_invite_code() for _ in range(100)} + assert len(codes) == 100 # All unique + + def test_code_is_url_safe(self): + """Test that generated code is URL-safe.""" + code = generate_invite_code() + # URL-safe characters: alphanumeric, hyphen, underscore + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in code) + + def test_code_length(self): + """Test that generated code has expected length.""" + code = generate_invite_code() + # token_urlsafe(24) produces ~32 characters + assert len(code) == 32 + + +class TestGetUserMembership: + """Tests for get_user_membership function.""" + + @pytest.mark.asyncio + async def test_returns_membership(self): + """Test returns membership when found.""" + mock_membership = MagicMock() + mock_membership.is_owner = True + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_membership + + mock_session = AsyncMock() + mock_session.execute.return_value = mock_result + + user_id = uuid4() + result = await get_user_membership(mock_session, user_id, 1) + + assert result == mock_membership + assert result.is_owner is True + + @pytest.mark.asyncio + async def test_returns_none_when_not_found(self): + """Test returns None when membership not found.""" + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + + mock_session = AsyncMock() + mock_session.execute.return_value = mock_result + + user_id = uuid4() + result = await get_user_membership(mock_session, user_id, 999) + + assert result is None + + +class TestGetUserPermissions: + """Tests for get_user_permissions function.""" + + @pytest.mark.asyncio + async def test_owner_has_full_access(self): + """Test owner gets FULL_ACCESS permission.""" + mock_membership = MagicMock() + mock_membership.is_owner = True + mock_membership.role = None + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + user_id = uuid4() + + permissions = await get_user_permissions(mock_session, user_id, 1) + + assert Permission.FULL_ACCESS.value in permissions + + @pytest.mark.asyncio + async def test_member_gets_role_permissions(self): + """Test member gets permissions from their role.""" + mock_role = MagicMock() + mock_role.permissions = ["documents:read", "chats:create"] + + mock_membership = MagicMock() + mock_membership.is_owner = False + mock_membership.role = mock_role + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + user_id = uuid4() + + permissions = await get_user_permissions(mock_session, user_id, 1) + + assert permissions == ["documents:read", "chats:create"] + + @pytest.mark.asyncio + async def test_no_membership_returns_empty(self): + """Test no membership returns empty permissions.""" + with patch("app.utils.rbac.get_user_membership", return_value=None): + mock_session = AsyncMock() + user_id = uuid4() + + permissions = await get_user_permissions(mock_session, user_id, 1) + + assert permissions == [] + + @pytest.mark.asyncio + async def test_no_role_returns_empty(self): + """Test member without role returns empty permissions.""" + mock_membership = MagicMock() + mock_membership.is_owner = False + mock_membership.role = None + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + user_id = uuid4() + + permissions = await get_user_permissions(mock_session, user_id, 1) + + assert permissions == [] + + +class TestCheckPermission: + """Tests for check_permission function.""" + + @pytest.mark.asyncio + async def test_owner_passes_any_permission(self): + """Test owner passes any permission check.""" + mock_membership = MagicMock() + mock_membership.is_owner = True + mock_membership.role = None + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + result = await check_permission( + mock_session, + mock_user, + 1, + Permission.SETTINGS_DELETE.value, + ) + + assert result == mock_membership + + @pytest.mark.asyncio + async def test_member_with_permission_passes(self): + """Test member with required permission passes.""" + mock_role = MagicMock() + mock_role.permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value] + + mock_membership = MagicMock() + mock_membership.is_owner = False + mock_membership.role = mock_role + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + result = await check_permission( + mock_session, + mock_user, + 1, + Permission.DOCUMENTS_READ.value, + ) + + assert result == mock_membership + + @pytest.mark.asyncio + async def test_member_without_permission_raises(self): + """Test member without required permission raises HTTPException.""" + mock_role = MagicMock() + mock_role.permissions = [Permission.DOCUMENTS_READ.value] + + mock_membership = MagicMock() + mock_membership.is_owner = False + mock_membership.role = mock_role + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + await check_permission( + mock_session, + mock_user, + 1, + Permission.DOCUMENTS_DELETE.value, + ) + + assert exc_info.value.status_code == 403 + + @pytest.mark.asyncio + async def test_no_membership_raises(self): + """Test user without membership raises HTTPException.""" + with patch("app.utils.rbac.get_user_membership", return_value=None): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + await check_permission( + mock_session, + mock_user, + 1, + Permission.DOCUMENTS_READ.value, + ) + + assert exc_info.value.status_code == 403 + assert "access to this search space" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_custom_error_message(self): + """Test custom error message is used.""" + mock_role = MagicMock() + mock_role.permissions = [] + + mock_membership = MagicMock() + mock_membership.is_owner = False + mock_membership.role = mock_role + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + await check_permission( + mock_session, + mock_user, + 1, + Permission.DOCUMENTS_DELETE.value, + error_message="Custom error message", + ) + + assert exc_info.value.detail == "Custom error message" + + +class TestCheckSearchSpaceAccess: + """Tests for check_search_space_access function.""" + + @pytest.mark.asyncio + async def test_member_has_access(self): + """Test member with any membership has access.""" + mock_membership = MagicMock() + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + result = await check_search_space_access(mock_session, mock_user, 1) + + assert result == mock_membership + + @pytest.mark.asyncio + async def test_no_membership_raises(self): + """Test user without membership raises HTTPException.""" + with patch("app.utils.rbac.get_user_membership", return_value=None): + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + await check_search_space_access(mock_session, mock_user, 1) + + assert exc_info.value.status_code == 403 + + +class TestIsSearchSpaceOwner: + """Tests for is_search_space_owner function.""" + + @pytest.mark.asyncio + async def test_returns_true_for_owner(self): + """Test returns True when user is owner.""" + mock_membership = MagicMock() + mock_membership.is_owner = True + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + user_id = uuid4() + + result = await is_search_space_owner(mock_session, user_id, 1) + + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_for_non_owner(self): + """Test returns False when user is not owner.""" + mock_membership = MagicMock() + mock_membership.is_owner = False + + with patch("app.utils.rbac.get_user_membership", return_value=mock_membership): + mock_session = AsyncMock() + user_id = uuid4() + + result = await is_search_space_owner(mock_session, user_id, 1) + + assert result is False + + @pytest.mark.asyncio + async def test_returns_false_for_no_membership(self): + """Test returns False when user has no membership.""" + with patch("app.utils.rbac.get_user_membership", return_value=None): + mock_session = AsyncMock() + user_id = uuid4() + + result = await is_search_space_owner(mock_session, user_id, 1) + + assert result is False diff --git a/surfsense_backend/tests/test_schemas.py b/surfsense_backend/tests/test_schemas.py new file mode 100644 index 000000000..1aabc63bd --- /dev/null +++ b/surfsense_backend/tests/test_schemas.py @@ -0,0 +1,569 @@ +""" +Tests for Pydantic schema models. + +This module tests schema validation, serialization, and deserialization +for all schema models used in the application. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from app.db import ChatType, DocumentType, LiteLLMProvider +from app.schemas.base import IDModel, TimestampModel +from app.schemas.chats import ( + AISDKChatRequest, + ChatBase, + ChatCreate, + ChatRead, + ChatReadWithoutMessages, + ChatUpdate, + ClientAttachment, + ToolInvocation, +) +from app.schemas.chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate +from app.schemas.documents import ( + DocumentBase, + DocumentRead, + DocumentsCreate, + DocumentUpdate, + DocumentWithChunksRead, + ExtensionDocumentContent, + ExtensionDocumentMetadata, + PaginatedResponse, +) +from app.schemas.llm_config import ( + LLMConfigBase, + LLMConfigCreate, + LLMConfigRead, + LLMConfigUpdate, +) +from app.schemas.search_space import ( + SearchSpaceBase, + SearchSpaceCreate, + SearchSpaceRead, + SearchSpaceUpdate, + SearchSpaceWithStats, +) + + +class TestBaseSchemas: + """Tests for base schema models.""" + + def test_timestamp_model(self): + """Test TimestampModel with valid datetime.""" + now = datetime.now(timezone.utc) + model = TimestampModel(created_at=now) + assert model.created_at == now + + def test_id_model(self): + """Test IDModel with valid ID.""" + model = IDModel(id=1) + assert model.id == 1 + + def test_id_model_with_zero(self): + """Test IDModel accepts zero.""" + model = IDModel(id=0) + assert model.id == 0 + + +class TestChatSchemas: + """Tests for chat-related schema models.""" + + def test_chat_base_valid(self): + """Test ChatBase with valid data.""" + chat = ChatBase( + type=ChatType.QNA, + title="Test Chat", + messages=[{"role": "user", "content": "Hello"}], + search_space_id=1, + ) + assert chat.type == ChatType.QNA + assert chat.title == "Test Chat" + assert chat.search_space_id == 1 + assert chat.state_version == 1 + + def test_chat_base_with_connectors(self): + """Test ChatBase with initial connectors.""" + chat = ChatBase( + type=ChatType.QNA, + title="Test Chat", + initial_connectors=["slack", "notion"], + messages=[], + search_space_id=1, + ) + assert chat.initial_connectors == ["slack", "notion"] + + def test_chat_base_default_state_version(self): + """Test ChatBase default state_version.""" + chat = ChatBase( + type=ChatType.QNA, + title="Test Chat", + messages=[], + search_space_id=1, + ) + assert chat.state_version == 1 + + def test_chat_create(self): + """Test ChatCreate schema.""" + chat = ChatCreate( + type=ChatType.QNA, + title="New Chat", + messages=[{"role": "user", "content": "Test"}], + search_space_id=1, + ) + assert chat.title == "New Chat" + + def test_chat_update(self): + """Test ChatUpdate schema.""" + chat = ChatUpdate( + type=ChatType.QNA, + title="Updated Chat", + messages=[{"role": "user", "content": "Updated"}], + search_space_id=1, + state_version=2, + ) + assert chat.state_version == 2 + + def test_chat_read(self): + """Test ChatRead schema.""" + now = datetime.now(timezone.utc) + chat = ChatRead( + id=1, + type=ChatType.QNA, + title="Read Chat", + messages=[], + search_space_id=1, + created_at=now, + ) + assert chat.id == 1 + assert chat.created_at == now + + def test_chat_read_without_messages(self): + """Test ChatReadWithoutMessages schema.""" + now = datetime.now(timezone.utc) + chat = ChatReadWithoutMessages( + id=1, + type=ChatType.QNA, + title="Chat Without Messages", + search_space_id=1, + created_at=now, + ) + assert chat.id == 1 + assert not hasattr(chat, "messages") or "messages" not in chat.model_fields + + def test_client_attachment(self): + """Test ClientAttachment schema.""" + attachment = ClientAttachment( + name="test.pdf", + content_type="application/pdf", + url="https://example.com/test.pdf", + ) + assert attachment.name == "test.pdf" + assert attachment.content_type == "application/pdf" + + def test_tool_invocation(self): + """Test ToolInvocation schema.""" + tool = ToolInvocation( + tool_call_id="tc_123", + tool_name="search", + args={"query": "test"}, + result={"results": []}, + ) + assert tool.tool_call_id == "tc_123" + assert tool.tool_name == "search" + + def test_aisdk_chat_request(self): + """Test AISDKChatRequest schema.""" + request = AISDKChatRequest( + messages=[{"role": "user", "content": "Hello"}], + data={"search_space_id": 1}, + ) + assert len(request.messages) == 1 + assert request.data["search_space_id"] == 1 + + def test_aisdk_chat_request_no_data(self): + """Test AISDKChatRequest without data.""" + request = AISDKChatRequest(messages=[{"role": "user", "content": "Hello"}]) + assert request.data is None + + +class TestChunkSchemas: + """Tests for chunk-related schema models.""" + + def test_chunk_base(self): + """Test ChunkBase schema.""" + chunk = ChunkBase(content="Test content", document_id=1) + assert chunk.content == "Test content" + assert chunk.document_id == 1 + + def test_chunk_create(self): + """Test ChunkCreate schema.""" + chunk = ChunkCreate(content="New chunk content", document_id=1) + assert chunk.content == "New chunk content" + + def test_chunk_update(self): + """Test ChunkUpdate schema.""" + chunk = ChunkUpdate(content="Updated content", document_id=1) + assert chunk.content == "Updated content" + + def test_chunk_read(self): + """Test ChunkRead schema.""" + now = datetime.now(timezone.utc) + chunk = ChunkRead( + id=1, + content="Read chunk", + document_id=1, + created_at=now, + ) + assert chunk.id == 1 + assert chunk.created_at == now + + +class TestDocumentSchemas: + """Tests for document-related schema models.""" + + def test_extension_document_metadata(self): + """Test ExtensionDocumentMetadata schema.""" + metadata = ExtensionDocumentMetadata( + BrowsingSessionId="session123", + VisitedWebPageURL="https://example.com", + VisitedWebPageTitle="Example Page", + VisitedWebPageDateWithTimeInISOString="2024-01-01T00:00:00Z", + VisitedWebPageReffererURL="https://google.com", + VisitedWebPageVisitDurationInMilliseconds="5000", + ) + assert metadata.BrowsingSessionId == "session123" + assert metadata.VisitedWebPageURL == "https://example.com" + + def test_extension_document_content(self): + """Test ExtensionDocumentContent schema.""" + metadata = ExtensionDocumentMetadata( + BrowsingSessionId="session123", + VisitedWebPageURL="https://example.com", + VisitedWebPageTitle="Example Page", + VisitedWebPageDateWithTimeInISOString="2024-01-01T00:00:00Z", + VisitedWebPageReffererURL="https://google.com", + VisitedWebPageVisitDurationInMilliseconds="5000", + ) + content = ExtensionDocumentContent( + metadata=metadata, + pageContent="This is the page content", + ) + assert content.pageContent == "This is the page content" + assert content.metadata.VisitedWebPageTitle == "Example Page" + + def test_document_base_with_string_content(self): + """Test DocumentBase with string content.""" + doc = DocumentBase( + document_type=DocumentType.FILE, + content="This is document content", + search_space_id=1, + ) + assert doc.content == "This is document content" + + def test_document_base_with_list_content(self): + """Test DocumentBase with list content.""" + doc = DocumentBase( + document_type=DocumentType.FILE, + content=["Part 1", "Part 2"], + search_space_id=1, + ) + assert len(doc.content) == 2 + + def test_documents_create(self): + """Test DocumentsCreate schema.""" + doc = DocumentsCreate( + document_type=DocumentType.CRAWLED_URL, + content="Crawled content", + search_space_id=1, + ) + assert doc.document_type == DocumentType.CRAWLED_URL + + def test_document_update(self): + """Test DocumentUpdate schema.""" + doc = DocumentUpdate( + document_type=DocumentType.FILE, + content="Updated content", + search_space_id=1, + ) + assert doc.content == "Updated content" + + def test_document_read(self): + """Test DocumentRead schema.""" + now = datetime.now(timezone.utc) + doc = DocumentRead( + id=1, + title="Test Document", + document_type=DocumentType.FILE, + document_metadata={"key": "value"}, + content="Content", + created_at=now, + search_space_id=1, + ) + assert doc.id == 1 + assert doc.title == "Test Document" + assert doc.document_metadata["key"] == "value" + + def test_document_with_chunks_read(self): + """Test DocumentWithChunksRead schema.""" + now = datetime.now(timezone.utc) + doc = DocumentWithChunksRead( + id=1, + title="Test Document", + document_type=DocumentType.FILE, + document_metadata={}, + content="Content", + created_at=now, + search_space_id=1, + chunks=[ + ChunkRead(id=1, content="Chunk 1", document_id=1, created_at=now), + ChunkRead(id=2, content="Chunk 2", document_id=1, created_at=now), + ], + ) + assert len(doc.chunks) == 2 + + def test_paginated_response(self): + """Test PaginatedResponse schema.""" + response = PaginatedResponse[dict]( + items=[{"id": 1}, {"id": 2}], + total=10, + ) + assert len(response.items) == 2 + assert response.total == 10 + + +class TestLLMConfigSchemas: + """Tests for LLM config schema models.""" + + def test_llm_config_base(self): + """Test LLMConfigBase schema.""" + config = LLMConfigBase( + name="GPT-4 Config", + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="sk-test123", + ) + assert config.name == "GPT-4 Config" + assert config.provider == LiteLLMProvider.OPENAI + assert config.language == "English" # Default value + + def test_llm_config_base_with_custom_provider(self): + """Test LLMConfigBase with custom provider.""" + config = LLMConfigBase( + name="Custom LLM", + provider=LiteLLMProvider.CUSTOM, + custom_provider="my-provider", + model_name="my-model", + api_key="test-key", + api_base="https://my-api.com/v1", + ) + assert config.custom_provider == "my-provider" + assert config.api_base == "https://my-api.com/v1" + + def test_llm_config_base_with_litellm_params(self): + """Test LLMConfigBase with litellm params.""" + config = LLMConfigBase( + name="Config with Params", + provider=LiteLLMProvider.ANTHROPIC, + model_name="claude-3-opus", + api_key="test-key", + litellm_params={"temperature": 0.7, "max_tokens": 1000}, + ) + assert config.litellm_params["temperature"] == 0.7 + + def test_llm_config_create(self): + """Test LLMConfigCreate schema.""" + config = LLMConfigCreate( + name="New Config", + provider=LiteLLMProvider.GROQ, + model_name="llama-3", + api_key="gsk-test", + search_space_id=1, + ) + assert config.search_space_id == 1 + + def test_llm_config_update_partial(self): + """Test LLMConfigUpdate with partial data.""" + update = LLMConfigUpdate(name="Updated Name") + assert update.name == "Updated Name" + assert update.provider is None + assert update.model_name is None + + def test_llm_config_update_full(self): + """Test LLMConfigUpdate with full data.""" + update = LLMConfigUpdate( + name="Full Update", + provider=LiteLLMProvider.MISTRAL, + model_name="mistral-large", + api_key="new-key", + language="French", + ) + assert update.language == "French" + + def test_llm_config_read(self): + """Test LLMConfigRead schema.""" + now = datetime.now(timezone.utc) + config = LLMConfigRead( + id=1, + name="Read Config", + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="sk-test", + created_at=now, + search_space_id=1, + ) + assert config.id == 1 + assert config.created_at == now + + def test_llm_config_read_global(self): + """Test LLMConfigRead for global config (no search_space_id).""" + config = LLMConfigRead( + id=-1, + name="Global Config", + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="sk-global", + created_at=None, + search_space_id=None, + ) + assert config.id == -1 + assert config.search_space_id is None + + +class TestSearchSpaceSchemas: + """Tests for search space schema models.""" + + def test_search_space_base(self): + """Test SearchSpaceBase schema.""" + space = SearchSpaceBase(name="My Search Space") + assert space.name == "My Search Space" + assert space.description is None + + def test_search_space_base_with_description(self): + """Test SearchSpaceBase with description.""" + space = SearchSpaceBase( + name="My Search Space", + description="A space for searching", + ) + assert space.description == "A space for searching" + + def test_search_space_create_defaults(self): + """Test SearchSpaceCreate with default values.""" + space = SearchSpaceCreate(name="New Space") + assert space.citations_enabled is True + assert space.qna_custom_instructions is None + + def test_search_space_create_custom(self): + """Test SearchSpaceCreate with custom values.""" + space = SearchSpaceCreate( + name="Custom Space", + description="Custom description", + citations_enabled=False, + qna_custom_instructions="Be concise", + ) + assert space.citations_enabled is False + assert space.qna_custom_instructions == "Be concise" + + def test_search_space_update_partial(self): + """Test SearchSpaceUpdate with partial data.""" + update = SearchSpaceUpdate(name="Updated Name") + assert update.name == "Updated Name" + assert update.description is None + assert update.citations_enabled is None + + def test_search_space_update_full(self): + """Test SearchSpaceUpdate with all fields.""" + update = SearchSpaceUpdate( + name="Full Update", + description="New description", + citations_enabled=True, + qna_custom_instructions="New instructions", + ) + assert update.qna_custom_instructions == "New instructions" + + def test_search_space_read(self): + """Test SearchSpaceRead schema.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + space = SearchSpaceRead( + id=1, + name="Read Space", + description="Description", + created_at=now, + user_id=user_id, + citations_enabled=True, + qna_custom_instructions=None, + ) + assert space.id == 1 + assert space.user_id == user_id + + def test_search_space_with_stats(self): + """Test SearchSpaceWithStats schema.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + space = SearchSpaceWithStats( + id=1, + name="Space with Stats", + created_at=now, + user_id=user_id, + citations_enabled=True, + member_count=5, + is_owner=True, + ) + assert space.member_count == 5 + assert space.is_owner is True + + def test_search_space_with_stats_defaults(self): + """Test SearchSpaceWithStats default values.""" + now = datetime.now(timezone.utc) + user_id = uuid4() + space = SearchSpaceWithStats( + id=1, + name="Default Stats Space", + created_at=now, + user_id=user_id, + citations_enabled=True, + ) + assert space.member_count == 1 + assert space.is_owner is False + + +class TestSchemaValidation: + """Tests for schema validation errors.""" + + def test_chat_base_missing_required(self): + """Test ChatBase raises error for missing required fields.""" + with pytest.raises(ValidationError): + ChatBase(type=ChatType.QNA, title="Test") # Missing messages and search_space_id + + def test_llm_config_name_too_long(self): + """Test LLMConfigBase validates name length.""" + with pytest.raises(ValidationError): + LLMConfigBase( + name="x" * 101, # Exceeds max_length of 100 + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="test", + ) + + def test_llm_config_model_name_too_long(self): + """Test LLMConfigBase validates model_name length.""" + with pytest.raises(ValidationError): + LLMConfigBase( + name="Valid Name", + provider=LiteLLMProvider.OPENAI, + model_name="x" * 101, # Exceeds max_length of 100 + api_key="test", + ) + + def test_document_read_missing_required(self): + """Test DocumentRead raises error for missing required fields.""" + with pytest.raises(ValidationError): + DocumentRead( + id=1, + title="Test", + # Missing document_type, document_metadata, content, created_at, search_space_id + ) diff --git a/surfsense_backend/tests/test_validators.py b/surfsense_backend/tests/test_validators.py new file mode 100644 index 000000000..8baa690d4 --- /dev/null +++ b/surfsense_backend/tests/test_validators.py @@ -0,0 +1,441 @@ +""" +Tests for the validators module. +""" + +import pytest +from fastapi import HTTPException + +from app.utils.validators import ( + validate_connectors, + validate_document_ids, + validate_email, + validate_messages, + validate_research_mode, + validate_search_mode, + validate_search_space_id, + validate_top_k, + validate_url, + validate_uuid, +) + + +class TestValidateSearchSpaceId: + """Tests for validate_search_space_id function.""" + + def test_valid_integer(self): + """Test valid integer input.""" + assert validate_search_space_id(1) == 1 + assert validate_search_space_id(100) == 100 + assert validate_search_space_id(999999) == 999999 + + def test_valid_string(self): + """Test valid string input.""" + assert validate_search_space_id("1") == 1 + assert validate_search_space_id("100") == 100 + assert validate_search_space_id(" 50 ") == 50 # Trimmed + + def test_none_raises_error(self): + """Test that None raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id(None) + assert exc_info.value.status_code == 400 + assert "required" in exc_info.value.detail + + def test_zero_raises_error(self): + """Test that zero raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id(0) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_negative_raises_error(self): + """Test that negative values raise HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id(-1) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_boolean_raises_error(self): + """Test that boolean raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id(True) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + def test_empty_string_raises_error(self): + """Test that empty string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id("") + assert exc_info.value.status_code == 400 + + def test_invalid_string_raises_error(self): + """Test that invalid string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id("abc") + assert exc_info.value.status_code == 400 + + def test_float_raises_error(self): + """Test that float raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_space_id(1.5) + assert exc_info.value.status_code == 400 + + +class TestValidateDocumentIds: + """Tests for validate_document_ids function.""" + + def test_none_returns_empty_list(self): + """Test that None returns empty list.""" + assert validate_document_ids(None) == [] + + def test_empty_list_returns_empty_list(self): + """Test that empty list returns empty list.""" + assert validate_document_ids([]) == [] + + def test_valid_integer_list(self): + """Test valid integer list.""" + assert validate_document_ids([1, 2, 3]) == [1, 2, 3] + + def test_valid_string_list(self): + """Test valid string list.""" + assert validate_document_ids(["1", "2", "3"]) == [1, 2, 3] + + def test_mixed_valid_types(self): + """Test mixed valid types.""" + assert validate_document_ids([1, "2", 3]) == [1, 2, 3] + + def test_not_list_raises_error(self): + """Test that non-list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_document_ids("not a list") + assert exc_info.value.status_code == 400 + assert "must be a list" in exc_info.value.detail + + def test_negative_id_raises_error(self): + """Test that negative ID raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_document_ids([1, -2, 3]) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_zero_id_raises_error(self): + """Test that zero ID raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_document_ids([0]) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_boolean_in_list_raises_error(self): + """Test that boolean in list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_document_ids([1, True, 3]) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + +class TestValidateConnectors: + """Tests for validate_connectors function.""" + + def test_none_returns_empty_list(self): + """Test that None returns empty list.""" + assert validate_connectors(None) == [] + + def test_empty_list_returns_empty_list(self): + """Test that empty list returns empty list.""" + assert validate_connectors([]) == [] + + def test_valid_connectors(self): + """Test valid connector names.""" + assert validate_connectors(["slack", "github"]) == ["slack", "github"] + + def test_connector_with_underscore(self): + """Test connector names with underscores.""" + assert validate_connectors(["google_calendar"]) == ["google_calendar"] + + def test_connector_with_hyphen(self): + """Test connector names with hyphens.""" + assert validate_connectors(["google-calendar"]) == ["google-calendar"] + + def test_not_list_raises_error(self): + """Test that non-list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_connectors("not a list") + assert exc_info.value.status_code == 400 + assert "must be a list" in exc_info.value.detail + + def test_non_string_in_list_raises_error(self): + """Test that non-string in list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_connectors(["slack", 123]) + assert exc_info.value.status_code == 400 + assert "must be a string" in exc_info.value.detail + + def test_empty_string_raises_error(self): + """Test that empty string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_connectors(["slack", ""]) + assert exc_info.value.status_code == 400 + assert "cannot be empty" in exc_info.value.detail + + def test_invalid_characters_raises_error(self): + """Test that invalid characters raise HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_connectors(["slack@connector"]) + assert exc_info.value.status_code == 400 + assert "invalid characters" in exc_info.value.detail + + +class TestValidateResearchMode: + """Tests for validate_research_mode function.""" + + def test_none_returns_default(self): + """Test that None returns default value.""" + assert validate_research_mode(None) == "QNA" + + def test_valid_mode(self): + """Test valid mode.""" + assert validate_research_mode("QNA") == "QNA" + assert validate_research_mode("qna") == "QNA" # Case insensitive + + def test_non_string_raises_error(self): + """Test that non-string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_research_mode(123) + assert exc_info.value.status_code == 400 + assert "must be a string" in exc_info.value.detail + + def test_invalid_mode_raises_error(self): + """Test that invalid mode raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_research_mode("INVALID") + assert exc_info.value.status_code == 400 + assert "must be one of" in exc_info.value.detail + + +class TestValidateSearchMode: + """Tests for validate_search_mode function.""" + + def test_none_returns_default(self): + """Test that None returns default value.""" + assert validate_search_mode(None) == "CHUNKS" + + def test_valid_modes(self): + """Test valid modes.""" + assert validate_search_mode("CHUNKS") == "CHUNKS" + assert validate_search_mode("DOCUMENTS") == "DOCUMENTS" + assert validate_search_mode("chunks") == "CHUNKS" # Case insensitive + + def test_non_string_raises_error(self): + """Test that non-string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_mode(123) + assert exc_info.value.status_code == 400 + assert "must be a string" in exc_info.value.detail + + def test_invalid_mode_raises_error(self): + """Test that invalid mode raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_search_mode("INVALID") + assert exc_info.value.status_code == 400 + assert "must be one of" in exc_info.value.detail + + +class TestValidateTopK: + """Tests for validate_top_k function.""" + + def test_none_returns_default(self): + """Test that None returns default value.""" + assert validate_top_k(None) == 10 + + def test_valid_integer(self): + """Test valid integer input.""" + assert validate_top_k(1) == 1 + assert validate_top_k(50) == 50 + assert validate_top_k(100) == 100 + + def test_valid_string(self): + """Test valid string input.""" + assert validate_top_k("5") == 5 + assert validate_top_k(" 10 ") == 10 + + def test_zero_raises_error(self): + """Test that zero raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_top_k(0) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_negative_raises_error(self): + """Test that negative values raise HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_top_k(-1) + assert exc_info.value.status_code == 400 + assert "positive" in exc_info.value.detail + + def test_exceeds_max_raises_error(self): + """Test that values over 100 raise HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_top_k(101) + assert exc_info.value.status_code == 400 + assert "exceed 100" in exc_info.value.detail + + def test_boolean_raises_error(self): + """Test that boolean raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_top_k(True) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + +class TestValidateMessages: + """Tests for validate_messages function.""" + + def test_valid_messages(self): + """Test valid messages.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = validate_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "user" + assert result[1]["role"] == "assistant" + + def test_trims_content(self): + """Test that content is trimmed.""" + messages = [{"role": "user", "content": " Hello "}] + result = validate_messages(messages) + assert result[0]["content"] == "Hello" + + def test_system_message_valid(self): + """Test that system messages are valid.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + result = validate_messages(messages) + assert result[0]["role"] == "system" + + def test_not_list_raises_error(self): + """Test that non-list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages("not a list") + assert exc_info.value.status_code == 400 + assert "must be a list" in exc_info.value.detail + + def test_empty_list_raises_error(self): + """Test that empty list raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages([]) + assert exc_info.value.status_code == 400 + assert "cannot be empty" in exc_info.value.detail + + def test_missing_role_raises_error(self): + """Test that missing role raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages([{"content": "Hello"}]) + assert exc_info.value.status_code == 400 + assert "role" in exc_info.value.detail + + def test_missing_content_raises_error(self): + """Test that missing content raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages([{"role": "user"}]) + assert exc_info.value.status_code == 400 + assert "content" in exc_info.value.detail + + def test_invalid_role_raises_error(self): + """Test that invalid role raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages([{"role": "invalid", "content": "Hello"}]) + assert exc_info.value.status_code == 400 + assert "role" in exc_info.value.detail + + def test_empty_content_raises_error(self): + """Test that empty content raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_messages([{"role": "user", "content": " "}]) + assert exc_info.value.status_code == 400 + assert "cannot be empty" in exc_info.value.detail + + +class TestValidateEmail: + """Tests for validate_email function.""" + + def test_valid_email(self): + """Test valid email addresses.""" + assert validate_email("test@example.com") == "test@example.com" + assert validate_email("user.name@domain.co.uk") == "user.name@domain.co.uk" + + def test_trims_whitespace(self): + """Test that whitespace is trimmed.""" + assert validate_email(" test@example.com ") == "test@example.com" + + def test_empty_raises_error(self): + """Test that empty string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_email("") + assert exc_info.value.status_code == 400 + + def test_invalid_format_raises_error(self): + """Test that invalid format raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_email("not-an-email") + assert exc_info.value.status_code == 400 + assert "Invalid email" in exc_info.value.detail + + +class TestValidateUrl: + """Tests for validate_url function.""" + + def test_valid_url(self): + """Test valid URLs.""" + assert validate_url("https://example.com") == "https://example.com" + assert ( + validate_url("http://sub.domain.com/path") + == "http://sub.domain.com/path" + ) + + def test_trims_whitespace(self): + """Test that whitespace is trimmed.""" + assert validate_url(" https://example.com ") == "https://example.com" + + def test_empty_raises_error(self): + """Test that empty string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_url("") + assert exc_info.value.status_code == 400 + + def test_invalid_format_raises_error(self): + """Test that invalid format raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_url("not-a-url") + assert exc_info.value.status_code == 400 + assert "Invalid URL" in exc_info.value.detail + + +class TestValidateUuid: + """Tests for validate_uuid function.""" + + def test_valid_uuid(self): + """Test valid UUIDs.""" + uuid_str = "123e4567-e89b-12d3-a456-426614174000" + assert validate_uuid(uuid_str) == uuid_str + + def test_trims_whitespace(self): + """Test that whitespace is trimmed.""" + uuid_str = " 123e4567-e89b-12d3-a456-426614174000 " + assert validate_uuid(uuid_str) == "123e4567-e89b-12d3-a456-426614174000" + + def test_empty_raises_error(self): + """Test that empty string raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_uuid("") + assert exc_info.value.status_code == 400 + + def test_invalid_format_raises_error(self): + """Test that invalid format raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + validate_uuid("not-a-uuid") + assert exc_info.value.status_code == 400 + assert "Invalid UUID" in exc_info.value.detail