diff --git a/README.md b/README.md index aeabc762..cb8a3032 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,19 @@ LLM-based nodes require a model configured in `models.yaml` and runtime paramete As of now, LLM inference is supported for TGI, vLLM, OpenAI, Azure, Azure OpenAI, Ollama and Triton compatible servers. Model deployment is external and configured in `models.yaml`. +## SyGra as a Platform + +SyGra can be used as a reusable platform to build different categories of tasks on top of the same graph execution engine, node types, processors, and metric infrastructure. + +### Eval + +Evaluation tasks live under `tasks/eval` and provide a standard pattern for: + +- Computing **unit metrics** per record during graph execution +- Computing **aggregator metrics** after the run via graph post-processing + +See: [`tasks/eval/README.md`](https://github.com/ServiceNow/SyGra/blob/main/tasks/eval/README.md) + diff --git a/sygra/core/base_task_executor.py b/sygra/core/base_task_executor.py index 7601fc45..8dfef8aa 100644 --- a/sygra/core/base_task_executor.py +++ b/sygra/core/base_task_executor.py @@ -544,7 +544,9 @@ def _repeat_to_merge_sequentially( # merge the primary and secondary dataframe horizontally by randomlly picking one and adding into primary # primary : M rows(a columns), secondary: N rows(b columns), merged: M rows(a+b columns) - def _shuffle_and_extend(self, primary_df, secondary_df) -> pd.DataFrame: + def _shuffle_and_extend( + self, primary_df: pd.DataFrame, secondary_df: pd.DataFrame + ) -> pd.DataFrame: max_len = len(primary_df) # Shuffle the secondary dataframe shuffled_secondary = secondary_df.sample(frac=1).reset_index(drop=True) @@ -560,7 +562,7 @@ def _shuffle_and_extend(self, primary_df, secondary_df) -> pd.DataFrame: final_secondary = pd.concat([shuffled_secondary, extra_rows], ignore_index=True) # now both dataset are same length, merge and return - return pd.concat([primary_df, final_secondary], axis=1) + return cast(pd.DataFrame, pd.concat([primary_df, final_secondary], axis=1)) def _load_source_data( self, data_config: dict @@ -587,8 +589,8 @@ def _load_source_data( full_data = self.apply_transforms(source_config_obj, full_data) elif isinstance(source_config, list): # if multiple dataset configured as list - dataset_list = [] - primary_df = None + dataset_list: list[dict[str, Any]] = [] + primary_df: Optional[pd.DataFrame] = None primary_config = None # if multiple dataset, verify if join_type and alias is defined in each config(@source and @sink) if isinstance(source_config, list): @@ -650,6 +652,9 @@ def _load_source_data( ds_conf: dict[str, Any] = ds.get("conf", {}) join_type = ds_conf.get(constants.DATASET_JOIN_TYPE) current_df = ds.get("dataset") + if current_df is None or not isinstance(current_df, pd.DataFrame): + logger.error("Dataset is missing or not a dataframe") + continue if join_type == constants.JOIN_TYPE_COLUMN: sec_alias_name = ds_conf.get(constants.DATASET_ALIAS) pri_alias_name = ( @@ -665,22 +670,26 @@ def _load_source_data( # where_clause = ds.get("conf").get("where_clause") primary_df = pd.merge( primary_df, - current_df, + cast(pd.DataFrame, current_df), left_on=primary_column, right_on=join_column, how="left", ) elif join_type == constants.JOIN_TYPE_SEQUENTIAL: - primary_df = self._repeat_to_merge_sequentially(primary_df, current_df) + primary_df = self._repeat_to_merge_sequentially( + primary_df, cast(pd.DataFrame, current_df) + ) elif join_type == constants.JOIN_TYPE_CROSS: - primary_df = primary_df.merge(current_df, how="cross") + primary_df = primary_df.merge(cast(pd.DataFrame, current_df), how="cross") elif join_type == constants.JOIN_TYPE_RANDOM: - primary_df = self._shuffle_and_extend(primary_df, current_df) + primary_df = self._shuffle_and_extend( + primary_df, cast(pd.DataFrame, current_df) + ) else: logger.error("Not implemented join_type") # now convert dataframe to list of dict (full_data) - full_data = primary_df.to_dict(orient="records") + full_data = cast(list[dict[str, Any]], primary_df.to_dict(orient="records")) else: logger.error("Unsupported source config type.") diff --git a/sygra/core/eval/metrics/aggregator_metrics/aggregator_metric_registry.py b/sygra/core/eval/metrics/aggregator_metrics/aggregator_metric_registry.py index e321b969..63b6461a 100644 --- a/sygra/core/eval/metrics/aggregator_metrics/aggregator_metric_registry.py +++ b/sygra/core/eval/metrics/aggregator_metrics/aggregator_metric_registry.py @@ -7,6 +7,8 @@ # Avoid circular imports from __future__ import annotations +import importlib +import pkgutil from typing import TYPE_CHECKING, Dict, List, Type from sygra.logger.logger_config import logger @@ -42,6 +44,33 @@ class AggregatorMetricRegistry: # Class-level storage (create singleton to have central control) _metrics: Dict[str, Type[BaseAggregatorMetric]] = {} + _discovered: bool = False + + @classmethod + def _ensure_discovered(cls) -> None: + if cls._discovered: + return + + try: + import sygra.core.eval.metrics.aggregator_metrics as aggregator_metrics_pkg + + for module_info in pkgutil.iter_modules( + aggregator_metrics_pkg.__path__, aggregator_metrics_pkg.__name__ + "." + ): + module_name = module_info.name + if module_name.endswith( + ( + ".base_aggregator_metric", + ".aggregator_metric_registry", + ) + ): + continue + importlib.import_module(module_name) + + cls._discovered = True + except Exception as e: + logger.error(f"Failed to auto-discover aggregator metrics: {e}") + cls._discovered = True @classmethod def register(cls, name: str, metric_class: Type[BaseAggregatorMetric]) -> None: @@ -105,6 +134,8 @@ def get_metric(cls, name: str, **kwargs) -> BaseAggregatorMetric: # Get metric with custom parameters topk = AggregatorMetricRegistry.get_metric("top_k_accuracy", k=5) """ + cls._ensure_discovered() + if name not in cls._metrics: available = cls.list_metrics() raise KeyError( @@ -135,6 +166,7 @@ def list_metrics(cls) -> List[str]: AggregatorMetricRegistry.list_metrics() ['accuracy', 'confusion_matrix', 'f1', 'precision', 'recall'] """ + cls._ensure_discovered() return sorted(cls._metrics.keys()) @classmethod @@ -149,6 +181,7 @@ def has_metric(cls, name: str) -> bool: if AggregatorMetricRegistry.has_metric("f1"): metric = AggregatorMetricRegistry.get_metric("f1") """ + cls._ensure_discovered() return name in cls._metrics @classmethod @@ -163,6 +196,8 @@ def get_metric_class(cls, name: str) -> Type[BaseAggregatorMetric]: Raises: KeyError: If metric name is not registered """ + cls._ensure_discovered() + if name not in cls._metrics: available = cls.list_metrics() raise KeyError( diff --git a/sygra/core/eval/metrics/aggregator_metrics/f1_score.py b/sygra/core/eval/metrics/aggregator_metrics/f1_score.py index 08299b71..73f5bbe2 100644 --- a/sygra/core/eval/metrics/aggregator_metrics/f1_score.py +++ b/sygra/core/eval/metrics/aggregator_metrics/f1_score.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import aggregator_metric from sygra.core.eval.metrics.aggregator_metrics.base_aggregator_metric import BaseAggregatorMetric @@ -23,14 +23,6 @@ class F1ScoreMetricConfig(BaseModel): predicted_key: str = Field(..., min_length=1, description="Key in predicted dict to check") golden_key: str = Field(..., min_length=1, description="Key in golden dict to check") - positive_class: Any = Field(..., description="Value representing positive class") - - @field_validator("positive_class") - @classmethod - def validate_positive_class(cls, v): - if v is None: - raise ValueError("positive_class is required (cannot be None)") - return v @aggregator_metric("f1_score") @@ -43,7 +35,6 @@ class F1ScoreMetric(BaseAggregatorMetric): Required configuration: predicted_key: Key in predicted dict to check (e.g., "tool") golden_key: Key in golden dict to check (e.g., "event") - positive_class: Value representing the positive class (e.g., "click") """ def __init__(self, **config): @@ -60,15 +51,10 @@ def validate_config(self): # Store validated fields as instance attributes self.predicted_key = config_obj.predicted_key self.golden_key = config_obj.golden_key - self.positive_class = config_obj.positive_class # Create precision and recall metrics (reuse implementations) - self.precision_metric = PrecisionMetric( - predicted_key=self.predicted_key, positive_class=self.positive_class - ) - self.recall_metric = RecallMetric( - golden_key=self.golden_key, positive_class=self.positive_class - ) + self.precision_metric = PrecisionMetric(predicted_key=self.predicted_key) + self.recall_metric = RecallMetric(golden_key=self.golden_key) def get_metadata(self) -> BaseMetricMetadata: """Return metadata for F1 score metric""" @@ -93,16 +79,27 @@ def calculate(self, results: List[UnitMetricResult]) -> Dict[str, Any]: """ if not results: logger.warning(f"{self.__class__.__name__}: No results provided") - return {"f1_score": 0.0} + return {"average_f1_score": 0.0, "f1_score_per_class": {}} + f1_score = dict() # Reuse existing metric implementations precision_result = self.precision_metric.calculate(results) recall_result = self.recall_metric.calculate(results) - precision = precision_result.get("precision", 0.0) - recall = recall_result.get("recall", 0.0) - # Calculate F1 as harmonic mean of precision and recall - f1_score = self._safe_divide(2 * precision * recall, precision + recall) + average_precision = precision_result.get("average_precision", 0.0) + average_recall = recall_result.get("average_recall", 0.0) + average_f1_score = self._safe_divide( + 2 * average_precision * average_recall, average_precision + average_recall + ) + + precision_classes = set(precision_result.get("precision_per_class", {}).keys()) + recall_classes = set(recall_result.get("recall_per_class", {}).keys()) + all_classes = precision_classes.union(recall_classes) + + for class_ in all_classes: + precision = precision_result.get("precision_per_class", {}).get(class_, 0.0) + recall = recall_result.get("recall_per_class", {}).get(class_, 0.0) + f1_score[class_] = self._safe_divide(2 * precision * recall, precision + recall) - return {"f1_score": f1_score} + return {"average_f1_score": average_f1_score, "f1_score_per_class": f1_score} diff --git a/sygra/core/eval/metrics/aggregator_metrics/precision.py b/sygra/core/eval/metrics/aggregator_metrics/precision.py index 2122ac0a..c9ae0241 100644 --- a/sygra/core/eval/metrics/aggregator_metrics/precision.py +++ b/sygra/core/eval/metrics/aggregator_metrics/precision.py @@ -5,9 +5,10 @@ Measures: Of all predicted positives, how many were actually positive? """ -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import aggregator_metric from sygra.core.eval.metrics.aggregator_metrics.base_aggregator_metric import BaseAggregatorMetric @@ -20,14 +21,6 @@ class PrecisionMetricConfig(BaseModel): """Configuration for Precision Metric""" predicted_key: str = Field(..., min_length=1, description="Key in predicted dict to check") - positive_class: Any = Field(..., description="Value representing positive class") - - @field_validator("positive_class") - @classmethod - def validate_positive_class(cls, v): - if v is None: - raise ValueError("positive_class is required (cannot be None)") - return v @aggregator_metric("precision") @@ -39,12 +32,12 @@ class PrecisionMetric(BaseAggregatorMetric): Required configuration: predicted_key: Key in predicted dict to check (e.g., "tool") - positive_class: Value representing the positive class (e.g., "click") """ def __init__(self, **config): """Initialize precision metric with two-phase initialization.""" super().__init__(**config) + self.predicted_key = None self.validate_config() self.metadata = self.get_metadata() @@ -55,7 +48,6 @@ def validate_config(self): # Store validated fields as instance attributes self.predicted_key = config_obj.predicted_key - self.positive_class = config_obj.positive_class def get_metadata(self) -> BaseMetricMetadata: """Return metadata for precision metric""" @@ -76,23 +68,54 @@ def calculate(self, results: List[UnitMetricResult]) -> Dict[str, Any]: results: List of UnitMetricResult Returns: - dict: {"precision": float (0.0 to 1.0)} + dict: { + "average_precision": float (0.0 to 1.0) + "precision_per_class": { + "class_1": float (0.0 to 1.0), + "class_2": float (0.0 to 1.0), + ... + "class_n": float (0.0 to 1.0) + } + } """ if not results: logger.warning(f"{self.__class__.__name__}: No results provided") - return {"precision": 0.0} - - # Calculate TP and FP - tp = sum( - 1 - for r in results - if r.predicted.get(self.predicted_key) == self.positive_class and r.correct - ) - fp = sum( - 1 - for r in results - if r.predicted.get(self.predicted_key) == self.positive_class and not r.correct + return {"average_precision": 0.0, "precision_per_class": {}} + + predicted_count: DefaultDict[str, int] = defaultdict(int) + true_positive: DefaultDict[str, int] = defaultdict(int) + + for r in results: + try: + predicted_key = self.predicted_key + if predicted_key is None: + logger.warning(f"{self.__class__.__name__}: predicted_key is not configured") + continue + label = r.predicted[predicted_key] + except KeyError: + logger.warning( + f"{self.__class__.__name__}: Missing predicted_key '{self.predicted_key}' in result" + ) + continue + + if not isinstance(label, str): + label = str(label) + + predicted_count[label] += 1 + if r.correct: + true_positive[label] += 1 + + precision_per_class = { + label: self._safe_divide(true_positive[label], count) + for label, count in predicted_count.items() + } + + average_precision = self._safe_divide( + sum(precision_per_class.values()), + len(precision_per_class), ) - precision = self._safe_divide(tp, tp + fp) - return {"precision": precision} + return { + "average_precision": average_precision, + "precision_per_class": precision_per_class, + } diff --git a/sygra/core/eval/metrics/aggregator_metrics/recall.py b/sygra/core/eval/metrics/aggregator_metrics/recall.py index 22e183b8..7aa4c37c 100644 --- a/sygra/core/eval/metrics/aggregator_metrics/recall.py +++ b/sygra/core/eval/metrics/aggregator_metrics/recall.py @@ -5,9 +5,10 @@ Measures: Of all actual positives, how many were predicted correctly? """ -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import aggregator_metric from sygra.core.eval.metrics.aggregator_metrics.base_aggregator_metric import BaseAggregatorMetric @@ -20,14 +21,6 @@ class RecallMetricConfig(BaseModel): """Configuration for Recall Metric""" golden_key: str = Field(..., min_length=1, description="Key in golden dict to check") - positive_class: Any = Field(..., description="Value representing positive class") - - @field_validator("positive_class") - @classmethod - def validate_positive_class(cls, v): - if v is None: - raise ValueError("positive_class is required (cannot be None)") - return v @aggregator_metric("recall") @@ -39,7 +32,6 @@ class RecallMetric(BaseAggregatorMetric): Required configuration: golden_key: Key in golden dict to check (e.g., "event") - positive_class: Value representing the positive class (e.g., "click") """ def __init__(self, **config): @@ -55,7 +47,6 @@ def validate_config(self): # Store validated fields as instance attributes self.golden_key = config_obj.golden_key - self.positive_class = config_obj.positive_class def get_metadata(self) -> BaseMetricMetadata: """Return metadata for recall metric""" @@ -80,17 +71,38 @@ def calculate(self, results: List[UnitMetricResult]) -> Dict[str, Any]: """ if not results: logger.warning(f"{self.__class__.__name__}: No results provided") - return {"recall": 0.0} - - # Calculate TP and FN - tp = sum( - 1 for r in results if r.golden.get(self.golden_key) == self.positive_class and r.correct - ) - fn = sum( - 1 - for r in results - if r.golden.get(self.golden_key) == self.positive_class and not r.correct + return {"average_recall": 0.0, "recall_per_class": {}} + + golden_count: DefaultDict[str, int] = defaultdict(int) + true_positive: DefaultDict[str, int] = defaultdict(int) + + for r in results: + try: + label = r.golden[self.golden_key] + except KeyError: + logger.warning( + f"{self.__class__.__name__}: Missing golden_key '{self.golden_key}' in result" + ) + continue + + if not isinstance(label, str): + label = str(label) + + golden_count[label] += 1 + if r.correct: + true_positive[label] += 1 + + recall_per_class = { + label: self._safe_divide(true_positive[label], count) + for label, count in golden_count.items() + } + + average_recall = self._safe_divide( + sum(recall_per_class.values()), + len(recall_per_class), ) - recall = self._safe_divide(tp, tp + fn) - return {"recall": recall} + return { + "average_recall": average_recall, + "recall_per_class": recall_per_class, + } diff --git a/sygra/core/eval/metrics/unit_metrics/action_within_bbox.py b/sygra/core/eval/metrics/unit_metrics/action_within_bbox.py index ad93bc65..d35f3f81 100644 --- a/sygra/core/eval/metrics/unit_metrics/action_within_bbox.py +++ b/sygra/core/eval/metrics/unit_metrics/action_within_bbox.py @@ -10,6 +10,7 @@ from sygra.core.eval.metrics.base_metric_metadata import BaseMetricMetadata from sygra.core.eval.metrics.unit_metrics.base_unit_metric import BaseUnitMetric +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import unit_metric from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult from sygra.logger.logger_config import logger @@ -24,6 +25,7 @@ class ActionWithinBboxMetricConfig(BaseModel): ) +@unit_metric("action_within_bbox") class ActionWithinBboxMetric(BaseUnitMetric): """Validate that the predicted (x, y) is within the golden bbox.""" diff --git a/sygra/core/eval/metrics/unit_metrics/exact_match.py b/sygra/core/eval/metrics/unit_metrics/exact_match.py index 5489381a..078313ec 100644 --- a/sygra/core/eval/metrics/unit_metrics/exact_match.py +++ b/sygra/core/eval/metrics/unit_metrics/exact_match.py @@ -11,6 +11,7 @@ from sygra.core.eval.metrics.base_metric_metadata import BaseMetricMetadata from sygra.core.eval.metrics.unit_metrics.base_unit_metric import BaseUnitMetric +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import unit_metric from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult from sygra.logger.logger_config import logger @@ -27,6 +28,7 @@ class ExactMatchMetricConfig(BaseModel): ) +@unit_metric("exact_match") class ExactMatchMetric(BaseUnitMetric): """ Exact Match metric. diff --git a/sygra/core/eval/metrics/unit_metrics/scroll_amount.py b/sygra/core/eval/metrics/unit_metrics/scroll_amount.py index 0d14e4e9..841cb77c 100644 --- a/sygra/core/eval/metrics/unit_metrics/scroll_amount.py +++ b/sygra/core/eval/metrics/unit_metrics/scroll_amount.py @@ -9,6 +9,7 @@ from sygra.core.eval.metrics.base_metric_metadata import BaseMetricMetadata from sygra.core.eval.metrics.unit_metrics.base_unit_metric import BaseUnitMetric +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import unit_metric from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult from sygra.logger.logger_config import logger @@ -34,6 +35,7 @@ class ScrollAmountMetricConfig(BaseModel): ) +@unit_metric("scroll_amount") class ScrollAmountMetric(BaseUnitMetric): """Validate scroll amount correctness within tolerance.""" diff --git a/sygra/core/eval/metrics/unit_metrics/scroll_direction.py b/sygra/core/eval/metrics/unit_metrics/scroll_direction.py index 733e706b..c58ce2e2 100644 --- a/sygra/core/eval/metrics/unit_metrics/scroll_direction.py +++ b/sygra/core/eval/metrics/unit_metrics/scroll_direction.py @@ -9,6 +9,7 @@ from sygra.core.eval.metrics.base_metric_metadata import BaseMetricMetadata from sygra.core.eval.metrics.unit_metrics.base_unit_metric import BaseUnitMetric +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import unit_metric from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult from sygra.logger.logger_config import logger @@ -24,6 +25,7 @@ class ScrollDirectionMetricConfig(BaseModel): ) +@unit_metric("scroll_direction") class ScrollDirectionMetric(BaseUnitMetric): """Validate scroll direction correctness.""" diff --git a/sygra/core/eval/metrics/unit_metrics/typed_value_match.py b/sygra/core/eval/metrics/unit_metrics/typed_value_match.py index cdfdbb53..5db5637b 100644 --- a/sygra/core/eval/metrics/unit_metrics/typed_value_match.py +++ b/sygra/core/eval/metrics/unit_metrics/typed_value_match.py @@ -11,6 +11,7 @@ from sygra.core.eval.metrics.base_metric_metadata import BaseMetricMetadata from sygra.core.eval.metrics.unit_metrics.base_unit_metric import BaseUnitMetric +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import unit_metric from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult from sygra.logger.logger_config import logger @@ -32,6 +33,7 @@ class TypedValueMatchMetricConfig(BaseModel): ) +@unit_metric("typed_value_match") class TypedValueMatchMetric(BaseUnitMetric): """Validate typed value correctness using exact and fuzzy matching.""" diff --git a/sygra/core/eval/metrics/unit_metrics/unit_metric_registry.py b/sygra/core/eval/metrics/unit_metrics/unit_metric_registry.py new file mode 100644 index 00000000..6518a519 --- /dev/null +++ b/sygra/core/eval/metrics/unit_metrics/unit_metric_registry.py @@ -0,0 +1,298 @@ +""" +Unit Metric Registry +Singleton registry for discovering and instantiating unit metrics. +Provides centralized service locator for all metrics (built-in and custom). +""" + +# Avoid circular imports +from __future__ import annotations + +import importlib +import pkgutil +from typing import TYPE_CHECKING, Dict, List, Type + +from sygra.logger.logger_config import logger + +# This prevents circular imports while still providing type safety. +if TYPE_CHECKING: + from sygra.core.eval.metrics.unit_metrics.base_unit_metric import ( + BaseUnitMetric, + ) + + +class UnitMetricRegistry: + """ + This registry maintains a mapping of metric names to metric classes, + allowing runtime discovery without hard coding. + Features: + 1. Auto-registration using @register_unit_metric decorator + 2. Runtime metric discovery(and use case being read from graph_config) + 3. Factory method for metric instantiation + 4. List available metrics + 5. Check metric existence + Usage: + # Register a metric (add decorator) + UnitMetricRegistry.register("precision", PrecisionMetric) + # Get metric instance + metric = UnitMetricRegistry.get_metric("precision") + # List all available metrics + all_metrics = UnitMetricRegistry.list_metrics() + # Check if metric exists, for example + if UnitMetricRegistry.has_metric("f1"): + metric = UnitMetricRegistry.get_metric("f1") + """ + + # Class-level storage (create singleton to have central control) + _metrics: Dict[str, Type[BaseUnitMetric]] = {} + _discovered: bool = False + + @classmethod + def _ensure_discovered(cls) -> None: + if cls._discovered: + return + + try: + import sygra.core.eval.metrics.unit_metrics as unit_metrics_pkg + + for module_info in pkgutil.iter_modules( + unit_metrics_pkg.__path__, unit_metrics_pkg.__name__ + "." + ): + module_name = module_info.name + if module_name.endswith( + ( + ".base_unit_metric", + ".unit_metric_registry", + ".unit_metric_result", + ) + ): + continue + importlib.import_module(module_name) + + cls._discovered = True + except Exception as e: + logger.error(f"Failed to auto-discover unit metrics: {e}") + cls._discovered = True + + @classmethod + def register(cls, name: str, metric_class: Type[BaseUnitMetric]) -> None: + """ + Register a unit metric class. + This method is typically called automatically by the @register_unit_metric + decorator, but can also be called manually if needed. + Args: + name: Unique identifier for the metric (e.g., "precision", "f1") + metric_class: Class that implements BaseUnitMetric + Raises: + ValueError: If name is empty or metric_class is invalid + Example: + UnitMetricRegistry.register("precision", PrecisionMetric) + """ + # Validation + if not name or not isinstance(name, str): + raise ValueError("Metric name must be a non-empty string") + + if not isinstance(metric_class, type): + raise ValueError(f"metric_class must be a class, got {type(metric_class)}") + + # Import at runtime (inside function) instead of at module level to avoid circular dependency + from sygra.core.eval.metrics.unit_metrics.base_unit_metric import ( + BaseUnitMetric, + ) + + if not issubclass(metric_class, BaseUnitMetric): + raise ValueError( + f"metric_class must inherit from BaseUnitMetric, " f"got {metric_class.__name__}" + ) + + # Check for duplicate registration + if name in cls._metrics: + logger.warning( + f"Unit metric '{name}' is already registered. " + f"Overwriting {cls._metrics[name].__name__} with {metric_class.__name__}" + ) + + # Register + cls._metrics[name] = metric_class + logger.debug(f"Registered unit metric: '{name}' -> {metric_class.__name__}") + + @classmethod + def get_metric(cls, name: str, **kwargs) -> BaseUnitMetric: + """ + Get an instance of a registered metric. + This is a factory method that creates and returns a metric instance + without the caller needing to know the concrete class. + Args: + name: Metric name (e.g., "precision", "recall", "f1") + **kwargs: Optional arguments to pass to metric constructor + Returns: + Instance of the requested metric + Raises: + KeyError: If metric name is not registered + Example: + # Get metric with default parameters + precision = UnitMetricRegistry.get_metric("precision") + # Get metric with custom parameters + topk = UnitMetricRegistry.get_metric("top_k_accuracy", k=5) + """ + cls._ensure_discovered() + + if name not in cls._metrics: + available = cls.list_metrics() + raise KeyError( + f"Unit metric '{name}' not found in registry. " f"Available metrics: {available}" + ) + + metric_class = cls._metrics[name] + + try: + # Instantiate metric with optional kwargs + metric_instance = metric_class(**kwargs) + logger.debug(f"Instantiated unit metric: '{name}'") + return metric_instance + except Exception as e: + logger.error( + f"Failed to instantiate metric '{name}' " f"({metric_class.__name__}): {e}" + ) + raise + + @classmethod + def list_metrics(cls) -> List[str]: + """ + List all registered metric names. + Returns: + List of metric names + Example: + UnitMetricRegistry.list_metrics() + ['accuracy', 'confusion_matrix', 'f1', 'precision', 'recall'] + """ + cls._ensure_discovered() + return sorted(cls._metrics.keys()) + + @classmethod + def has_metric(cls, name: str) -> bool: + """ + Check if a metric is registered. + Args: + name: Metric name to check + Returns: + True if metric is registered, False otherwise + Example: + if UnitMetricRegistry.has_metric("f1"): + metric = UnitMetricRegistry.get_metric("f1") + """ + cls._ensure_discovered() + return name in cls._metrics + + @classmethod + def get_metric_class(cls, name: str) -> Type[BaseUnitMetric]: + """ + Get the class (not instance) of a registered metric. + Adding this for now for inspection purposes on which metric is being used. + Args: + name: Metric name + Returns: + Metric class + Raises: + KeyError: If metric name is not registered + """ + cls._ensure_discovered() + + if name not in cls._metrics: + available = cls.list_metrics() + raise KeyError( + f"Unit metric '{name}' not found in registry. " f"Available metrics: {available}" + ) + return cls._metrics[name] + + @classmethod + def unregister(cls, name: str) -> bool: + """ + Unregister a metric. This is added feature if we want to deprecate or test, there could be a better way to achieve this using decorator. + Args: + name: Metric name to unregister + Returns: + True if metric was unregistered, False if it wasn't registered + Example: + UnitMetricRegistry.unregister("old_metric") + """ + if name in cls._metrics: + del cls._metrics[name] + logger.debug(f"Unregistered unit metric: '{name}'") + return True + return False + + @classmethod + def clear(cls) -> None: + """ + Clear all registered metrics. + Adding this because it is standard practice to have an evict option to test registry in unit testing. + """ + cls._metrics.clear() + logger.warning("Cleared all registered unit metrics") + + @classmethod + def get_metrics_info(cls) -> Dict[str, Dict[str, str]]: + """ + Get information about all registered metrics in dict format, basically the registered name and module path where code is written for it. + This is just for debugging purposes for now, may have some use case in the future. + Returns: + dict: {metric_name: {"class": class_name, "module": module_name}} + Example: + UnitMetricRegistry.get_metrics_info() + { + 'precision': { + 'class': 'PrecisionMetric', + 'module': 'core.unit_metrics.precision' + }, + 'recall': { + 'class': 'RecallMetric', + 'module': 'core.unit_metrics.recall' + } + } + """ + info = {} + for name, metric_class in cls._metrics.items(): + info[name] = {"class": metric_class.__name__, "module": metric_class.__module__} + return info + + +# Decorator for metric registration +def unit_metric(name: str): + """ + Decorator to auto-register unit metrics with the registry. + + Usage: + @unit_metric("precision") + class PrecisionMetric(BaseUnitMetric): + def calculate(self, results): + # Implementation + pass + + Args: + name: Unique name for the metric (used for registry lookup) + + Returns: + Decorator function that registers the class + """ + + def decorator(cls): + # Import at runtime when decorator is applied (not at module load time) + # Metrics use this decorator, so they import this registry file. + # If we imported BaseUnitMetric at the top, we'd have: + # metric.py -> registry.py -> base.py (circular dependency) + # By importing here, the import happens when the class is decorated, + # after all modules have loaded. + from sygra.core.eval.metrics.unit_metrics.base_unit_metric import ( + BaseUnitMetric, + ) + + # Validate that class inherits from BaseUnitMetric + if not issubclass(cls, BaseUnitMetric): + raise TypeError( + f"{cls.__name__} must inherit from BaseUnitMetric to use @unit_metric decorator" + ) + + UnitMetricRegistry.register(name, cls) + return cls + + return decorator diff --git a/tasks/eval/README.md b/tasks/eval/README.md new file mode 100644 index 00000000..6aec30a4 --- /dev/null +++ b/tasks/eval/README.md @@ -0,0 +1,308 @@ +# Evaluation Tasks (`tasks/eval`) + +SyGra is a graph-oriented workflow framework for synthetic data generation **and evaluation**. Evaluation tasks are implemented as standard SyGra tasks (a `graph_config.yaml` + optional Python utilities) and produce: + +- **Per-record outputs** (including unit metric results) +- **Aggregated evaluation reports** (aggregator metrics computed after the run) + +## Quick start + +Run an eval task through `main.py`: + +```bash +uv run python main.py --task tasks.eval.question_answering.simpleqa --num_records 50 +``` + +You can also omit the `tasks.` prefix: + +```bash +uv run python main.py --task eval/classification/simpleqa --num_records 50 +``` + +To control where artifacts are written: + +```bash +uv run python main.py \ + --task tasks.eval.question_answering.simpleqa \ + --num_records 50 \ + --output_dir /abs/path/to/my_eval_outputs +``` + +## Concepts + +SyGra evaluation is organized into two metric layers: + +- **Unit metrics** + - Per-record metrics computed *inside the graph*. + - Stored back into the record/state (e.g., `exact_match_result`). +- **Aggregator metrics** + - Dataset-level metrics computed *after the run* by aggregating unit metric outputs (e.g., accuracy, precision, recall, F1). + - Produced by graph-level post-processing. + +### Key files and registries + +- **Generic eval utilities**: `tasks/eval/utils.py` + - `UnitMetrics`: reusable `lambda` node implementation that computes unit metrics via `UnitMetricRegistry`. + - `MetricCollatorPostProcessor`: reusable graph post-processor that computes aggregator metrics via `AggregatorMetricRegistry`. +- **Unit metric registry**: `sygra.core.eval.metrics.unit_metrics.unit_metric_registry.UnitMetricRegistry` +- **Aggregator metric registry**: `sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry.AggregatorMetricRegistry` + +## Evaluation lifecycle + +At a high level, an eval task runs as follows: + +1. **Load task config** + - Task configs live under `tasks//graph_config.yaml`. + +2. **Load dataset** + - Configured in `data_config.source`. + - Common sources: HuggingFace (`type: hf`) or local disk. + +3. **Execute the graph per record** + - Nodes (LLM / lambda / sampler / etc.) run and update the `SygraState`. + +4. **Compute unit metrics per record (inside the graph)** + - Typically implemented as a `lambda` node. + - Example used in eval tasks: `tasks.eval.utils.UnitMetrics` + - Unit metric outputs are stored back into the record (e.g., `exact_match_result`). + +5. **Write the per-record output file** + - The framework writes a file named like `output_*.json` or `output_*.jsonl`. + +6. **Compute aggregator metrics (graph post-processing)** + - Configured via `graph_post_process` in the task YAML. + - Framework reads the output file back and runs each configured graph post processor. + - For eval tasks this is commonly: `tasks.eval.utils.MetricCollatorPostProcessor`. + +7. **Write aggregator results file(s)** + - For each post processor, SyGra writes a new file by replacing the `output` prefix in the filename with the post-processor class name. + - Example: `output_2026-01-27_12-31-46.json` -> `MetricCollatorPostProcessor_2026-01-27_12-31-46.json` + +## Using `tasks/eval/utils.py` in eval graphs + +### Unit metrics (generic `UnitMetrics` lambda) + +`tasks.eval.utils.UnitMetrics` is a generic lambda node that: + +- Reads `golden_key` and `predicted_key` from the graph state. +- Iterates the configured `unit_metrics_map`. +- Instantiates metrics via `UnitMetricRegistry.get_metric(name, **params)`. +- Writes per-record results back into state as `_result`. + +This is why most eval graphs: + +- Add a `unit_metrics` lambda node. +- Include the corresponding `*_result` field in `output_config.output_map`. + +### Aggregator metrics (generic `MetricCollatorPostProcessor`) + +`tasks.eval.utils.MetricCollatorPostProcessor` is a generic graph-level post-processor that: + +- Loads the run output file (`output_*.json`). +- For each configured entry in `aggregator_metrics_map`: + - Selects the unit-metric column specified by `unit_metrics_results` (commonly `exact_match_result`). + - Converts dicts to `UnitMetricResult` objects. + - Instantiates the aggregator metric via `AggregatorMetricRegistry.get_metric(name, **params)`. + - Computes the aggregated metric over all rows. + +It writes a new JSON file named by replacing `output` in the filename with `MetricCollatorPostProcessor`. + +## Outputs and artifacts + +An eval run typically produces **two primary artifacts**: + +### 1) Per-record output (includes unit metric results) + +This is the main run output file written during graph execution (typically named `output_*.json` or `output_*.jsonl`). + +You’ll find, per record: + +- The original fields you mapped in `output_config.output_map` +- The model outputs (e.g., `predicted_answer`) +- **Unit metric result fields** (e.g., `exact_match_result`) + +In the eval examples, `UnitMetrics` stores metric results in the state using: + +- `state[unit_metric_name + "_result"] = results[0].to_dict()` + +So if you configure `exact_match`, you’ll typically see: + +- `exact_match_result`: a dict representing a `UnitMetricResult` + +### 2) Aggregated evaluation report (aggregator metrics) + +After the run completes, SyGra runs graph post-processors and writes additional JSON files. + +For eval tasks using `tasks.eval.utils.MetricCollatorPostProcessor`, the report output is: + +- `MetricCollatorPostProcessor_*.json` + +This file contains a list with a single “report object”: + +- `evaluation_summary`: counts/status +- `results`: a dict keyed by aggregator metric name (e.g., `accuracy`, `f1_score`) + +## `MetricCollatorPostProcessor_*.json` file format + +The file is a JSON list with a single report object: + +```json +[ + { + "evaluation_summary": { + "total_records": 1000, + "status": "success" + }, + "results": { + "accuracy": { + "accuracy": 0.689 + }, + "precision": { + "average_precision": 0.54, + "precision_per_class": { + "Music": 0.91, + "Politics": 0.92, + "Other": 0.45 + } + }, + "recall": { + "average_recall": 0.72, + "recall_per_class": { + "Music": 0.81, + "Politics": 0.56, + "Other": 0.72 + } + }, + "f1_score": { + "average_f1_score": 0.62, + "f1_score_per_class": { + "Music": 0.86, + "Politics": 0.69, + "Other": 0.55 + } + } + } + } +] +``` + +Notes: + +- **Top-level** is always a list (currently a single-item list). +- `evaluation_summary.status` is typically: + - `success` + - `no_data` + - `fatal_error` (includes an `error` string) +- `results` keys match the `aggregator_metrics_map[].name` entries from your task config. +- `results` can contain **multiple aggregator metrics** in a single report (e.g., `accuracy`, `precision`, `recall`, `f1_score`, `pass_at_k`, etc.). +- For classification-style metrics, per-class fields (e.g., `precision_per_class`) are keyed by the task’s **label classes** (for example: `Music`, `Politics`, `Other`). +- The exact shape of each metric payload is **metric-dependent**: + - Example: `accuracy` returns `{ "accuracy": }`. + - Example: `precision`/`recall`/`f1_score` may return both macro/average values and per-class breakdowns. + +## Configuring metrics in `graph_config.yaml` + +Eval tasks typically implement **unit metrics** as a lambda node and **aggregator metrics** as a graph post-processor. + +### Unit metric example (per record) + +From `tasks/eval/question_answering/simpleqa/graph_config.yaml`: + +- `lambda: tasks.eval.utils.UnitMetrics` +- `unit_metrics_map`: list of unit metrics to run + +Example: + +```yaml +unit_metrics: + node_type: lambda + lambda: tasks.eval.utils.UnitMetrics + golden_key: "answer" + predicted_key: "predicted_answer" + unit_metrics_map: + - name: "exact_match" + params: + key: "text" + output_keys: + - exact_match_result +``` + +This writes `exact_match_result` into each output record. + +### Aggregator metric example (dataset-level) + +From `tasks/eval/classification/simpleqa/graph_config.yaml`: + +```yaml +graph_post_process: + - processor: tasks.eval.utils.MetricCollatorPostProcessor + params: + aggregator_metrics_map: + - name: "accuracy" + params: + key: "text" + unit_metrics_results: + - "exact_match_result" + - name: "f1_score" + params: + predicted_key: "text" + golden_key: "text" + unit_metrics_results: + - "exact_match_result" +``` + +`MetricCollatorPostProcessor` will: + +- Build a DataFrame from the run output (`output_*.json`) +- For each aggregator metric: + - Pull the configured `unit_metrics_results` column (e.g., `exact_match_result`) + - Convert dicts to `UnitMetricResult` + - Instantiate the metric via `AggregatorMetricRegistry.get_metric(name, **params)` + - Compute `metric.calculate(unit_metrics_results)` + +## Programmatic access + +- **Unit metric results** + - Read the per-record output file and inspect fields ending in `_result` (e.g., `exact_match_result`). + - These fields are dict-serialized `UnitMetricResult` values. +- **Aggregator metric results** + - Read `MetricCollatorPostProcessor_*.json` and access: + - `"results"` for metrics + - `"evaluation_summary"` for counts/status + +## Extending evaluation + +### Add a new unit metric + +- Implement a class under `sygra.core.eval.metrics.unit_metrics.*` inheriting the unit metric base. +- Decorate it with the unit metric decorator so it registers (see `UnitMetricRegistry` in `sygra.core.eval.metrics.unit_metrics.unit_metric_registry`). +- Reference it by name in your task YAML under `unit_metrics_map`. + +### Add a new aggregator metric + +- Implement a class under `sygra.core.eval.metrics.aggregator_metrics.*` inheriting the aggregator metric base. +- Decorate/register it (see `AggregatorMetricRegistry` in `sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry`). +- Reference it by name in `graph_post_process` -> `MetricCollatorPostProcessor` -> `aggregator_metrics_map`. + +### Add a new graph-level report + +Graph post-processing is generic. + +- Implement a `GraphPostProcessor` (`sygra.core.graph.graph_postprocessor.GraphPostProcessor`). +- Add it to `graph_post_process` in the task config. +- SyGra will automatically write a new file by replacing `output` with your post-processor class name. + +## Troubleshooting / notes + +- **Graph post-processing reads the full output into memory** + - `GraphPostProcessor` runs after the graph completes and loads the entire output file. This is appropriate for evaluation, but avoid it for very large generations. +- **`MetricCollatorPostProcessor` expects the unit-metric field to exist** + - If the configured `unit_metrics_results` field is missing (or not included in `output_config.output_map`), metrics will fail with a `KeyError` listing available columns. +- **JSON vs JSONL** + - Current graph post-processing reads the output with `json.load(...)`, so it expects the run output to be a JSON array (`output_*.json`). + - If your run produces `output_*.jsonl`, you will need to either: + - Configure the task to write JSON (preferred for eval), or + - Implement a custom graph post-processor that can stream/parse JSONL. + +--- + diff --git a/tasks/eval/classification/simpleqa/graph_config.yaml b/tasks/eval/classification/simpleqa/graph_config.yaml new file mode 100644 index 00000000..5a0227cc --- /dev/null +++ b/tasks/eval/classification/simpleqa/graph_config.yaml @@ -0,0 +1,105 @@ +data_config: + source: + type: "hf" + repo_id: "google/simpleqa-verified" + config_name: "simpleqa_verified" + split: "eval" + + transformations: + - transform: sygra.processors.data_transform.RenameFieldsTransform + params: + mapping: + task_id: original_index + overwrite: false + +graph_config: + nodes: + generate_taxonomy: + node_type: llm + prompt: + - system: | + You are an assistant tasked with identifying the topic of a problem. + You should respond only with the topic out of the below list: + Politics + Art + TV shows + History + Music + Other + Geography + Sports + Science and technology + Video games + + - user: | + {problem} + output_keys: + - predicted_topic + post_process: tasks.eval.classification.simpleqa.task_executor.GenerateTopicPostProcessor + model: + name: eval_model + parameters: + temperature: 0.1 + structured_output: + schema: + fields: + predicted_topic: + type: str + description: "Predicted topic of the problem" + unit_metrics: + node_type: lambda + lambda: tasks.eval.utils.UnitMetrics + golden_key: "topic" + predicted_key: "predicted_topic" + unit_metrics_map: + - name: "exact_match" + params: + key: "text" + output_keys: + - exact_match_result + edges: + - from: START + to: generate_taxonomy + - from: generate_taxonomy + to: unit_metrics + - from: unit_metrics + to: END + +output_config: + output_map: + original_index: + from: "original_index" + problem: + from: "problem" + topic: + from: "topic" + predicted_topic: + from: "predicted_topic" + exact_match_result: + from: "exact_match_result" + +graph_post_process: + - processor: tasks.eval.utils.MetricCollatorPostProcessor + params: + aggregator_metrics_map: + - name: "accuracy" + params: + key: "text" + unit_metrics_results: + - "exact_match_result" + - name: "precision" + params: + predicted_key: "text" + unit_metrics_results: + - "exact_match_result" + - name: "recall" + params: + golden_key: "text" + unit_metrics_results: + - "exact_match_result" + - name: "f1_score" + params: + predicted_key: "text" + golden_key: "text" + unit_metrics_results: + - "exact_match_result" diff --git a/tasks/eval/classification/simpleqa/task_executor.py b/tasks/eval/classification/simpleqa/task_executor.py new file mode 100644 index 00000000..609b7619 --- /dev/null +++ b/tasks/eval/classification/simpleqa/task_executor.py @@ -0,0 +1,25 @@ +"""Post-processors for SimpleQA classification evaluation tasks.""" + +from sygra.core.graph.functions.node_processor import NodePostProcessorWithState +from sygra.core.graph.sygra_message import SygraMessage +from sygra.core.graph.sygra_state import SygraState +from tasks.eval.utils import parse_response_as_json + + +class GenerateTopicPostProcessor(NodePostProcessorWithState): + """Extract `predicted_topic` from a model response and store it in the state.""" + + def apply(self, response: SygraMessage, state: SygraState) -> SygraState: + """Parse `response` as JSON and update `state` with `predicted_topic`.""" + content = response.message.content + json_data = parse_response_as_json(content) + if json_data: + output_dict = { + "predicted_topic": json_data.get("predicted_topic", ""), + } + state.update(output_dict) + return state + + else: + state.update({"predicted_topic": ""}) + return state diff --git a/tasks/eval/question_answering/simpleqa/graph_config.yaml b/tasks/eval/question_answering/simpleqa/graph_config.yaml new file mode 100644 index 00000000..9a28d3ed --- /dev/null +++ b/tasks/eval/question_answering/simpleqa/graph_config.yaml @@ -0,0 +1,79 @@ +data_config: + source: + type: "hf" + repo_id: "google/simpleqa-verified" + config_name: "simpleqa_verified" + split: "eval" + + transformations: + - transform: sygra.processors.data_transform.RenameFieldsTransform + params: + mapping: + task_id: original_index + overwrite: false + +graph_config: + nodes: + predict_answer: + node_type: llm + prompt: + - system: | + You are an assistant tasked with identifying the answer of a problem. + You should respond only with the answer of the problem without any extra text. + - user: | + {problem} + output_keys: + - predicted_answer + post_process: tasks.eval.question_answering.simpleqa.task_executor.GenerateAnswerPostProcessor + model: + name: eval_model + parameters: + temperature: 0.1 + structured_output: + schema: + fields: + predicted_answer: + type: str + description: "Predicted answer of the problem" + unit_metrics: + node_type: lambda + lambda: tasks.eval.utils.UnitMetrics + golden_key: "answer" + predicted_key: "predicted_answer" + unit_metrics_map: + - name: "exact_match" + params: + key: "text" + output_keys: + - exact_match_result + edges: + - from: START + to: predict_answer + - from: predict_answer + to: unit_metrics + - from: unit_metrics + to: END + +output_config: + output_map: + original_index: + from: "original_index" + problem: + from: "problem" + answer: + from: "answer" + predicted_answer: + from: "predicted_answer" + exact_match_result: + from: "exact_match_result" + + +graph_post_process: + - processor: tasks.eval.utils.MetricCollatorPostProcessor + params: + aggregator_metrics_map: + - name: "accuracy" + params: + key: "text" + unit_metrics_results: + - "exact_match_result" \ No newline at end of file diff --git a/tasks/eval/question_answering/simpleqa/task_executor.py b/tasks/eval/question_answering/simpleqa/task_executor.py new file mode 100644 index 00000000..bfe2fb7c --- /dev/null +++ b/tasks/eval/question_answering/simpleqa/task_executor.py @@ -0,0 +1,25 @@ +"""Post-processors for SimpleQA question answering evaluation tasks.""" + +from sygra.core.graph.functions.node_processor import NodePostProcessorWithState +from sygra.core.graph.sygra_message import SygraMessage +from sygra.core.graph.sygra_state import SygraState +from tasks.eval.utils import parse_response_as_json + + +class GenerateAnswerPostProcessor(NodePostProcessorWithState): + """Extract `predicted_answer` from a model response and store it in the state.""" + + def apply(self, response: SygraMessage, state: SygraState) -> SygraState: + """Parse `response` as JSON and update `state` with `predicted_answer`.""" + content = response.message.content + json_data = parse_response_as_json(content) + if json_data: + output_dict = { + "predicted_answer": json_data.get("predicted_answer", ""), + } + state.update(output_dict) + return state + + else: + state.update({"predicted_answer": ""}) + return state diff --git a/tasks/eval/utils.py b/tasks/eval/utils.py new file mode 100644 index 00000000..c2a3ab2f --- /dev/null +++ b/tasks/eval/utils.py @@ -0,0 +1,166 @@ +import json + +""" +Utilities for evaluation tasks. + +Includes helpers to parse model responses as JSON, run unit metrics in a graph node, +and collate metric results into an evaluation report. +""" + +from typing import Any, Optional + +import pandas as pd +import regex + +from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import AggregatorMetricRegistry +from sygra.core.eval.metrics.unit_metrics.unit_metric_registry import UnitMetricRegistry +from sygra.core.eval.metrics.unit_metrics.unit_metric_result import UnitMetricResult +from sygra.core.graph.functions.lambda_function import LambdaFunction +from sygra.core.graph.graph_postprocessor import GraphPostProcessor +from sygra.core.graph.sygra_state import SygraState +from sygra.logger.logger_config import logger + + +def parse_response_as_json(s: Any) -> Optional[dict[str, Any]]: + """Parse a model response into a JSON object. + + This helper first attempts to parse the full string as JSON. If that fails, it + falls back to extracting the first balanced JSON object substring (supports + nested braces) and parsing that. + + Returns `None` if parsing fails. + """ + + JSON_REGEX_PATTERN = regex.compile(r"\{(?:[^{}]|(?R))*\}") + + if s is None: + return None + + text = s if isinstance(s, str) else str(s) + try: + parsed = json.loads(text) + return parsed if isinstance(parsed, dict) else {"value": parsed} + except json.decoder.JSONDecodeError as e: + match = JSON_REGEX_PATTERN.search(text) + if not match: + logger.error("No json string found: " + e.msg) + logger.error(text) + return None + try: + parsed = json.loads(match[0]) + return parsed if isinstance(parsed, dict) else {"value": parsed} + except json.decoder.JSONDecodeError as e2: + logger.error("Unable to parse json string: " + e2.msg) + logger.error(text) + return None + + +class UnitMetrics(LambdaFunction): + """Graph lambda that evaluates configured unit metrics and stores results in state.""" + + @staticmethod + def apply(lambda_node_dict: dict, state: SygraState) -> SygraState: + golden_topic = [{"text": state[lambda_node_dict["golden_key"]]}] + predicted_answer = [{"text": state[lambda_node_dict["predicted_key"]]}] + for unit_metric in lambda_node_dict.get("unit_metrics_map", []): + unit_metric_name = unit_metric["name"] + unit_metric_params = unit_metric.get("params", {}) + validator = UnitMetricRegistry.get_metric(unit_metric_name, **unit_metric_params) + results = validator.evaluate(golden=golden_topic, predicted=predicted_answer) + if results: + state[unit_metric_name + "_result"] = results[0].to_dict() + return state + + +class MetricCollatorPostProcessor(GraphPostProcessor): + """ + Post-processor that calculates evaluation metrics from the task. + + Note: Records with structural errors (raised as StructuralError exceptions during + preprocessing) will not reach this post-processor and will be automatically skipped + from metrics calculation. + """ + + def __init__( + self, + aggregator_metrics_map: Optional[list[dict[str, Any]]] = None, + unit_metrics_results: str = "text" + ): + self.aggregator_metrics_map = aggregator_metrics_map or [] + self.unit_metrics_results = unit_metrics_results + + def process(self, data: list, metadata: dict) -> list: + """ + Calculate comprehensive metrics from evaluation data. + Handles invalid records gracefully - skips them and continues with valid data. + + Args: + data: List of evaluation records with model_responses and golden_response + metadata: Any extract information to pass from dataset_processor like model name or file name etc + + Returns: + Metrics report list for downstream usage (always returns a report, even if no valid data) + """ + logger.info(f"MetricCollatorPostProcessor: Starting metrics calculation for {len(data)} records") + + try: + if not data: + logger.warning("MetricCollatorPostProcessor: No data provided") + return [{ + "evaluation_summary": { + "total_records": 0, + "status": "no_data" + }, + "results": {} + }] + + df = pd.DataFrame(data) + results: dict[str, Any] = {} + for aggregator_metric in self.aggregator_metrics_map: + aggregator_metric_name = aggregator_metric["name"] + aggregator_metric_params = aggregator_metric.get("params", {}) + + unit_metrics_field = aggregator_metric.get( + "unit_metrics_results", self.unit_metrics_results + ) + if isinstance(unit_metrics_field, list): + unit_metrics_field = unit_metrics_field[0] if unit_metrics_field else "" + + if not unit_metrics_field or unit_metrics_field not in df.columns: + raise KeyError( + f"Missing unit metric results field '{unit_metrics_field}' in data. " + f"Available columns: {list(df.columns)}" + ) + + unit_metrics_results = ( + df[unit_metrics_field] + .apply(lambda d: UnitMetricResult(**d) if isinstance(d, dict) else d) + .tolist() + ) + + metric = AggregatorMetricRegistry.get_metric( + aggregator_metric_name, **aggregator_metric_params + ) + metric_result = metric.calculate(unit_metrics_results) + + results[aggregator_metric_name] = metric_result + + return [{ + "evaluation_summary": { + "total_records": len(data), + "status": "success" + }, + "results": results + }] + + except Exception as e: + logger.error(f"MetricCollatorPostProcessor: Fatal error calculating metrics: {e}") + # Return error report but don't fail completely - downstream still gets something + return [{ + "evaluation_summary": { + "total_records": len(data) if data else 0, + "status": "fatal_error", + "error": str(e) + }, + "results": {} + }] diff --git a/tests/core/eval/metrics/aggregator_metrics/test_f1_score.py b/tests/core/eval/metrics/aggregator_metrics/test_f1_score.py index 5a6adbdf..726b4dfc 100644 --- a/tests/core/eval/metrics/aggregator_metrics/test_f1_score.py +++ b/tests/core/eval/metrics/aggregator_metrics/test_f1_score.py @@ -22,43 +22,26 @@ class TestF1ScoreMetric: def test_get_metric_name(self): """Test that metric name is 'f1_score'""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") + metric = F1ScoreMetric(predicted_key="class", golden_key="class") assert metric.get_metric_name() == "f1_score" - def test_initialization_with_parameters(self): - """Test initialization with custom parameters""" - metric = F1ScoreMetric(predicted_key="tool", golden_key="event", positive_class="click") - assert metric.predicted_key == "tool" - assert metric.golden_key == "event" - assert metric.positive_class == "click" - def test_initialization_requires_parameters(self): - """Test that initialization requires predicted_key, golden_key, and positive_class""" - # Should raise ValidationError when parameters are missing - with pytest.raises(ValidationError): - F1ScoreMetric(golden_key="class", positive_class="A") - - with pytest.raises(ValidationError): - F1ScoreMetric(predicted_key="class", positive_class="A") - + """Test that initialization requires predicted_key and golden_key""" with pytest.raises(ValidationError): - F1ScoreMetric(predicted_key="class", golden_key="class") + F1ScoreMetric(golden_key="class") - # Should raise ValidationError when predicted_key is empty with pytest.raises(ValidationError): - F1ScoreMetric(predicted_key="", golden_key="class", positive_class="A") + F1ScoreMetric(predicted_key="class") - # Should raise ValidationError when golden_key is empty with pytest.raises(ValidationError): - F1ScoreMetric(predicted_key="class", golden_key="", positive_class="A") + F1ScoreMetric(predicted_key="", golden_key="class") - # Should raise ValidationError when positive_class is None with pytest.raises(ValidationError): - F1ScoreMetric(predicted_key="class", golden_key="class", positive_class=None) + F1ScoreMetric(predicted_key="class", golden_key="") def test_initialization_creates_precision_and_recall_metrics(self): """Test that initialization creates precision and recall metric instances""" - metric = F1ScoreMetric(predicted_key="tool", golden_key="event", positive_class="click") + metric = F1ScoreMetric(predicted_key="tool", golden_key="event") assert metric.precision_metric is not None assert metric.recall_metric is not None assert metric.precision_metric.predicted_key == "tool" @@ -66,333 +49,75 @@ def test_initialization_creates_precision_and_recall_metrics(self): def test_calculate_empty_results(self): """Test calculate with empty results list""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") - results = [] - output = metric.calculate(results) - - assert "f1_score" in output - assert output["f1_score"] == 0.0 - - def test_calculate_perfect_f1_score(self): - """Test calculate with perfect F1 score (precision=1.0, recall=1.0)""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - assert "f1_score" in output - assert output["f1_score"] == 1.0 - - def test_calculate_zero_f1_score(self): - """Test calculate with zero F1 score (no true positives)""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="click") - results = [ - # False Positives - UnitMetricResult( - correct=False, - golden={"class": "type"}, - predicted={"class": "click"}, - ), - # False Negatives - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "type"}, - ), - ] - output = metric.calculate(results) - - assert "f1_score" in output - assert output["f1_score"] == 0.0 - - def test_calculate_balanced_f1_score(self): - """Test calculate with balanced precision and recall""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Positive - UnitMetricResult( - correct=False, - golden={"class": "type"}, - predicted={"class": "click"}, - ), - # False Negative - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "type"}, - ), - ] - output = metric.calculate(results) - - # TP=1, FP=1, FN=1 - # Precision = 1/(1+1) = 0.5 - # Recall = 1/(1+1) = 0.5 - # F1 = 2 * (0.5 * 0.5) / (0.5 + 0.5) = 0.5 - assert "f1_score" in output - assert output["f1_score"] == 0.5 - - def test_calculate_high_precision_low_recall(self): - """Test calculate with high precision but low recall""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Negatives (missed positives) - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "type"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "scroll"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - # TP=1, FP=0, FN=3 - # Precision = 1/(1+0) = 1.0 - # Recall = 1/(1+3) = 0.25 - # F1 = 2 * (1.0 * 0.25) / (1.0 + 0.25) = 0.4 - assert "f1_score" in output - assert output["f1_score"] == 0.4 - - def test_calculate_low_precision_high_recall(self): - """Test calculate with low precision but high recall""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Positives (wrong predictions) - UnitMetricResult( - correct=False, - golden={"class": "type"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "scroll"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "hover"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - # TP=1, FP=3, FN=0 - # Precision = 1/(1+3) = 0.25 - # Recall = 1/(1+0) = 1.0 - # F1 = 2 * (0.25 * 1.0) / (0.25 + 1.0) = 0.4 - assert "f1_score" in output - assert output["f1_score"] == 0.4 - - def test_calculate_with_different_keys(self): - """Test calculate with different predicted_key and golden_key""" - metric = F1ScoreMetric(predicted_key="tool", golden_key="event", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"event": "click"}, - predicted={"tool": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"event": "type"}, - predicted={"tool": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"event": "click"}, - predicted={"tool": "type"}, - ), - ] - output = metric.calculate(results) + metric = F1ScoreMetric(predicted_key="class", golden_key="class") + output = metric.calculate([]) - # TP=1, FP=1, FN=1 - # Precision = 1/(1+1) = 0.5 - # Recall = 1/(1+1) = 0.5 - # F1 = 0.5 - assert "f1_score" in output - assert output["f1_score"] == 0.5 + assert output == {"average_f1_score": 0.0, "f1_score_per_class": {}} - def test_calculate_with_numeric_positive_class(self): - """Test calculate with numeric positive class""" - metric = F1ScoreMetric(predicted_key="label", golden_key="label", positive_class=1) + def test_calculate_f1_per_class_and_average(self): + """Test per-class F1 and macro-average F1""" + metric = F1ScoreMetric(predicted_key="class", golden_key="class") results = [ - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=False, golden={"label": 0}, predicted={"label": 1}), - UnitMetricResult(correct=False, golden={"label": 1}, predicted={"label": 0}), - ] - output = metric.calculate(results) - - # TP=2, FP=1, FN=1 - # Precision = 2/(2+1) = 0.666... - # Recall = 2/(2+1) = 0.666... - # F1 = 2 * (0.666 * 0.666) / (0.666 + 0.666) = 0.666... - assert "f1_score" in output - assert output["f1_score"] == pytest.approx(0.666, rel=1e-2) - - def test_calculate_with_boolean_positive_class(self): - """Test calculate with boolean positive class""" - metric = F1ScoreMetric(predicted_key="is_valid", golden_key="is_valid", positive_class=True) - results = [ - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=False, - golden={"is_valid": False}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=False, - golden={"is_valid": True}, - predicted={"is_valid": False}, - ), - ] - output = metric.calculate(results) - - # TP=2, FP=1, FN=1 - # Precision = 2/(2+1) = 0.666... - # Recall = 2/(2+1) = 0.666... - # F1 = 0.666... - assert "f1_score" in output - assert output["f1_score"] == pytest.approx(0.666, rel=1e-2) - - def test_calculate_single_true_positive(self): - """Test calculate with single true positive""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") - results = [UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"})] - output = metric.calculate(results) - - # TP=1, FP=0, FN=0 - # Precision = 1.0, Recall = 1.0, F1 = 1.0 - assert "f1_score" in output - assert output["f1_score"] == 1.0 - - def test_calculate_with_true_negatives(self): - """Test calculate with true negatives (shouldn't affect F1)""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") - results = [ - # True Positive UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - # True Negatives (shouldn't affect F1) - UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), - UnitMetricResult(correct=True, golden={"class": "C"}, predicted={"class": "C"}), - # False Positive + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), ] output = metric.calculate(results) - # TP=1, FP=1, FN=0 - # Precision = 1/(1+1) = 0.5 - # Recall = 1/(1+0) = 1.0 - # F1 = 2 * (0.5 * 1.0) / (0.5 + 1.0) = 0.666... - assert "f1_score" in output - assert output["f1_score"] == pytest.approx(0.666, rel=1e-2) + assert set(output.keys()) == {"average_f1_score", "f1_score_per_class"} + assert output["f1_score_per_class"]["A"] == pytest.approx(0.5) + assert output["f1_score_per_class"]["B"] == pytest.approx(0.5) + assert output["average_f1_score"] == pytest.approx(0.5) - def test_calculate_various_f1_values(self): - """Test calculate with various F1 score values""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") - - # F1 = 0.8 (Precision=0.8, Recall=0.8) - # TP=4, FP=1, FN=1 + def test_calculate_union_of_classes(self): + """Test that per-class output uses union of precision and recall classes""" + metric = F1ScoreMetric(predicted_key="class", golden_key="class") results = [ - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), ] output = metric.calculate(results) - assert output["f1_score"] == pytest.approx(0.8, rel=1e-9) - def test_calculate_harmonic_mean_property(self): - """Test that F1 is indeed the harmonic mean of precision and recall""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") + assert output["f1_score_per_class"] == {"A": 0.0, "B": 0.0} + assert output["average_f1_score"] == 0.0 + + def test_calculate_skips_rows_missing_keys(self): + """Test that rows missing predicted or golden keys are skipped by underlying metrics""" + metric = F1ScoreMetric(predicted_key="class", golden_key="class") results = [ + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"other": "A"}), + UnitMetricResult(correct=True, golden={"other": "A"}, predicted={"class": "A"}), UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "C"}), ] output = metric.calculate(results) - # TP=2, FP=1, FN=2 - # Precision = 2/(2+1) = 0.666... - # Recall = 2/(2+2) = 0.5 - # F1 = 2 * (0.666 * 0.5) / (0.666 + 0.5) = 0.571... - precision = 2 / 3 - recall = 2 / 4 - expected_f1 = 2 * (precision * recall) / (precision + recall) - - assert "f1_score" in output - assert output["f1_score"] == pytest.approx(expected_f1, rel=1e-2) + assert output["f1_score_per_class"] == {"A": 1.0} + assert output["average_f1_score"] == 1.0 - def test_calculate_when_precision_or_recall_is_zero(self): - """Test calculate when either precision or recall is zero""" - metric = F1ScoreMetric(predicted_key="class", golden_key="class", positive_class="A") - - # Only false positives (precision=0, recall undefined) + def test_calculate_multi_class_f1(self): + """Test multi-class F1 computation across 3 classes""" + metric = F1ScoreMetric(predicted_key="class", golden_key="class") results = [ + # Class A: + # predicted A: 2 total, 1 correct => P(A)=0.5 + # golden A: 2 total, 1 correct => R(A)=0.5 + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "C"}, predicted={"class": "A"}), + # Class B: + # predicted B: 2 total, 1 correct => P(B)=0.5 + # golden B: 2 total, 1 correct => R(B)=0.5 + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + # Class C: + # predicted C: 1 total, 0 correct => P(C)=0 + # golden C: 1 total, 0 correct => R(C)=0 + UnitMetricResult(correct=False, golden={"class": "C"}, predicted={"class": "C"}), ] output = metric.calculate(results) - assert output["f1_score"] == 0.0 - # Only false negatives (precision undefined, recall=0) - results = [ - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "C"}), - ] - output = metric.calculate(results) - assert output["f1_score"] == 0.0 + assert output["f1_score_per_class"] == { + "A": pytest.approx(0.5), + "B": pytest.approx(0.5), + "C": 0.0, + } + assert output["average_f1_score"] == pytest.approx((0.5 + 0.5 + 0.0) / 3) diff --git a/tests/core/eval/metrics/aggregator_metrics/test_precision.py b/tests/core/eval/metrics/aggregator_metrics/test_precision.py index 509f4812..24f9ac76 100644 --- a/tests/core/eval/metrics/aggregator_metrics/test_precision.py +++ b/tests/core/eval/metrics/aggregator_metrics/test_precision.py @@ -22,311 +22,83 @@ class TestPrecisionMetric: def test_get_metric_name(self): """Test that metric name is 'precision'""" - metric = PrecisionMetric(predicted_key="class", positive_class="A") + metric = PrecisionMetric(predicted_key="class") assert metric.get_metric_name() == "precision" - def test_initialization_with_parameters(self): - """Test initialization with custom parameters""" - metric = PrecisionMetric(predicted_key="tool", positive_class="click") - assert metric.predicted_key == "tool" - assert metric.positive_class == "click" - - def test_initialization_requires_parameters(self): - """Test that initialization requires both predicted_key and positive_class""" - # Should raise ValidationError when predicted_key is missing - with pytest.raises(ValidationError): - PrecisionMetric(positive_class="A") - - # Should raise ValidationError when positive_class is missing + def test_initialization_requires_predicted_key(self): + """Test that initialization requires predicted_key""" with pytest.raises(ValidationError): - PrecisionMetric(predicted_key="class") + PrecisionMetric() - # Should raise ValidationError when predicted_key is empty with pytest.raises(ValidationError): - PrecisionMetric(predicted_key="", positive_class="A") - - # Should raise ValidationError when positive_class is None - with pytest.raises(ValidationError): - PrecisionMetric(predicted_key="class", positive_class=None) + PrecisionMetric(predicted_key="") def test_calculate_empty_results(self): """Test calculate with empty results list""" - metric = PrecisionMetric(predicted_key="class", positive_class="A") - results = [] - output = metric.calculate(results) - - assert "precision" in output - assert output["precision"] == 0.0 - - def test_calculate_perfect_precision(self): - """Test calculate with perfect precision (all positive predictions are correct)""" - metric = PrecisionMetric(predicted_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - assert "precision" in output - assert output["precision"] == 1.0 - - def test_calculate_zero_precision(self): - """Test calculate with zero precision (all positive predictions are wrong)""" - metric = PrecisionMetric(predicted_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=False, - golden={"class": "type"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "scroll"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "hover"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - assert "precision" in output - assert output["precision"] == 0.0 - - def test_calculate_mixed_precision(self): - """Test calculate with mixed true positives and false positives""" - metric = PrecisionMetric(predicted_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Positive - UnitMetricResult( - correct=False, - golden={"class": "type"}, - predicted={"class": "click"}, - ), - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Positive - UnitMetricResult( - correct=False, - golden={"class": "scroll"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - # TP = 2, FP = 2, Precision = 2/(2+2) = 0.5 - assert "precision" in output - assert output["precision"] == 0.5 - - def test_calculate_with_negative_class_predictions(self): - """Test calculate when some predictions are not the positive class""" - metric = PrecisionMetric(predicted_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # True Negative (not predicted as positive class) - UnitMetricResult( - correct=True, - golden={"class": "type"}, - predicted={"class": "type"}, - ), - # False Positive - UnitMetricResult( - correct=False, - golden={"class": "scroll"}, - predicted={"class": "click"}, - ), - # True Negative - UnitMetricResult( - correct=True, - golden={"class": "hover"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - # TP = 1, FP = 1, Precision = 1/(1+1) = 0.5 - # True negatives don't affect precision - assert "precision" in output - assert output["precision"] == 0.5 + metric = PrecisionMetric(predicted_key="class") + output = metric.calculate([]) - def test_calculate_no_positive_predictions(self): - """Test calculate when no predictions are the positive class""" - metric = PrecisionMetric(predicted_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"class": "type"}, - predicted={"class": "type"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "scroll"}, - predicted={"class": "scroll"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "hover"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - # TP = 0, FP = 0, Precision = 0/0 = 0.0 (safe divide) - assert "precision" in output - assert output["precision"] == 0.0 + assert output == {"average_precision": 0.0, "precision_per_class": {}} - def test_calculate_with_different_predicted_key(self): - """Test calculate with different predicted_key""" - metric = PrecisionMetric(predicted_key="tool", positive_class="click") + def test_calculate_precision_per_class_and_average(self): + """Test precision per class and macro-average precision""" + metric = PrecisionMetric(predicted_key="class") results = [ - UnitMetricResult( - correct=True, - golden={"event": "click"}, - predicted={"tool": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"event": "type"}, - predicted={"tool": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"event": "click"}, - predicted={"tool": "click"}, - ), + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), + UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), ] output = metric.calculate(results) - # TP = 2, FP = 1, Precision = 2/(2+1) = 0.666... - assert "precision" in output - assert output["precision"] == pytest.approx(0.666, rel=1e-2) + # Predicted counts: A=2, B=2 + # True positives by predicted label: A=1, B=2 + # precision(A)=1/2, precision(B)=2/2 + assert output["precision_per_class"] == {"A": 0.5, "B": 1.0} + assert output["average_precision"] == pytest.approx((0.5 + 1.0) / 2) - def test_calculate_with_numeric_positive_class(self): - """Test calculate with numeric positive class""" - metric = PrecisionMetric(predicted_key="label", positive_class=1) + def test_calculate_skips_rows_missing_predicted_key(self): + """Test that rows missing predicted_key are skipped""" + metric = PrecisionMetric(predicted_key="class") results = [ - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=False, golden={"label": 0}, predicted={"label": 1}), - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=True, golden={"label": 0}, predicted={"label": 0}), + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"other": "A"}), + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), ] output = metric.calculate(results) - # TP = 2, FP = 1, Precision = 2/(2+1) = 0.666... - assert "precision" in output - assert output["precision"] == pytest.approx(0.666, rel=1e-2) + assert output["precision_per_class"] == {"A": 1.0} + assert output["average_precision"] == 1.0 - def test_calculate_with_boolean_positive_class(self): - """Test calculate with boolean positive class""" - metric = PrecisionMetric(predicted_key="is_valid", positive_class=True) + def test_calculate_returns_zero_when_all_rows_missing_predicted_key(self): + """Test safe behavior when nothing is usable for calculation""" + metric = PrecisionMetric(predicted_key="class") results = [ - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=False, - golden={"is_valid": False}, - predicted={"is_valid": True}, - ), + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"other": "A"}), + UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"other": "B"}), ] output = metric.calculate(results) - # TP = 2, FP = 1, Precision = 2/(2+1) = 0.666... - assert "precision" in output - assert output["precision"] == pytest.approx(0.666, rel=1e-2) + assert output == {"average_precision": 0.0, "precision_per_class": {}} - def test_calculate_single_true_positive(self): - """Test calculate with single true positive""" - metric = PrecisionMetric(predicted_key="class", positive_class="A") - results = [UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"})] - output = metric.calculate(results) - - assert "precision" in output - assert output["precision"] == 1.0 - - def test_calculate_single_false_positive(self): - """Test calculate with single false positive""" - metric = PrecisionMetric(predicted_key="class", positive_class="A") - results = [UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"})] - output = metric.calculate(results) - - assert "precision" in output - assert output["precision"] == 0.0 - - def test_calculate_with_missing_predicted_key(self): - """Test calculate when predicted dict doesn't have the key""" - metric = PrecisionMetric(predicted_key="class", positive_class="A") + def test_calculate_multi_class_precision(self): + """Test multi-class precision computation across 3 classes""" + metric = PrecisionMetric(predicted_key="class") results = [ - UnitMetricResult( - correct=False, - golden={"class": "A"}, - predicted={"other_key": "B"}, # Missing 'class' key - ), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - ] - output = metric.calculate(results) - - # Only 1 TP (second result), 0 FP - assert "precision" in output - assert output["precision"] == 1.0 - - def test_calculate_various_precision_values(self): - """Test calculate with various precision percentages""" - # 80% precision (4 TP, 1 FP) - metric = PrecisionMetric(predicted_key="class", positive_class="A") - results = [ - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), + # Predicted A: 2 total, 1 correct => precision(A)=0.5 UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), + # Predicted B: 3 total, 2 correct => precision(B)=2/3 + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=False, golden={"class": "C"}, predicted={"class": "B"}), + # Predicted C: 1 total, 0 correct => precision(C)=0.0 + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "C"}), ] output = metric.calculate(results) - assert output["precision"] == 0.8 - # 25% precision (1 TP, 3 FP) - results = [ - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "C"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "D"}, predicted={"class": "A"}), - ] - output = metric.calculate(results) - assert output["precision"] == 0.25 + assert output["precision_per_class"] == { + "A": 0.5, + "B": pytest.approx(2 / 3), + "C": 0.0, + } + assert output["average_precision"] == pytest.approx((0.5 + (2 / 3) + 0.0) / 3) diff --git a/tests/core/eval/metrics/aggregator_metrics/test_recall.py b/tests/core/eval/metrics/aggregator_metrics/test_recall.py index f2656090..a7c12410 100644 --- a/tests/core/eval/metrics/aggregator_metrics/test_recall.py +++ b/tests/core/eval/metrics/aggregator_metrics/test_recall.py @@ -22,329 +22,83 @@ class TestRecallMetric: def test_get_metric_name(self): """Test that metric name is 'recall'""" - metric = RecallMetric(golden_key="class", positive_class="A") + metric = RecallMetric(golden_key="class") assert metric.get_metric_name() == "recall" - def test_initialization_with_parameters(self): - """Test initialization with custom parameters""" - metric = RecallMetric(golden_key="event", positive_class="click") - assert metric.golden_key == "event" - assert metric.positive_class == "click" - - def test_initialization_requires_parameters(self): - """Test that initialization requires both golden_key and positive_class""" - # Should raise ValidationError when golden_key is missing - with pytest.raises(ValidationError): - RecallMetric(positive_class="A") - - # Should raise ValidationError when positive_class is missing - with pytest.raises(ValidationError): - RecallMetric(golden_key="class") - - # Should raise ValidationError when golden_key is empty + def test_initialization_requires_golden_key(self): + """Test that initialization requires golden_key""" with pytest.raises(ValidationError): - RecallMetric(golden_key="", positive_class="A") + RecallMetric() - # Should raise ValidationError when positive_class is None with pytest.raises(ValidationError): - RecallMetric(golden_key="class", positive_class=None) + RecallMetric(golden_key="") def test_calculate_empty_results(self): """Test calculate with empty results list""" - metric = RecallMetric(golden_key="class", positive_class="A") - results = [] - output = metric.calculate(results) - - assert "recall" in output - assert output["recall"] == 0.0 - - def test_calculate_perfect_recall(self): - """Test calculate with perfect recall (all actual positives are found)""" - metric = RecallMetric(golden_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - ] - output = metric.calculate(results) - - assert "recall" in output - assert output["recall"] == 1.0 - - def test_calculate_zero_recall(self): - """Test calculate with zero recall (all actual positives are missed)""" - metric = RecallMetric(golden_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "type"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "scroll"}, - ), - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - assert "recall" in output - assert output["recall"] == 0.0 - - def test_calculate_mixed_recall(self): - """Test calculate with mixed true positives and false negatives""" - metric = RecallMetric(golden_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Negative - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "type"}, - ), - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # False Negative - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "scroll"}, - ), - ] - output = metric.calculate(results) - - # TP = 2, FN = 2, Recall = 2/(2+2) = 0.5 - assert "recall" in output - assert output["recall"] == 0.5 - - def test_calculate_with_negative_class_in_golden(self): - """Test calculate when some golden values are not the positive class""" - metric = RecallMetric(golden_key="class", positive_class="click") - results = [ - # True Positive - UnitMetricResult( - correct=True, - golden={"class": "click"}, - predicted={"class": "click"}, - ), - # True Negative (golden is not positive class) - UnitMetricResult( - correct=True, - golden={"class": "type"}, - predicted={"class": "type"}, - ), - # False Negative - UnitMetricResult( - correct=False, - golden={"class": "click"}, - predicted={"class": "scroll"}, - ), - # True Negative - UnitMetricResult( - correct=True, - golden={"class": "hover"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - # TP = 1, FN = 1, Recall = 1/(1+1) = 0.5 - # True negatives don't affect recall - assert "recall" in output - assert output["recall"] == 0.5 - - def test_calculate_no_actual_positives(self): - """Test calculate when no golden values are the positive class""" - metric = RecallMetric(golden_key="class", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"class": "type"}, - predicted={"class": "type"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "scroll"}, - predicted={"class": "scroll"}, - ), - UnitMetricResult( - correct=True, - golden={"class": "hover"}, - predicted={"class": "hover"}, - ), - ] - output = metric.calculate(results) - - # TP = 0, FN = 0, Recall = 0/0 = 0.0 (safe divide) - assert "recall" in output - assert output["recall"] == 0.0 - - def test_calculate_with_different_golden_key(self): - """Test calculate with different golden_key""" - metric = RecallMetric(golden_key="event", positive_class="click") - results = [ - UnitMetricResult( - correct=True, - golden={"event": "click"}, - predicted={"tool": "click"}, - ), - UnitMetricResult( - correct=False, - golden={"event": "click"}, - predicted={"tool": "type"}, - ), - UnitMetricResult( - correct=True, - golden={"event": "click"}, - predicted={"tool": "click"}, - ), - ] - output = metric.calculate(results) + metric = RecallMetric(golden_key="class") + output = metric.calculate([]) - # TP = 2, FN = 1, Recall = 2/(2+1) = 0.666... - assert "recall" in output - assert output["recall"] == pytest.approx(0.666, rel=1e-2) + assert output == {"average_recall": 0.0, "recall_per_class": {}} - def test_calculate_with_numeric_positive_class(self): - """Test calculate with numeric positive class""" - metric = RecallMetric(golden_key="label", positive_class=1) + def test_calculate_recall_per_class_and_average(self): + """Test recall per class and macro-average recall""" + metric = RecallMetric(golden_key="class") results = [ - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=False, golden={"label": 1}, predicted={"label": 0}), - UnitMetricResult(correct=True, golden={"label": 1}, predicted={"label": 1}), - UnitMetricResult(correct=True, golden={"label": 0}, predicted={"label": 0}), - ] - output = metric.calculate(results) - - # TP = 2, FN = 1, Recall = 2/(2+1) = 0.666... - assert "recall" in output - assert output["recall"] == pytest.approx(0.666, rel=1e-2) - - def test_calculate_with_boolean_positive_class(self): - """Test calculate with boolean positive class""" - metric = RecallMetric(golden_key="is_valid", positive_class=True) - results = [ - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=True, - golden={"is_valid": True}, - predicted={"is_valid": True}, - ), - UnitMetricResult( - correct=False, - golden={"is_valid": True}, - predicted={"is_valid": False}, - ), + UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), ] output = metric.calculate(results) - # TP = 2, FN = 1, Recall = 2/(2+1) = 0.666... - assert "recall" in output - assert output["recall"] == pytest.approx(0.666, rel=1e-2) - - def test_calculate_single_true_positive(self): - """Test calculate with single true positive""" - metric = RecallMetric(golden_key="class", positive_class="A") - results = [UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"})] - output = metric.calculate(results) - - assert "recall" in output - assert output["recall"] == 1.0 - - def test_calculate_single_false_negative(self): - """Test calculate with single false negative""" - metric = RecallMetric(golden_key="class", positive_class="A") - results = [UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"})] - output = metric.calculate(results) - - assert "recall" in output - assert output["recall"] == 0.0 + # Golden counts: A=2, B=2 + # True positives by golden label: A=1, B=2 + # recall(A)=1/2, recall(B)=2/2 + assert output["recall_per_class"] == {"A": 0.5, "B": 1.0} + assert output["average_recall"] == pytest.approx((0.5 + 1.0) / 2) - def test_calculate_with_missing_golden_key(self): - """Test calculate when golden dict doesn't have the key""" - metric = RecallMetric(golden_key="class", positive_class="A") + def test_calculate_skips_rows_missing_golden_key(self): + """Test that rows missing golden_key are skipped""" + metric = RecallMetric(golden_key="class") results = [ - UnitMetricResult( - correct=False, - golden={"other_key": "B"}, # Missing 'class' key - predicted={"class": "A"}, - ), + UnitMetricResult(correct=False, golden={"other": "A"}, predicted={"class": "A"}), UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), ] output = metric.calculate(results) - # Only 1 TP (second result), 0 FN - assert "recall" in output - assert output["recall"] == 1.0 + assert output["recall_per_class"] == {"A": 1.0} + assert output["average_recall"] == 1.0 - def test_calculate_various_recall_values(self): - """Test calculate with various recall percentages""" - # 80% recall (4 TP, 1 FN) - metric = RecallMetric(golden_key="class", positive_class="A") + def test_calculate_returns_zero_when_all_rows_missing_golden_key(self): + """Test safe behavior when nothing is usable for calculation""" + metric = RecallMetric(golden_key="class") results = [ - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"other": "A"}, predicted={"class": "A"}), + UnitMetricResult(correct=False, golden={"other": "B"}, predicted={"class": "B"}), ] output = metric.calculate(results) - assert output["recall"] == 0.8 - # 25% recall (1 TP, 3 FN) - results = [ - UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "C"}), - UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "D"}), - ] - output = metric.calculate(results) - assert output["recall"] == 0.25 + assert output == {"average_recall": 0.0, "recall_per_class": {}} - def test_calculate_with_false_positives_not_affecting_recall(self): - """Test that false positives don't affect recall calculation""" - metric = RecallMetric(golden_key="class", positive_class="A") + def test_calculate_multi_class_recall(self): + """Test multi-class recall computation across 3 classes""" + metric = RecallMetric(golden_key="class") results = [ - # True Positive + # Golden A: 2 total, 1 correct => recall(A)=0.5 UnitMetricResult(correct=True, golden={"class": "A"}, predicted={"class": "A"}), - # False Positive (doesn't affect recall) - UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "A"}), - # False Positive (doesn't affect recall) + UnitMetricResult(correct=False, golden={"class": "A"}, predicted={"class": "B"}), + # Golden B: 3 total, 2 correct => recall(B)=2/3 + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=True, golden={"class": "B"}, predicted={"class": "B"}), + UnitMetricResult(correct=False, golden={"class": "B"}, predicted={"class": "C"}), + # Golden C: 1 total, 0 correct => recall(C)=0.0 UnitMetricResult(correct=False, golden={"class": "C"}, predicted={"class": "A"}), ] output = metric.calculate(results) - # TP = 1, FN = 0, Recall = 1/(1+0) = 1.0 - # False positives don't affect recall - assert "recall" in output - assert output["recall"] == 1.0 + assert output["recall_per_class"] == { + "A": 0.5, + "B": pytest.approx(2 / 3), + "C": 0.0, + } + assert output["average_recall"] == pytest.approx((0.5 + (2 / 3) + 0.0) / 3)