diff --git a/src/parxy_cli/commands/parse.py b/src/parxy_cli/commands/parse.py index eace15c..af982df 100644 --- a/src/parxy_cli/commands/parse.py +++ b/src/parxy_cli/commands/parse.py @@ -1,5 +1,7 @@ """Command line interface for Parxy document processing.""" +import json +import tomllib from datetime import timedelta from pathlib import Path from typing import Optional, List, Annotated @@ -17,6 +19,75 @@ console = Console() +def _load_middleware_from_config(config_path: Path) -> List[str]: + """Load middleware class paths from a config file. + + Supports JSON, TOML, YAML and YML. The expected structure is either: + - A top-level list: ``["path.to.Middleware1", "path.to.Middleware2"]`` + - An object with a ``middleware`` key: ``{"middleware": ["path.to.Middleware"]}`` + """ + if not config_path.exists(): + raise typer.BadParameter(f'Middleware config file not found: {config_path}') + + suffix = config_path.suffix.lower() + + if suffix == '.json': + raw_data = json.loads(config_path.read_text(encoding='utf-8')) + elif suffix == '.toml': + raw_data = tomllib.loads(config_path.read_text(encoding='utf-8')) + elif suffix in {'.yaml', '.yml'}: + try: + import yaml + except ImportError as exc: + raise typer.BadParameter( + 'YAML config requires PyYAML. Install pyyaml or use JSON/TOML config.' + ) from exc + raw_data = yaml.safe_load(config_path.read_text(encoding='utf-8')) + else: + raise typer.BadParameter( + 'Unsupported middleware config format. Use .json, .toml, .yaml or .yml' + ) + + if isinstance(raw_data, list): + middleware_list = raw_data + elif isinstance(raw_data, dict): + middleware_list = raw_data.get('middleware', []) + if not isinstance(middleware_list, list): + raise typer.BadParameter( + 'middleware_config: "middleware" key must be a list of class paths.' + ) + else: + raise typer.BadParameter( + 'Middleware config must be a list or an object with a "middleware" key.' + ) + + if not all(isinstance(item, str) for item in middleware_list): + raise typer.BadParameter('Middleware class paths must be strings.') + + return middleware_list + + +def configure_middleware( + middleware: Optional[List[str]], + config_path: Optional[Path], +) -> None: + """Configure global middleware from inline class paths and/or a config file.""" + paths: List[str] = list(middleware or []) + + if config_path is not None: + paths.extend(_load_middleware_from_config(config_path)) + + if not paths: + return + + Parxy.clear_middleware() + Parxy.with_middleware(paths) + + console.info( + f'Using {len(paths)} middleware class{"es" if len(paths) != 1 else ""}.' + ) + + def collect_files_with_depth( directory: Path, pattern: str, max_depth: int, current_depth: int = 0 ) -> List[Path]: @@ -261,6 +332,22 @@ def parse( min=1, ), ] = None, + middleware: Annotated[ + Optional[List[str]], + typer.Option( + '--middleware', + '-p', + help='Middleware class path(s) to apply. Can be specified multiple times (e.g. --middleware my.pkg.MyMiddleware).', + ), + ] = None, + middleware_config: Annotated[ + Optional[str], + typer.Option( + '--middleware-config', + envvar='PARXY_MIDDLEWARE_CONFIG', + help='Path to a .json/.toml/.yaml file with a list of middleware class paths to apply. Appended after inline middleware with --middleware', + ), + ] = None, ): """ Parse documents using one or more drivers. @@ -312,6 +399,11 @@ def parse( # Calculate total tasks total_tasks = len(files) * len(drivers) + configure_middleware( + middleware=middleware, + config_path=Path(middleware_config) if middleware_config else None, + ) + error_count = 0 # Show info diff --git a/src/parxy_core/drivers/abstract_driver.py b/src/parxy_core/drivers/abstract_driver.py index 38180c6..4f476f0 100644 --- a/src/parxy_core/drivers/abstract_driver.py +++ b/src/parxy_core/drivers/abstract_driver.py @@ -1,15 +1,14 @@ -import base64 import hashlib import io import time from abc import ABC, abstractmethod from logging import Logger -from typing import Dict, Any, Self, Tuple, Optional +from typing import Dict, Any, Self, Tuple, Optional, List, Union import requests import validators -from parxy_core.models import Document +from parxy_core.models import Document, ParsingRequest from parxy_core.exceptions import ( FileNotFoundException, ParsingException, @@ -21,6 +20,7 @@ from parxy_core.models.config import BaseConfig from parxy_core.logging import create_null_logger from parxy_core.tracing import tracer +from parxy_core.middleware import Middleware class Driver(ABC): @@ -50,6 +50,9 @@ class Driver(ABC): _logger: Logger + _middleware: List[Middleware] + """Driver-specific middleware list""" + def __new__(cls, config: Dict[str, Any] = [], logger: Logger = None): instance = super().__new__(cls) instance.__init__(config=config, logger=logger) @@ -62,6 +65,7 @@ def __init__(self, config: Dict[str, Any] = None, logger: Logger = None): logger = create_null_logger(name=f'parxy.{self.__class__.__name__}') self._logger = logger + self._middleware = [] # Initialize empty middleware list self._initialize_driver() def parse( @@ -110,25 +114,30 @@ def parse( driver=self.__class__.__name__, level=level, **kwargs, - ) as span: + ): self._validate_level(level) + middleware_list = self._resolve_middleware() try: - # Start timing start_time = time.perf_counter() - document = self._handle(file=file, level=level, **kwargs) + if middleware_list: + document = self._parse_with_middleware( + file=file, + level=level, + middleware_list=middleware_list, + **kwargs, + ) + else: + document = self._handle(file=file, level=level, **kwargs) - # Calculate elapsed time in milliseconds end_time = time.perf_counter() elapsed_ms = (end_time - start_time) * 1000 - # Store elapsed time in parsing metadata if document.parsing_metadata is None: document.parsing_metadata = {} document.parsing_metadata['driver_elapsed_time'] = elapsed_ms - # Increment the documents processed counter tracer.count( 'documents.processed', description='Total documents processed by each driver', @@ -170,10 +179,138 @@ def parse( tracer.error('Parsing failed', exception=str(parxy_exc)) raise parxy_exc from ex + def _resolve_middleware(self) -> List[Middleware]: + """Resolve middleware for the current parse call. + + External middleware is applied first, then driver-specific middleware. + """ + from parxy_core.drivers.factory import DriverFactory + + combined = DriverFactory.build().get_middleware() + combined.extend(self._middleware) + return combined + + def _parse_with_middleware( + self, + file: str | io.BytesIO | bytes, + level: str, + middleware_list: List[Middleware], + **kwargs, + ) -> Document: + """Parse file with middleware chain. + + Parameters + ---------- + file : str | io.BytesIO | bytes + Path, URL or stream of the file to parse. + level : str + Desired extraction level. + middleware_list : List[Middleware] + List of middleware to apply. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + Document + The parsed document + """ + # Create parsing request + request = ParsingRequest( + driver=self.__class__.__name__, + file=file, + level=level, + config=kwargs, + ) + + with tracer.span('middleware-chain', count=len(middleware_list)): + + def call_handle(index: int, req: ParsingRequest) -> Document: + if index >= len(middleware_list): + return self._handle(file=req.file, level=req.level, **req.config) + + current_middleware = middleware_list[index] + with tracer.span( + 'middleware.handle', + middleware=current_middleware.__class__.__name__, + index=index, + ): + return current_middleware.handle( + req, lambda next_req: call_handle(index + 1, next_req) + ) + + def call_terminate(index: int, doc: Document) -> Document: + if index < 0: + return doc + + current_middleware = middleware_list[index] + with tracer.span( + 'middleware.terminate', + middleware=current_middleware.__class__.__name__, + index=index, + ): + return current_middleware.terminate( + doc, lambda next_doc: call_terminate(index - 1, next_doc) + ) + + document = call_handle(0, request) + document = call_terminate(len(middleware_list) - 1, document) + + return document + def _initialize_driver(self) -> Self: """Initialize driver internal logic. It is called automatically during class initialization""" return self + def with_middleware(self, middleware: Union[Middleware, List[Middleware]]) -> Self: + """Add middleware to this driver instance. + + Note: Drivers are singletons, so middleware added to a driver instance + persists for all subsequent uses of that driver. + + Parameters + ---------- + middleware : Union[Middleware, List[Middleware]] + A middleware instance or list of middleware instances to add + + Returns + ------- + Self + Returns self for chaining + + Example + ------- + >>> driver = Parxy.driver('pymupdf') + >>> driver.with_middleware(LoggingMiddleware()) + >>> doc = driver.parse('document.pdf') + """ + if isinstance(middleware, list): + self._middleware.extend(middleware) + else: + self._middleware.append(middleware) + return self + + def clear_middleware(self) -> Self: + """Clear all middleware from this driver instance. + + Returns + ------- + Self + Returns self for chaining + """ + self._middleware.clear() + return self + + def get_middleware(self) -> List[Middleware]: + """Get the list of middleware for this driver. + + Returns + ------- + List[Middleware] + Copy of the current middleware list + """ + return list(self._middleware) + @abstractmethod def _handle( self, diff --git a/src/parxy_core/drivers/factory.py b/src/parxy_core/drivers/factory.py index 46b02d1..b9d213e 100644 --- a/src/parxy_core/drivers/factory.py +++ b/src/parxy_core/drivers/factory.py @@ -1,6 +1,7 @@ +import importlib import logging -from typing import Dict, Optional, Callable, Self, List +from typing import Dict, Optional, Callable, Self, List, Union from parxy_core.drivers.abstract_driver import Driver from parxy_core.drivers.landingai import LandingAIADEDriver @@ -19,6 +20,7 @@ ) from parxy_core.logging import create_isolated_logger from parxy_core.tracing import tracer +from parxy_core.middleware import Middleware class DriverFactory: @@ -46,6 +48,12 @@ class DriverFactory: __custom_creators: Dict[str, Callable[[], Driver]] = {} """The custom drivers""" + __config_middleware: List[Middleware] = [] + """Middleware loaded from ParxyConfig — preserved across clear_middleware() calls.""" + + __middleware: List[Middleware] = [] + """Runtime middleware added programmatically via with_middleware().""" + _config: Optional[ParxyConfig] = None _logger: logging.Logger = None @@ -68,7 +76,16 @@ def build(cls) -> 'DriverFactory': @classmethod def reset(cls): + """Reset the factory instance and clear all state. + + This clears middleware, drivers, and custom creators. + Useful for testing and isolation between test cases. + """ cls.__instance = None + cls.__config_middleware = [] + cls.__middleware = [] + cls.__drivers = {} + cls.__custom_creators = {} def initialize(self, config: ParxyConfig) -> Self: self._config = config @@ -89,8 +106,62 @@ def initialize(self, config: ParxyConfig) -> Self: or self._config.tracing.verbose, ) + # Load middleware from configuration + self._load_middleware_from_config() + return self + def _load_middleware_from_config(self) -> None: + """Load middleware from ParxyConfig.middleware into the config layer. + + Config middleware is kept separate from runtime middleware so it + survives clear_middleware() calls. + """ + if not self._config.middleware: + return + + for middleware_path in self._config.middleware: + try: + middleware = self._import_middleware(middleware_path) + self.__config_middleware.append(middleware) + self._logger.info(f'Loaded middleware from config: {middleware_path}') + except (ImportError, ValueError) as e: + self._logger.warning( + f'Failed to load middleware from config: {middleware_path} - {e}' + ) + + def _import_middleware(self, middleware_path: str) -> Middleware: + """Import a middleware class from a string path. + + Parameters + ---------- + middleware_path : str + Dot-notation path to the middleware class (e.g., 'parxy_core.middleware.PIIScanner') + + Returns + ------- + Middleware + An instance of the middleware class + + Raises + ------ + ImportError + If the module or class cannot be imported + ValueError + If the imported object is not a Middleware subclass + """ + try: + module_path, class_name = middleware_path.rsplit('.', 1) + module = importlib.import_module(module_path) + middleware_class = getattr(module, class_name) + + if not issubclass(middleware_class, Middleware): + raise ValueError(f'{middleware_path} is not a Middleware subclass') + + return middleware_class() + except (ImportError, AttributeError) as e: + raise ImportError(f'Failed to import middleware: {middleware_path}') from e + def driver(self, name: str = None) -> Driver: """Get a driver instance. @@ -245,6 +316,60 @@ def extend(self, name: str, callback: Callable[[], Driver]) -> 'DriverFactory': return self + def with_middleware( + self, middleware: Union[Middleware, List[Middleware]] + ) -> 'DriverFactory': + """Add middleware to the global middleware registry. + + Middleware added here will be applied to all drivers. + + Parameters + ---------- + middleware : Union[Middleware, List[Middleware]] + A middleware instance or list of middleware instances to add + + Returns + ------- + DriverFactory + Returns self for chaining + + Example + ------- + >>> factory = DriverFactory.build() + >>> factory.with_middleware([LoggingMiddleware(), PIIScannerMiddleware()]) + """ + if isinstance(middleware, list): + self.__middleware.extend(middleware) + else: + self.__middleware.append(middleware) + + return self + + def clear_middleware(self) -> 'DriverFactory': + """Clear runtime middleware. + + Middleware loaded from ``ParxyConfig.middleware`` is preserved. + Only middleware added via :meth:`with_middleware` is removed. + + Returns + ------- + DriverFactory + Returns self for chaining + """ + self.__middleware.clear() + + return self + + def get_middleware(self) -> List[Middleware]: + """Get the combined middleware list (config layer + runtime layer). + + Returns + ------- + List[Middleware] + Copy of the current middleware list, config entries first. + """ + return list(self.__config_middleware) + list(self.__middleware) + def get_drivers(self) -> Dict[str, Driver]: """Get all of the created "drivers". diff --git a/src/parxy_core/facade/parxy.py b/src/parxy_core/facade/parxy.py index ddfe290..8498c3e 100644 --- a/src/parxy_core/facade/parxy.py +++ b/src/parxy_core/facade/parxy.py @@ -1,5 +1,6 @@ """Facade for accessing Parxy document parsing functionality.""" +import importlib import io import os import threading @@ -13,6 +14,7 @@ from parxy_core.models import Document, BatchTask, BatchResult from parxy_core.models.config import ParxyConfig from parxy_core.services.pdf_service import PdfService +from parxy_core.middleware import Middleware class Parxy: @@ -60,6 +62,99 @@ def _get_factory(cls) -> DriverFactory: cls._factory = DriverFactory.build() return cls._factory + @classmethod + def _import_middleware(cls, middleware_path: str) -> Middleware: + """Import a middleware class from a string path. + + Parameters + ---------- + middleware_path : str + Dot-notation path to the middleware class (e.g., 'parxy_core.middleware.PIIScanner') + + Returns + ------- + Middleware + An instance of the middleware class + + Raises + ------ + ImportError + If the module or class cannot be imported + ValueError + If the imported object is not a Middleware subclass + """ + try: + module_path, class_name = middleware_path.rsplit('.', 1) + module = importlib.import_module(module_path) + middleware_class = getattr(module, class_name) + + if not issubclass(middleware_class, Middleware): + raise ValueError(f'{middleware_path} is not a Middleware subclass') + + return middleware_class() + except (ImportError, AttributeError) as e: + raise ImportError(f'Failed to import middleware: {middleware_path}') from e + + @classmethod + def with_middleware(cls, middleware: List[Union[str, Middleware]]) -> 'Parxy': + """Add middleware to the global middleware chain. + + Middleware can be specified as: + - Middleware instances + - String paths to middleware classes (for loading from config) + + The middleware are executed in the order they are added. + + Parameters + ---------- + middleware : List[Union[str, Middleware]] + List of middleware to add. Each can be a Middleware instance + or a string path like 'parxy_core.middleware.PIIScanner' + + Returns + ------- + Parxy + Returns self for chaining + + Example + ------- + >>> Parxy.with_middleware( + ... [MyCustomMiddleware(), 'parxy_core.middleware.SimpleMiddleware'] + ... ) + >>> doc = Parxy.parse('document.pdf') + """ + factory = cls._get_factory() + for mw in middleware: + if isinstance(mw, str): + factory.with_middleware([cls._import_middleware(mw)]) + elif isinstance(mw, Middleware): + factory.with_middleware([mw]) + else: + raise TypeError(f'Invalid middleware type: {type(mw)}') + + return cls + + @classmethod + def clear_middleware(cls) -> None: + """Clear all global middleware. + + Note that driver-specific middleware is not affected. + """ + factory = cls._get_factory() + factory.clear_middleware() + + @classmethod + def get_middleware(cls) -> List[Middleware]: + """Get the list of global middleware. + + Returns + ------- + List[Middleware] + Copy of the current global middleware list + """ + factory = cls._get_factory() + return factory.get_middleware() + @classmethod def parse( cls, @@ -83,7 +178,14 @@ def parse( Document The parsed document """ - return cls.driver(driver_name).parse(file=file, level=level) + # Get the driver instance + driver = cls.driver(driver_name) + + # Delegate to driver's parse method (middleware execution now happens in driver) + return driver.parse( + file=file, + level=level, + ) @classmethod def driver(cls, name: Optional[str] = None) -> Driver: diff --git a/src/parxy_core/middleware/__init__.py b/src/parxy_core/middleware/__init__.py new file mode 100644 index 0000000..d1c18b9 --- /dev/null +++ b/src/parxy_core/middleware/__init__.py @@ -0,0 +1,6 @@ +# Middleware module exports + +from parxy_core.middleware.base import Middleware as Middleware +from parxy_core.middleware.base import SimpleMiddleware as SimpleMiddleware + +__all__ = ['Middleware', 'SimpleMiddleware'] diff --git a/src/parxy_core/middleware/base.py b/src/parxy_core/middleware/base.py new file mode 100644 index 0000000..1fc3e34 --- /dev/null +++ b/src/parxy_core/middleware/base.py @@ -0,0 +1,54 @@ +"""Middleware base classes for Parxy document processing. + +This module defines the abstract middleware interface that allows preprocessing +and postprocessing of documents through the parsing pipeline. +""" + +from abc import ABC +from typing import Callable + +from parxy_core.models import Document, ParsingRequest + + +class Middleware(ABC): + """Abstract base class for document processing middleware. + + Middleware allows you to hook into the document parsing pipeline at two points: + - `handle()`: Called BEFORE parsing, receives a ParsingRequest + - `terminate()`: Called AFTER parsing, receives the parsed Document + + Middleware are executed in chains: + - handle() chain: First middleware's handle() -> next -> ... -> driver + - terminate() chain: driver -> ... -> last middleware's terminate() + """ + + def handle( + self, + request: ParsingRequest, + next: Callable[[ParsingRequest], Document], + ) -> Document: + """Process the request before parsing.""" + return next(request) + + def terminate( + self, + document: Document, + next: Callable[[Document], Document], + ) -> Document: + """Process the document after parsing.""" + return next(document) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + +class SimpleMiddleware(Middleware): + """Simplified middleware base for handle-only or terminate-only implementations.""" + + def handle( + self, + request: ParsingRequest, + next: Callable[[ParsingRequest], Document], + ) -> Document: + """Default implementation that passes through to next.""" + return next(request) diff --git a/src/parxy_core/models/__init__.py b/src/parxy_core/models/__init__.py index 0853107..aea1250 100644 --- a/src/parxy_core/models/__init__.py +++ b/src/parxy_core/models/__init__.py @@ -13,6 +13,7 @@ Page as Page, Metadata as Metadata, Document as Document, + ParsingRequest as ParsingRequest, BatchTask as BatchTask, BatchResult as BatchResult, # estimate_lines_from_block, diff --git a/src/parxy_core/models/config.py b/src/parxy_core/models/config.py index 1e4bbbb..c01d456 100644 --- a/src/parxy_core/models/config.py +++ b/src/parxy_core/models/config.py @@ -1,10 +1,11 @@ -from typing import Literal, Optional +from typing import Literal, Optional, List, Any +import json import logging from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import Field, SecretStr, BaseModel +from pydantic import Field, SecretStr, BaseModel, field_validator class BaseConfig(BaseSettings): @@ -79,6 +80,19 @@ class ParxyConfig(BaseConfig): tracing: ParxyTracingConfig = ParxyTracingConfig() """Tracing configuration""" + middleware: Optional[List[str]] = None + """List of middleware class paths to load automatically.""" + + @field_validator('middleware', mode='before') + @classmethod + def parse_middleware(cls, v: Any) -> Any: + if not isinstance(v, str): + return v + stripped = v.strip() + if stripped.startswith('['): + return json.loads(stripped) + return [item.strip() for item in stripped.split(',') if item.strip()] + model_config = SettingsConfigDict( env_prefix='parxy_', env_file='.env', diff --git a/src/parxy_core/models/models.py b/src/parxy_core/models/models.py index b965c56..f072e2c 100644 --- a/src/parxy_core/models/models.py +++ b/src/parxy_core/models/models.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from enum import IntEnum from io import BytesIO -from typing import List, Optional, Any, Union +from typing import List, Optional, Any, Union, Dict -from pydantic import BaseModel +from pydantic import BaseModel, Field class BoundingBox(BaseModel): @@ -67,6 +67,8 @@ class Block(BaseModel, ABC): source_data: Optional[dict[str, Any]] = None category: Optional[str] = None """Category attributed to this block by the parser""" + metadata: Optional[dict[str, Any]] = None + """Extended metadata for storing additional information like PII entities""" class TextBlock(Block): @@ -275,6 +277,32 @@ def markdown(self, page_separators: bool = False) -> str: return '\n\n'.join(markdown_parts) +class ParsingRequest(BaseModel): + """Request object for document parsing with middleware support. + + Encapsulates all parameters needed for parsing a document, passed through + the middleware chain's handle() methods. + + Attributes + ---------- + driver : str + The name of the driver to use for parsing (Required). + file : Union[str, BytesIO, bytes] + The file to parse (path, URL, or binary data). + level : str + The extraction level (e.g., 'page', 'block', 'line'). + config : dict[str, Any] + Additional configuration options for the parsing process. + """ + + model_config = {'arbitrary_types_allowed': True} + + driver: str + file: Union[str, BytesIO, bytes] + level: str + config: Dict[str, Any] = Field(default_factory=dict) + + @dataclass class BatchTask: """Configuration for a single batch parsing task. diff --git a/tests/commands/test_parse.py b/tests/commands/test_parse.py index 5fc27f4..b19dd8c 100644 --- a/tests/commands/test_parse.py +++ b/tests/commands/test_parse.py @@ -239,6 +239,91 @@ def test_parse_command_with_multiple_drivers(runner, mock_document, tmp_path): assert (output_dir / 'llamaparse-test.json').exists() +def test_parse_command_with_middleware(runner, mock_document, tmp_path): + """Test inline middleware class paths via --middleware.""" + + test_file = tmp_path / 'test.pdf' + test_file.write_text('dummy pdf content') + + with patch('parxy_cli.commands.parse.Parxy') as mock_parxy: + mock_parxy.default_driver.return_value = 'pymupdf' + mock_parxy.batch_iter.return_value = iter( + [ + BatchResult( + file=str(test_file), + driver='pymupdf', + document=mock_document, + error=None, + ) + ] + ) + + result = runner.invoke( + app, + [str(test_file), '--middleware', 'parxy_core.middleware.SimpleMiddleware'], + ) + + assert result.exit_code == 0 + mock_parxy.clear_middleware.assert_called_once() + mock_parxy.with_middleware.assert_called_once_with( + ['parxy_core.middleware.SimpleMiddleware'] + ) + + +def test_parse_command_with_middleware_config(runner, mock_document, tmp_path): + """Test middleware loading from a JSON config file via --middleware-config.""" + + test_file = tmp_path / 'test.pdf' + test_file.write_text('dummy pdf content') + + config_file = tmp_path / 'middleware.json' + config_file.write_text('["parxy_core.middleware.SimpleMiddleware"]') + + with patch('parxy_cli.commands.parse.Parxy') as mock_parxy: + mock_parxy.default_driver.return_value = 'pymupdf' + mock_parxy.batch_iter.return_value = iter( + [ + BatchResult( + file=str(test_file), + driver='pymupdf', + document=mock_document, + error=None, + ) + ] + ) + + result = runner.invoke( + app, + [str(test_file), '--middleware-config', str(config_file)], + ) + + assert result.exit_code == 0 + mock_parxy.clear_middleware.assert_called_once() + mock_parxy.with_middleware.assert_called_once_with( + ['parxy_core.middleware.SimpleMiddleware'] + ) + + +def test_parse_command_with_middleware_config_missing_file_fails( + runner, mock_document, tmp_path +): + """Test that a missing middleware config file returns a CLI error.""" + + test_file = tmp_path / 'test.pdf' + test_file.write_text('dummy pdf content') + + with patch('parxy_cli.commands.parse.Parxy') as mock_parxy: + mock_parxy.default_driver.return_value = 'pymupdf' + + result = runner.invoke( + app, + [str(test_file), '--middleware-config', str(tmp_path / 'nonexistent.json')], + ) + + assert result.exit_code != 0 + mock_parxy.batch_iter.assert_not_called() + + def test_collect_files_non_recursive(tmp_path): """Test that collect_files only finds files in the given directory when not recursive.""" diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..77edd82 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,252 @@ +"""Tests for middleware base classes and pipeline integration.""" + +from parxy_core.models import Document, ParsingRequest, Page +from parxy_core.models.config import ParxyConfig +from parxy_core.middleware import Middleware, SimpleMiddleware +from parxy_core.facade import Parxy +from parxy_core.drivers import Driver + + +class TestMiddlewareBase: + """Test middleware base class functionality.""" + + def test_middleware_with_only_handle(self): + """Test middleware with only handle() method.""" + + class TestHandleMiddleware(SimpleMiddleware): + def handle(self, request: ParsingRequest, next) -> Document: + request.config['test_flag'] = True + return next(request) + + middleware = TestHandleMiddleware() + assert hasattr(middleware, 'handle') + assert hasattr(middleware, 'terminate') + + def test_middleware_with_only_terminate(self): + """Test middleware with only terminate() method.""" + + class TestTerminateMiddleware(SimpleMiddleware): + def terminate(self, document: Document, next) -> Document: + document.parsing_metadata = document.parsing_metadata or {} + document.parsing_metadata['test'] = 'processed' + return next(document) + + middleware = TestTerminateMiddleware() + assert hasattr(middleware, 'handle') + assert hasattr(middleware, 'terminate') + + def test_middleware_with_both(self): + """Test middleware with both handle() and terminate() methods.""" + + class TestBothMiddleware(Middleware): + def handle(self, request: ParsingRequest, next) -> Document: + request.config['handle_called'] = True + return next(request) + + def terminate(self, document: Document, next) -> Document: + document.parsing_metadata = document.parsing_metadata or {} + document.parsing_metadata['terminate_called'] = True + return next(document) + + middleware = TestBothMiddleware() + assert hasattr(middleware, 'handle') + assert hasattr(middleware, 'terminate') + + +class TestMiddlewareRegistry: + """Test Parxy-level middleware registration.""" + + def test_global_middleware_registry_configuration(self): + """Test that global middleware can be registered and cleared.""" + Parxy.clear_middleware() + + class TestMiddleware(Middleware): + def handle(self, request: ParsingRequest, next) -> Document: + return next(request) + + Parxy.with_middleware([TestMiddleware()]) + + assert len(Parxy.get_middleware()) == 1 + + Parxy.clear_middleware() + + def test_execution_order(self): + """Test that middleware are stored in registration order.""" + Parxy.clear_middleware() + + class LoggingMiddleware(Middleware): + def __init__(self, name: str): + self.name = name + + def handle(self, request: ParsingRequest, next) -> Document: + return next(request) + + Parxy.with_middleware( + [LoggingMiddleware('A'), LoggingMiddleware('B'), LoggingMiddleware('C')] + ) + + middleware_list = Parxy.get_middleware() + assert len(middleware_list) == 3 + assert middleware_list[0].name == 'A' + assert middleware_list[1].name == 'B' + assert middleware_list[2].name == 'C' + + Parxy.clear_middleware() + + def test_string_loading(self): + """Test loading middleware from a dotted string path.""" + Parxy.clear_middleware() + + Parxy.with_middleware(['parxy_core.middleware.SimpleMiddleware']) + + middleware_list = Parxy.get_middleware() + assert len(middleware_list) == 1 + assert isinstance(middleware_list[0], SimpleMiddleware) + + Parxy.clear_middleware() + + def test_driver_middleware_persists_on_singleton(self): + """Test that driver middleware persists across singleton lookups.""" + driver = Parxy.driver('pymupdf') + + class TestDriverMiddleware(Middleware): + def handle(self, request: ParsingRequest, next) -> Document: + return next(request) + + driver.with_middleware(TestDriverMiddleware()) + assert len(driver.get_middleware()) == 1 + + driver2 = Parxy.driver('pymupdf') + assert len(driver2.get_middleware()) == 1 + + driver.clear_middleware() + assert len(driver.get_middleware()) == 0 + + +class TestMiddlewarePipelineOrder: + """Test runtime handle/terminate execution order.""" + + def test_driver_parse_executes_middleware_chain_in_expected_order(self): + """Global handle → driver handle → driver terminate → global terminate.""" + execution_log = [] + + class RecordingMiddleware(Middleware): + def __init__(self, name: str): + self.name = name + + def handle(self, request: ParsingRequest, next) -> Document: + execution_log.append(f'{self.name}.handle') + request.config[f'{self.name}_seen'] = True + return next(request) + + def terminate(self, document: Document, next) -> Document: + execution_log.append(f'{self.name}.terminate') + return next(document) + + class DummyDriver(Driver): + supported_levels = ['block'] + + def _handle(self, file, level='block', **kwargs) -> Document: + assert kwargs.get('global_seen') is True + assert kwargs.get('driver_seen') is True + return Document(pages=[Page(number=1, text='ok', blocks=[])]) + + Parxy.clear_middleware() + Parxy.with_middleware([RecordingMiddleware('global')]) + + driver = DummyDriver(config={}) + driver.with_middleware(RecordingMiddleware('driver')) + + driver.parse(file=b'dummy-bytes', level='block') + + assert execution_log == [ + 'global.handle', + 'driver.handle', + 'driver.terminate', + 'global.terminate', + ] + + driver.clear_middleware() + Parxy.clear_middleware() + + +class TestMiddlewareConfig: + """Test ParxyConfig middleware field.""" + + def test_config_stores_middleware_class_paths(self): + """Test that ParxyConfig accepts and stores middleware string paths.""" + config = ParxyConfig(middleware=['parxy_core.middleware.SimpleMiddleware']) + + assert config.middleware is not None + assert len(config.middleware) == 1 + assert config.middleware[0] == 'parxy_core.middleware.SimpleMiddleware' + + def test_config_middleware_defaults_to_none(self): + """Test that middleware is None when not configured.""" + config = ParxyConfig() + assert config.middleware is None + + def test_config_middleware_from_json_string(self): + """Test that middleware accepts a JSON array string.""" + config = ParxyConfig(middleware='["parxy_core.middleware.SimpleMiddleware"]') + assert config.middleware == ['parxy_core.middleware.SimpleMiddleware'] + + def test_config_middleware_from_comma_separated_string(self): + """Test that middleware accepts a comma-separated string.""" + config = ParxyConfig( + middleware='parxy_core.middleware.SimpleMiddleware, parxy_core.middleware.SimpleMiddleware' + ) + assert config.middleware == [ + 'parxy_core.middleware.SimpleMiddleware', + 'parxy_core.middleware.SimpleMiddleware', + ] + + def test_config_middleware_comma_separated_single_entry(self): + """Test that a single comma-separated entry (no comma) is parsed correctly.""" + config = ParxyConfig(middleware='parxy_core.middleware.SimpleMiddleware') + assert config.middleware == ['parxy_core.middleware.SimpleMiddleware'] + + def test_config_middleware_survives_clear(self): + """Config-layer middleware must not be removed by clear_middleware().""" + from parxy_core.drivers import DriverFactory + + DriverFactory.reset() + try: + factory = DriverFactory.__new__(DriverFactory).initialize( + ParxyConfig(middleware=['parxy_core.middleware.SimpleMiddleware']) + ) + + assert len(factory.get_middleware()) == 1 + + factory.clear_middleware() + + # Config middleware is preserved; runtime layer was empty so count stays 1. + assert len(factory.get_middleware()) == 1 + assert isinstance(factory.get_middleware()[0], SimpleMiddleware) + finally: + DriverFactory.reset() + + def test_runtime_middleware_cleared_config_middleware_preserved(self): + """clear_middleware() removes runtime entries but keeps config entries.""" + from parxy_core.drivers import DriverFactory + + DriverFactory.reset() + try: + factory = DriverFactory.__new__(DriverFactory).initialize( + ParxyConfig(middleware=['parxy_core.middleware.SimpleMiddleware']) + ) + + class ExtraMiddleware(Middleware): + def handle(self, request, next): + return next(request) + + factory.with_middleware([ExtraMiddleware()]) + assert len(factory.get_middleware()) == 2 + + factory.clear_middleware() + + remaining = factory.get_middleware() + assert len(remaining) == 1 + assert isinstance(remaining[0], SimpleMiddleware) + finally: + DriverFactory.reset()