Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 115 additions & 13 deletions py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from __future__ import annotations

import ast
import json
import re
import sys
from datetime import datetime
Expand All @@ -53,10 +54,16 @@
class ClassTransformer(ast.NodeTransformer):
"""AST transformer that modifies class definitions."""

def __init__(self) -> None:
"""Initialize the ClassTransformer."""
def __init__(self, models_allowing_extra: set[str] | None = None) -> None:
"""Initialize the ClassTransformer.

Args:
models_allowing_extra: Set of model names that have additionalProperties: true
in the JSON schema and should use extra='allow' instead of extra='forbid'.
"""
self.modified = False
self.schema_fields_to_suppress: list[ast.AnnAssign] = []
self.models_allowing_extra = models_allowing_extra or set()

def is_rootmodel_class(self, node: ast.ClassDef) -> bool:
"""Check if a class definition is a RootModel class."""
Expand All @@ -70,20 +77,27 @@ def is_rootmodel_class(self, node: ast.ClassDef) -> bool:
return False

def create_model_config(
self, existing_config: ast.Call | None = None, frozen: bool = False, has_schema_field: bool = False
self,
existing_config: ast.Call | None = None,
frozen: bool = False,
has_schema_field: bool = False,
allow_extra: bool = False,
) -> ast.AnnAssign:
"""Create or update a model_config assignment with proper type annotation.

Creates: model_config: ClassVar[ConfigDict] = ConfigDict(...)

Ensures alias_generator=to_camel, populate_by_name=True, and extra='forbid',
keeping other existing settings.
Ensures alias_generator=to_camel, populate_by_name=True, and the appropriate
'extra' setting based on the JSON schema's additionalProperties.

Args:
existing_config: Existing ConfigDict call to preserve settings from.
frozen: Whether to add frozen=True for immutable models.
has_schema_field: Whether the class has a 'schema' field. If True,
adds protected_namespaces=() to allow using 'schema' as a field name.
allow_extra: Whether to use extra='allow' (True) or extra='forbid' (False).
This corresponds to additionalProperties: true/false in JSON schema.
In Zod (JS), this is .passthrough() vs .strict().
"""
keywords = []
found_populate = False
Expand All @@ -102,7 +116,7 @@ def create_model_config(
)
found_populate = True
elif kw.arg == 'extra':
# Skip the existing 'extra', we will enforce 'forbid'
# Skip the existing 'extra', we will set based on allow_extra
continue
elif kw.arg == 'alias_generator':
# Skip existing alias_generator, we will add our own
Expand All @@ -122,8 +136,11 @@ def create_model_config(
else:
keywords.append(kw) # Keep other existing settings

# Always add extra='forbid'
keywords.append(ast.keyword(arg='extra', value=ast.Constant(value='forbid')))
# Add extra based on JSON schema's additionalProperties
# 'allow' corresponds to .passthrough() in Zod (JS)
# 'forbid' corresponds to .strict() or no modifier in Zod
extra_value = 'allow' if allow_extra else 'forbid'
keywords.append(ast.keyword(arg='extra', value=ast.Constant(value=extra_value)))

# Add populate_by_name=True if it wasn't found
if not found_populate:
Expand Down Expand Up @@ -314,6 +331,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object:
added_config = False
frozen = node.name == 'PathMetadata'
has_schema = self.has_schema_field(node)
# Check if this model allows extra properties (additionalProperties: true in JSON schema)
# This corresponds to .passthrough() in Zod (JS)
allow_extra = node.name in self.models_allowing_extra
for stmt in node.body[body_start_index:]:
# Check for model_config (both Assign and AnnAssign)
is_model_config = False
Expand All @@ -330,7 +350,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object:
if is_model_config:
# Update existing model_config
updated_config = self.create_model_config(
existing_model_config_call, frozen=frozen, has_schema_field=has_schema
existing_model_config_call,
frozen=frozen,
has_schema_field=has_schema,
allow_extra=allow_extra,
)
# Check if the config actually changed
if ast.dump(updated_config) != ast.dump(stmt):
Expand All @@ -354,7 +377,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> object:
# Add model_config if it wasn't present
# Insert after potential docstring
insert_pos = 1 if len(new_body) > 0 and isinstance(new_body[0], ast.Expr) else 0
new_body.insert(insert_pos, self.create_model_config(frozen=frozen, has_schema_field=has_schema))
new_body.insert(
insert_pos,
self.create_model_config(frozen=frozen, has_schema_field=has_schema, allow_extra=allow_extra),
)
self.modified = True
elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases):
# Uppercase Enum members
Expand Down Expand Up @@ -568,14 +594,74 @@ def add_header(content: str) -> str:
return final_output


def process_file(filename: str) -> None:
def load_models_allowing_extra(schema_path: Path) -> set[str]:
"""Load JSON schema and extract models with additionalProperties: true.

In Zod (JS), these are models defined with .passthrough() which allows
extra properties beyond the defined schema. In Pydantic (Python), this
corresponds to extra='allow' in the model_config.

This function checks both:
1. Top-level $defs with additionalProperties: true
2. Inline nested objects with additionalProperties: true (e.g., Score.details)

For inline schemas, datamodel-codegen extracts them as separate classes.
We map from the JSON schema property name to the generated Python class name.

Args:
schema_path: Path to the genkit-schema.json file.

Returns:
Set of model names that allow extra properties.
"""
if not schema_path.is_file():
return set()

try:
with schema_path.open(encoding='utf-8') as f:
schema = json.load(f)

defs = schema.get('$defs', {})
result: set[str] = set()

# 1. Check top-level $defs for additionalProperties: true
for name, defn in defs.items():
if isinstance(defn, dict) and defn.get('additionalProperties') is True:
result.add(name)

# 2. Check inline nested objects in properties
# datamodel-codegen extracts these as separate classes with PascalCase names
# e.g., Score.properties.details -> Details class
# e.g., Operation.properties.error -> Error class
for defn in defs.values():
if not isinstance(defn, dict):
continue
properties = defn.get('properties', {})
for prop_name, prop_defn in properties.items():
if not isinstance(prop_defn, dict):
continue
# Check if this is an inline object with additionalProperties: true
if prop_defn.get('type') == 'object' and prop_defn.get('additionalProperties') is True:
# Convert property name to PascalCase class name
# e.g., 'details' -> 'Details', 'error' -> 'Error'
class_name = prop_name[0].upper() + prop_name[1:]
result.add(class_name)

return result
except (json.JSONDecodeError, OSError):
return set()


def process_file(filename: str, schema_path: Path | None = None) -> None:
"""Process a Python file to remove model_config from RootModel classes.

This function reads a Python file, processes its AST to remove model_config
from RootModel classes, and writes the modified code back to the file.

Args:
filename: Path to the Python file to process.
schema_path: Path to genkit-schema.json for determining which models
allow extra properties.

Raises:
FileNotFoundError: If the input file does not exist.
Expand All @@ -585,12 +671,17 @@ def process_file(filename: str) -> None:
if not path.is_file():
sys.exit(1)

# Load models that allow extra properties from JSON schema
models_allowing_extra: set[str] = set()
if schema_path:
models_allowing_extra = load_models_allowing_extra(schema_path)

try:
with Path(path).open(encoding='utf-8') as f:
source = f.read()

tree = ast.parse(source)
class_transformer = ClassTransformer()
class_transformer = ClassTransformer(models_allowing_extra=models_allowing_extra)
modified_tree = class_transformer.visit(tree)

# Generate source from potentially modified AST
Expand Down Expand Up @@ -630,7 +721,18 @@ def main() -> None:
if len(sys.argv) != 2:
sys.exit(1)

process_file(sys.argv[1])
typing_file = Path(sys.argv[1])

# Derive genkit-schema.json path relative to the typing.py file
# typing.py is at: py/packages/genkit/src/genkit/core/typing.py
# schema is at: genkit-tools/genkit-schema.json
# So we go up 6 directories from typing.py to reach repo root, then into genkit-tools/
schema_path = typing_file.parent
for _ in range(6): # Go up: core -> genkit -> src -> genkit -> packages -> py -> (repo root)
schema_path = schema_path.parent
schema_path = schema_path / 'genkit-tools' / 'genkit-schema.json'

process_file(sys.argv[1], schema_path=schema_path if schema_path.is_file() else None)


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions py/packages/genkit/src/genkit/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class EvalStatusEnum(StrEnum):
class Details(BaseModel):
"""Model for details data."""

model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True)
model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True)
reasoning: str | None = None


Expand Down Expand Up @@ -264,7 +264,7 @@ class ModelInfo(BaseModel):
class Error(BaseModel):
"""Model for error data."""

model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True)
model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True)
message: str


Expand Down Expand Up @@ -363,7 +363,7 @@ class CommonRerankerOptions(BaseModel):
class RankedDocumentMetadata(BaseModel):
"""Model for rankeddocumentmetadata data."""

model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True)
model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True)
score: float


Expand Down
Loading