diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py index 8e6e2fb20f..0d35b253df 100644 --- a/py/bin/sanitize_schema_typing.py +++ b/py/bin/sanitize_schema_typing.py @@ -43,6 +43,7 @@ from __future__ import annotations import ast +import json import re import sys from datetime import datetime @@ -53,10 +54,16 @@ class ClassTransformer(ast.NodeTransformer): """AST transformer that modifies class definitions.""" - def __init__(self) -> None: - """Initialize the ClassTransformer.""" + def __init__(self, models_allowing_extra: set[str] | None = None) -> None: + """Initialize the ClassTransformer. + + Args: + models_allowing_extra: Set of model names that have additionalProperties: true + in the JSON schema and should use extra='allow' instead of extra='forbid'. + """ self.modified = False self.schema_fields_to_suppress: list[ast.AnnAssign] = [] + self.models_allowing_extra = models_allowing_extra or set() def is_rootmodel_class(self, node: ast.ClassDef) -> bool: """Check if a class definition is a RootModel class.""" @@ -70,20 +77,27 @@ def is_rootmodel_class(self, node: ast.ClassDef) -> bool: return False def create_model_config( - self, existing_config: ast.Call | None = None, frozen: bool = False, has_schema_field: bool = False + self, + existing_config: ast.Call | None = None, + frozen: bool = False, + has_schema_field: bool = False, + allow_extra: bool = False, ) -> ast.AnnAssign: """Create or update a model_config assignment with proper type annotation. Creates: model_config: ClassVar[ConfigDict] = ConfigDict(...) - Ensures alias_generator=to_camel, populate_by_name=True, and extra='forbid', - keeping other existing settings. + Ensures alias_generator=to_camel, populate_by_name=True, and the appropriate + 'extra' setting based on the JSON schema's additionalProperties. Args: existing_config: Existing ConfigDict call to preserve settings from. frozen: Whether to add frozen=True for immutable models. has_schema_field: Whether the class has a 'schema' field. If True, adds protected_namespaces=() to allow using 'schema' as a field name. + allow_extra: Whether to use extra='allow' (True) or extra='forbid' (False). + This corresponds to additionalProperties: true/false in JSON schema. + In Zod (JS), this is .passthrough() vs .strict(). """ keywords = [] found_populate = False @@ -102,7 +116,7 @@ def create_model_config( ) found_populate = True elif kw.arg == 'extra': - # Skip the existing 'extra', we will enforce 'forbid' + # Skip the existing 'extra', we will set based on allow_extra continue elif kw.arg == 'alias_generator': # Skip existing alias_generator, we will add our own @@ -122,8 +136,11 @@ def create_model_config( else: keywords.append(kw) # Keep other existing settings - # Always add extra='forbid' - keywords.append(ast.keyword(arg='extra', value=ast.Constant(value='forbid'))) + # Add extra based on JSON schema's additionalProperties + # 'allow' corresponds to .passthrough() in Zod (JS) + # 'forbid' corresponds to .strict() or no modifier in Zod + extra_value = 'allow' if allow_extra else 'forbid' + keywords.append(ast.keyword(arg='extra', value=ast.Constant(value=extra_value))) # Add populate_by_name=True if it wasn't found if not found_populate: @@ -314,6 +331,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object: added_config = False frozen = node.name == 'PathMetadata' has_schema = self.has_schema_field(node) + # Check if this model allows extra properties (additionalProperties: true in JSON schema) + # This corresponds to .passthrough() in Zod (JS) + allow_extra = node.name in self.models_allowing_extra for stmt in node.body[body_start_index:]: # Check for model_config (both Assign and AnnAssign) is_model_config = False @@ -330,7 +350,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object: if is_model_config: # Update existing model_config updated_config = self.create_model_config( - existing_model_config_call, frozen=frozen, has_schema_field=has_schema + existing_model_config_call, + frozen=frozen, + has_schema_field=has_schema, + allow_extra=allow_extra, ) # Check if the config actually changed if ast.dump(updated_config) != ast.dump(stmt): @@ -354,7 +377,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object: # Add model_config if it wasn't present # Insert after potential docstring insert_pos = 1 if len(new_body) > 0 and isinstance(new_body[0], ast.Expr) else 0 - new_body.insert(insert_pos, self.create_model_config(frozen=frozen, has_schema_field=has_schema)) + new_body.insert( + insert_pos, + self.create_model_config(frozen=frozen, has_schema_field=has_schema, allow_extra=allow_extra), + ) self.modified = True elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases): # Uppercase Enum members @@ -568,7 +594,65 @@ def add_header(content: str) -> str: return final_output -def process_file(filename: str) -> None: +def load_models_allowing_extra(schema_path: Path) -> set[str]: + """Load JSON schema and extract models with additionalProperties: true. + + In Zod (JS), these are models defined with .passthrough() which allows + extra properties beyond the defined schema. In Pydantic (Python), this + corresponds to extra='allow' in the model_config. + + This function checks both: + 1. Top-level $defs with additionalProperties: true + 2. Inline nested objects with additionalProperties: true (e.g., Score.details) + + For inline schemas, datamodel-codegen extracts them as separate classes. + We map from the JSON schema property name to the generated Python class name. + + Args: + schema_path: Path to the genkit-schema.json file. + + Returns: + Set of model names that allow extra properties. + """ + if not schema_path.is_file(): + return set() + + try: + with schema_path.open(encoding='utf-8') as f: + schema = json.load(f) + + defs = schema.get('$defs', {}) + result: set[str] = set() + + # 1. Check top-level $defs for additionalProperties: true + for name, defn in defs.items(): + if isinstance(defn, dict) and defn.get('additionalProperties') is True: + result.add(name) + + # 2. Check inline nested objects in properties + # datamodel-codegen extracts these as separate classes with PascalCase names + # e.g., Score.properties.details -> Details class + # e.g., Operation.properties.error -> Error class + for defn in defs.values(): + if not isinstance(defn, dict): + continue + properties = defn.get('properties', {}) + for prop_name, prop_defn in properties.items(): + if not isinstance(prop_defn, dict): + continue + # Check if this is an inline object with additionalProperties: true + if prop_defn.get('type') == 'object' and prop_defn.get('additionalProperties') is True: + # Convert property name to PascalCase class name + # e.g., 'details' -> 'Details', 'error' -> 'Error' + class_name = prop_name[0].upper() + prop_name[1:] + result.add(class_name) + + return result + except (json.JSONDecodeError, OSError): + return set() + + +def process_file(filename: str, schema_path: Path | None = None) -> None: """Process a Python file to remove model_config from RootModel classes. This function reads a Python file, processes its AST to remove model_config @@ -576,6 +660,8 @@ def process_file(filename: str) -> None: Args: filename: Path to the Python file to process. + schema_path: Path to genkit-schema.json for determining which models + allow extra properties. Raises: FileNotFoundError: If the input file does not exist. @@ -585,12 +671,17 @@ def process_file(filename: str) -> None: if not path.is_file(): sys.exit(1) + # Load models that allow extra properties from JSON schema + models_allowing_extra: set[str] = set() + if schema_path: + models_allowing_extra = load_models_allowing_extra(schema_path) + try: with Path(path).open(encoding='utf-8') as f: source = f.read() tree = ast.parse(source) - class_transformer = ClassTransformer() + class_transformer = ClassTransformer(models_allowing_extra=models_allowing_extra) modified_tree = class_transformer.visit(tree) # Generate source from potentially modified AST @@ -630,7 +721,18 @@ def main() -> None: if len(sys.argv) != 2: sys.exit(1) - process_file(sys.argv[1]) + typing_file = Path(sys.argv[1]) + + # Derive genkit-schema.json path relative to the typing.py file + # typing.py is at: py/packages/genkit/src/genkit/core/typing.py + # schema is at: genkit-tools/genkit-schema.json + # So we go up 6 directories from typing.py to reach repo root, then into genkit-tools/ + schema_path = typing_file.parent + for _ in range(6): # Go up: core -> genkit -> src -> genkit -> packages -> py -> (repo root) + schema_path = schema_path.parent + schema_path = schema_path / 'genkit-tools' / 'genkit-schema.json' + + process_file(sys.argv[1], schema_path=schema_path if schema_path.is_file() else None) if __name__ == '__main__': diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 9a51e3b718..02f5927450 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -79,7 +79,7 @@ class EvalStatusEnum(StrEnum): class Details(BaseModel): """Model for details data.""" - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) reasoning: str | None = None @@ -264,7 +264,7 @@ class ModelInfo(BaseModel): class Error(BaseModel): """Model for error data.""" - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) message: str @@ -363,7 +363,7 @@ class CommonRerankerOptions(BaseModel): class RankedDocumentMetadata(BaseModel): """Model for rankeddocumentmetadata data.""" - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) score: float