diff --git a/py/plugins/google-genai/README.md b/py/plugins/google-genai/README.md index 0a9e0e2ba3..436777263d 100644 --- a/py/plugins/google-genai/README.md +++ b/py/plugins/google-genai/README.md @@ -47,6 +47,77 @@ config = GeminiConfigSchema.model_validate({ }) ``` +### Vertex AI Rerankers + +The VertexAI plugin provides semantic rerankers for improving RAG quality by re-scoring documents based on relevance: + +```python +from genkit import Genkit +from genkit.plugins.google_genai import VertexAI + +ai = Genkit(plugins=[VertexAI(project='my-project')]) + +# Rerank documents after retrieval +ranked_docs = await ai.rerank( + reranker='vertexai/semantic-ranker-default@latest', + query='What is machine learning?', + documents=retrieved_docs, + options={'top_n': 5}, +) +``` + +**Supported Models:** + +| Model | Description | +|-------|-------------| +| `semantic-ranker-default@latest` | Latest default semantic ranker | +| `semantic-ranker-default-004` | Semantic ranker version 004 | +| `semantic-ranker-fast-004` | Fast variant (lower latency) | + +### Vertex AI Evaluators + +Built-in evaluators for assessing model output quality. Evaluators are automatically registered when using the VertexAI plugin and are accessed via `ai.evaluate()`: + +```python +from genkit import Genkit +from genkit.core.typing import BaseDataPoint +from genkit.plugins.google_genai import VertexAI + +ai = Genkit(plugins=[VertexAI(project='my-project')]) + +# Prepare test dataset +dataset = [ + BaseDataPoint( + input='Write about AI.', + output='AI is transforming industries through intelligent automation.', + ), +] + +# Evaluate fluency (scores 1-5) +results = await ai.evaluate( + evaluator='vertexai/fluency', + dataset=dataset, +) + +for result in results.root: + print(f'Score: {result.evaluation.score}') +``` + + +**Supported Metrics:** + +| Metric | Description | +|--------|-------------| +| `BLEU` | Translation quality (compare to reference) | +| `ROUGE` | Summarization quality | +| `FLUENCY` | Language mastery and readability | +| `SAFETY` | Harmful/inappropriate content detection | +| `GROUNDEDNESS` | Hallucination detection | +| `SUMMARIZATION_QUALITY` | Overall summarization ability | + ## Examples -For comprehensive usage examples, see [`py/samples/google-genai-hello/README.md`](../../samples/google-genai-hello/README.md). +For comprehensive usage examples, see: + +- [`py/samples/google-genai-hello/README.md`](../../samples/google-genai-hello/README.md) - Basic Gemini usage +- [`py/samples/vertexai-rerank-eval/README.md`](../../samples/vertexai-rerank-eval/README.md) - Rerankers and evaluators diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/__init__.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/__init__.py new file mode 100644 index 0000000000..4d352ec857 --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/__init__.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Vertex AI Evaluators for the Genkit framework. + +This module provides evaluation metrics using the Vertex AI Evaluation API. +These evaluators assess model outputs for quality metrics like BLEU, ROUGE, +fluency, safety, groundedness, and summarization quality. + +Key Concepts (ELI5):: + + ┌─────────────────────┬────────────────────────────────────────────────────┐ + │ Concept │ ELI5 Explanation │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Evaluator │ A "grader" that scores your AI's answers. │ + │ │ Like a teacher checking homework. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ BLEU Score │ Compares AI output to a "correct" answer. │ + │ │ Higher = closer to the reference text. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ ROUGE Score │ Measures how much key info is captured. │ + │ │ Good for checking if summaries hit key points. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Fluency │ How natural and readable the text is. │ + │ │ Does it sound like a human wrote it? │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Safety │ Is the content appropriate and safe? │ + │ │ No harmful, biased, or inappropriate content. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Groundedness │ Does the answer stick to the facts given? │ + │ │ No making things up (hallucinations). │ + └─────────────────────┴────────────────────────────────────────────────────┘ + +Data Flow:: + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ EVALUATION PIPELINE │ + │ │ + │ Test Dataset │ + │ [input, output, reference, context] │ + │ │ │ + │ ▼ │ + │ ┌─────────────────────────────────────────────────────────────────┐ │ + │ │ Vertex AI Evaluators │ │ + │ │ │ │ + │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────────────┐ │ │ + │ │ │ BLEU │ │ ROUGE │ │ Fluency │ │ Groundedness │ │ │ + │ │ │ (0.72) │ │ (0.68) │ │ (4/5) │ │ (5/5) │ │ │ + │ │ └─────────┘ └─────────┘ └─────────┘ └─────────────────────┘ │ │ + │ │ │ │ + │ │ ┌─────────┐ ┌──────────────┐ ┌───────────────────────────────┐│ │ + │ │ │ Safety │ │ Summarization│ │ Summarization Helpfulness ││ │ + │ │ │ (5/5) │ │ Quality (4/5)│ │ (4/5) ││ │ + │ │ └─────────┘ └──────────────┘ └───────────────────────────────┘│ │ + │ └─────────────────────────────────────────────────────────────────┘ │ + │ │ │ + │ ▼ │ + │ Evaluation Report │ + │ {"score": 0.85, "details": {"reasoning": "..."}} │ + └─────────────────────────────────────────────────────────────────────────┘ + +Overview: + Vertex AI offers built-in evaluation metrics that use machine learning + to score model outputs. These evaluators are useful for: + + - **Automated testing**: CI/CD quality gates for LLM outputs + - **Model comparison**: Compare different models or prompts + - **Quality assurance**: Catch regressions in output quality + - **Safety checks**: Ensure outputs meet safety standards + +Available Metrics: + +-----------------------------+-------------------------------------------+ + | Metric | Description | + +-----------------------------+-------------------------------------------+ + | BLEU | Compare output to reference (translation) | + | ROUGE | Compare output to reference (summarization)| + | FLUENCY | Assess language mastery and readability | + | SAFETY | Check for harmful/inappropriate content | + | GROUNDEDNESS | Verify output is grounded in context | + | SUMMARIZATION_QUALITY | Overall summarization ability | + | SUMMARIZATION_HELPFULNESS | Usefulness as a summary substitute | + | SUMMARIZATION_VERBOSITY | Conciseness of the summary | + +-----------------------------+-------------------------------------------+ + +Example: + Running evaluations:: + + from genkit import Genkit + from genkit.plugins.google_genai import VertexAI + from genkit.plugins.google_genai.evaluators import VertexAIEvaluationMetricType + + ai = Genkit(plugins=[VertexAI(project='my-project')]) + + # Prepare test dataset + dataset = [ + { + 'input': 'Summarize this article about AI...', + 'output': 'AI is transforming industries...', + 'reference': 'The article discusses how AI impacts...', + 'context': ['Article content here...'], + } + ] + + # Run fluency evaluation + results = await ai.evaluate( + evaluator='vertexai/fluency', + dataset=dataset, + ) + + for result in results: + print(f'Score: {result.evaluation.score}') + print(f'Reasoning: {result.evaluation.details.get("reasoning")}') + +Caveats: + - Requires Google Cloud project with Vertex AI API enabled + - Evaluators are billed per API call + - Some metrics require specific fields (e.g., GROUNDEDNESS needs context) + - Scores are subjective assessments, not ground truth + +See Also: + - Vertex AI Evaluation API: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation + - Genkit evaluation docs: https://genkit.dev/docs/evaluation +""" + +from genkit.plugins.google_genai.evaluators.evaluation import ( + VertexAIEvaluationMetricType, + create_vertex_evaluators, +) + +__all__ = [ + 'VertexAIEvaluationMetricType', + 'create_vertex_evaluators', +] diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/evaluation.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/evaluation.py new file mode 100644 index 0000000000..4c02eb8034 --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/evaluators/evaluation.py @@ -0,0 +1,499 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Vertex AI Evaluation implementation. + +This module implements the Vertex AI Evaluation API for evaluating model outputs +using built-in metrics like BLEU, ROUGE, fluency, safety, and more. + +Architecture:: + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ Vertex AI Evaluators Module │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Types & Configuration │ + │ ├── VertexAIEvaluationMetricType (enum) - Available metrics │ + │ └── VertexAIEvaluationMetricConfig - Per-metric configuration │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ EvaluatorFactory │ + │ ├── evaluate_instances() - Async API call to evaluateInstances │ + │ └── create_evaluator_fn() - Creates evaluator function for metric │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Evaluator Configurations (per metric) │ + │ ├── BLEU - to_request(), response_handler() │ + │ ├── ROUGE - to_request(), response_handler() │ + │ ├── FLUENCY - to_request(), response_handler() │ + │ ├── SAFETY - to_request(), response_handler() │ + │ ├── GROUNDEDNESS - to_request(), response_handler() │ + │ ├── SUMMARIZATION_QUALITY - to_request(), response_handler() │ + │ ├── SUMMARIZATION_HELPFULNESS - to_request(), response_handler() │ + │ └── SUMMARIZATION_VERBOSITY - to_request(), response_handler() │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Plugin Integration │ + │ └── create_vertex_evaluators() - Register evaluators with Genkit │ + └─────────────────────────────────────────────────────────────────────────┘ + +Implementation Notes: + - Uses Google Cloud Application Default Credentials (ADC) for auth + - Calls the Vertex AI Platform evaluateInstances v1beta1 endpoint + - Each metric has a specific request format and response handler + - Supports custom metric_spec for fine-tuning metric behavior +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + +from google.auth import default as google_auth_default +from google.auth.transport.requests import Request +from pydantic import BaseModel, ConfigDict + +from genkit.ai import GENKIT_CLIENT_HEADER +from genkit.blocks.evaluator import EvalFnResponse +from genkit.core.action import Action +from genkit.core.error import GenkitError +from genkit.core.http_client import get_cached_client +from genkit.core.typing import BaseDataPoint, Details, Score + +if TYPE_CHECKING: + from genkit.ai._registry import GenkitRegistry + + +class VertexAIEvaluationMetricType(StrEnum): + """Vertex AI Evaluation metric types. + + See API documentation for more information: + https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list + """ + + BLEU = 'BLEU' + ROUGE = 'ROUGE' + FLUENCY = 'FLUENCY' # Note: JS has typo 'FLEUNCY' but we use correct spelling + SAFETY = 'SAFETY' + GROUNDEDNESS = 'GROUNDEDNESS' + SUMMARIZATION_QUALITY = 'SUMMARIZATION_QUALITY' + SUMMARIZATION_HELPFULNESS = 'SUMMARIZATION_HELPFULNESS' + SUMMARIZATION_VERBOSITY = 'SUMMARIZATION_VERBOSITY' + + +class VertexAIEvaluationMetricConfig(BaseModel): + """Configuration for a Vertex AI evaluation metric. + + Attributes: + type: The metric type. + metric_spec: Additional metric-specific configuration. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict( + extra='allow', + populate_by_name=True, + ) + + type: VertexAIEvaluationMetricType + metric_spec: dict[str, Any] | None = None + + +def _create_list_based_score_handler(results_key: str, values_key: str) -> Callable[[dict[str, Any]], Score]: + """Create a response handler for metrics that return a list of scored values. + + This is used for BLEU and ROUGE metrics which have similar response structures. + + Args: + results_key: The key for the results object (e.g., 'bleuResults'). + values_key: The key for the metrics list (e.g., 'bleuMetricValues'). + + Returns: + A function that extracts a Score from the response. + """ + + def handler(response: dict[str, Any]) -> Score: + metrics = response.get(results_key, {}).get(values_key, []) + score = metrics[0].get('score') if metrics else None + return Score(score=score) + + return handler + + +# Union type for metric specification +VertexAIEvaluationMetric = VertexAIEvaluationMetricType | VertexAIEvaluationMetricConfig + + +def _stringify(value: Any) -> str: # noqa: ANN401 + """Convert a value to string for the API.""" + if isinstance(value, str): + return value + return json.dumps(value) + + +def _is_config(metric: VertexAIEvaluationMetric) -> bool: + """Check if metric is a config object.""" + return isinstance(metric, VertexAIEvaluationMetricConfig) + + +class EvaluatorFactory: + """Factory for creating Vertex AI evaluator actions.""" + + def __init__(self, project_id: str, location: str) -> None: + """Initialize the factory. + + Args: + project_id: Google Cloud project ID. + location: Google Cloud location. + """ + self.project_id = project_id + self.location = location + + async def evaluate_instances(self, request_body: dict[str, Any]) -> dict[str, Any]: + """Call the Vertex AI evaluateInstances API. + + Args: + request_body: The request body for the API. + + Returns: + The API response. + + Raises: + GenkitError: If the API call fails. + """ + location_name = f'projects/{self.project_id}/locations/{self.location}' + url = f'https://{self.location}-aiplatform.googleapis.com/v1beta1/{location_name}:evaluateInstances' + + # Get authentication token + # Use asyncio.to_thread to avoid blocking the event loop during token refresh + credentials, _ = google_auth_default() + await asyncio.to_thread(credentials.refresh, Request()) + token = credentials.token + + if not token: + raise GenkitError( + message='Unable to authenticate your request. ' + 'Please ensure you have valid Google Cloud credentials configured.', + status='UNAUTHENTICATED', + ) + + headers = { + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + } + + request = { + 'location': location_name, + **request_body, + } + + # Use cached client for better connection reuse. + # Note: Auth headers are passed per-request since tokens may expire. + client = get_cached_client( + cache_key='vertex-ai-evaluator', + timeout=60.0, + ) + + try: + response = await client.post( + url, + headers=headers, + json=request, + ) + + if response.status_code != 200: + error_message = response.text + try: + error_json = response.json() + if 'error' in error_json and 'message' in error_json['error']: + error_message = error_json['error']['message'] + except json.JSONDecodeError: # noqa: S110 + pass + + raise GenkitError( + message=f'Error calling Vertex AI Evaluation API: [{response.status_code}] {error_message}', + status='INTERNAL', + ) + + return response.json() + + except Exception as e: + if isinstance(e, GenkitError): + raise + raise GenkitError( + message=f'Failed to call Vertex AI Evaluation API: {e}', + status='UNAVAILABLE', + ) from e + + def create_evaluator_fn( + self, + metric_type: VertexAIEvaluationMetricType, + metric_spec: dict[str, Any] | None, + to_request: Any, # noqa: ANN401 + response_handler: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Create an evaluator function. + + Args: + metric_type: The metric type. + metric_spec: Optional metric specification. + to_request: Function to convert datapoint to request. + response_handler: Function to extract score from response. + + Returns: + An async evaluator function. + """ + + async def evaluator_fn( + datapoint: BaseDataPoint, + options: dict[str, Any] | None = None, + ) -> EvalFnResponse: + """Evaluate a single datapoint. + + Args: + datapoint: The evaluation data point. + options: Optional evaluation options. + + Returns: + The evaluation response with score. + """ + request_body = to_request(datapoint, metric_spec or {}) + response = await self.evaluate_instances(request_body) + score = response_handler(response) + + return EvalFnResponse( + evaluation=score, + test_case_id=datapoint.test_case_id or '', + ) + + return evaluator_fn + + +def create_vertex_evaluators( + registry: GenkitRegistry, + metrics: list[VertexAIEvaluationMetric], + project_id: str, + location: str, +) -> list[Action]: + """Create Vertex AI evaluator actions. + + Args: + registry: The Genkit registry. + metrics: List of metrics to create evaluators for. + project_id: Google Cloud project ID. + location: Google Cloud location. + + Returns: + List of created evaluator actions. + """ + factory = EvaluatorFactory(project_id, location) + actions = [] + + for metric in metrics: + if isinstance(metric, VertexAIEvaluationMetricConfig): + metric_type: VertexAIEvaluationMetricType = metric.type + metric_spec: dict[str, Any] | None = metric.metric_spec + else: + metric_type = metric + metric_spec = None + + action = _create_evaluator_for_metric(registry, factory, metric_type, metric_spec or {}) + if action: + actions.append(action) + + return actions + + +def _create_evaluator_for_metric( + registry: GenkitRegistry, + factory: EvaluatorFactory, + metric_type: VertexAIEvaluationMetricType, + metric_spec: dict[str, Any], +) -> Action | None: + """Create an evaluator action for a specific metric. + + Args: + registry: The Genkit registry. + factory: The evaluator factory. + metric_type: The metric type. + metric_spec: The metric specification. + + Returns: + The created action, or None if metric is not supported. + """ + evaluator_configs = { + VertexAIEvaluationMetricType.BLEU: { + 'display_name': 'BLEU', + 'definition': 'Computes the BLEU score by comparing the output against the ground truth', + 'to_request': lambda dp, spec: { + 'bleuInput': { + 'metricSpec': spec, + 'instances': [ + { + 'prediction': _stringify(dp.output), + 'reference': dp.reference, + } + ], + } + }, + 'response_handler': _create_list_based_score_handler('bleuResults', 'bleuMetricValues'), + }, + VertexAIEvaluationMetricType.ROUGE: { + 'display_name': 'ROUGE', + 'definition': 'Computes the ROUGE score by comparing the output against the ground truth', + 'to_request': lambda dp, spec: { + 'rougeInput': { + 'metricSpec': spec, + 'instances': [ + { + 'prediction': _stringify(dp.output), + 'reference': dp.reference, + } + ], + } + }, + 'response_handler': _create_list_based_score_handler('rougeResults', 'rougeMetricValues'), + }, + VertexAIEvaluationMetricType.FLUENCY: { + 'display_name': 'Fluency', + 'definition': 'Assesses the language mastery of an output', + 'to_request': lambda dp, spec: { + 'fluencyInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('fluencyResult', {}).get('score'), + details=Details(reasoning=r.get('fluencyResult', {}).get('explanation')), + ), + }, + VertexAIEvaluationMetricType.SAFETY: { + 'display_name': 'Safety', + 'definition': 'Assesses the level of safety of an output', + 'to_request': lambda dp, spec: { + 'safetyInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('safetyResult', {}).get('score'), + details=Details(reasoning=r.get('safetyResult', {}).get('explanation')), + ), + }, + VertexAIEvaluationMetricType.GROUNDEDNESS: { + 'display_name': 'Groundedness', + 'definition': 'Assesses the ability to provide or reference information included only in the context', + 'to_request': lambda dp, spec: { + 'groundednessInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + 'context': '. '.join(dp.context) if dp.context else None, + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('groundednessResult', {}).get('score'), + details=Details(reasoning=r.get('groundednessResult', {}).get('explanation')), + ), + }, + VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: { + 'display_name': 'Summarization quality', + 'definition': 'Assesses the overall ability to summarize text', + 'to_request': lambda dp, spec: { + 'summarizationQualityInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + 'instruction': _stringify(dp.input), + 'context': '. '.join(dp.context) if dp.context else None, + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('summarizationQualityResult', {}).get('score'), + details=Details(reasoning=r.get('summarizationQualityResult', {}).get('explanation')), + ), + }, + VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { + 'display_name': 'Summarization helpfulness', + 'definition': 'Assesses ability to provide a summarization with details to substitute the original', + 'to_request': lambda dp, spec: { + 'summarizationHelpfulnessInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + 'instruction': _stringify(dp.input), + 'context': '. '.join(dp.context) if dp.context else None, + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('summarizationHelpfulnessResult', {}).get('score'), + details=Details(reasoning=r.get('summarizationHelpfulnessResult', {}).get('explanation')), + ), + }, + VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: { + 'display_name': 'Summarization verbosity', + 'definition': 'Assesses the ability to provide a succinct summarization', + 'to_request': lambda dp, spec: { + 'summarizationVerbosityInput': { + 'metricSpec': spec, + 'instance': { + 'prediction': _stringify(dp.output), + 'instruction': _stringify(dp.input), + 'context': '. '.join(dp.context) if dp.context else None, + }, + } + }, + 'response_handler': lambda r: Score( + score=r.get('summarizationVerbosityResult', {}).get('score'), + details=Details(reasoning=r.get('summarizationVerbosityResult', {}).get('explanation')), + ), + }, + } + + config = evaluator_configs.get(metric_type) + if not config: + return None + + evaluator_name = f'vertexai/{metric_type.lower()}' + display_name: str = config['display_name'] # type: ignore[assignment] + definition: str = config['definition'] # type: ignore[assignment] + evaluator_fn = factory.create_evaluator_fn( + metric_type, + metric_spec, + config['to_request'], + config['response_handler'], + ) + + return registry.define_evaluator( + name=evaluator_name, + display_name=display_name, + definition=definition, + fn=evaluator_fn, + is_billed=True, # These use Vertex AI API which is billed + ) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 44c8e6529b..588980fc38 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -93,7 +93,7 @@ """ import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from genkit.blocks.background_model import BackgroundAction @@ -105,11 +105,19 @@ import genkit.plugins.google_genai.constants as const from genkit.ai import GENKIT_CLIENT_HEADER, Plugin +from genkit.blocks.document import Document from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata from genkit.blocks.model import model_action_metadata +from genkit.blocks.reranker import reranker_action_metadata from genkit.core.action import Action, ActionMetadata from genkit.core.registry import ActionKind from genkit.core.schema import to_json_schema +from genkit.core.typing import ( + RankedDocumentData, + RankedDocumentMetadata, + RerankerRequest, + RerankerResponse, +) from genkit.plugins.google_genai.models.embedder import ( Embedder, default_embedder_info, @@ -132,6 +140,15 @@ is_veo_model, veo_model_info, ) +from genkit.plugins.google_genai.rerankers.reranker import ( + KNOWN_MODELS as RERANKER_MODELS, + RerankRequest, + VertexRerankerClientOptions, + VertexRerankerConfig, + _from_rerank_response, + _to_reranker_doc, + reranker_rank, +) class GenaiModels: @@ -677,15 +694,17 @@ def __init__( api_version: The API version to use. Defaults to None. base_url: The base URL for the API. Defaults to None. """ - project = project if project else os.getenv(const.GCLOUD_PROJECT) - location = location if location else const.DEFAULT_REGION + # Store project and location on the plugin for reranker resolution. + # This avoids reaching into client internals. + self._project = project if project else os.getenv(const.GCLOUD_PROJECT) + self._location = location if location else const.DEFAULT_REGION self._client = genai.client.Client( vertexai=self._vertexai, api_key=api_key, credentials=credentials, - project=project, - location=location, + project=self._project, + location=self._location, debug_config=debug_config, http_options=_inject_attribution_headers(http_options, base_url, api_version), ) @@ -711,6 +730,10 @@ async def init(self) -> list[Action]: for name in genai_models.embedders: actions.append(self._resolve_embedder(vertexai_name(name))) + # Register Vertex AI rerankers + for name in RERANKER_MODELS: + actions.append(self._resolve_reranker(vertexai_name(name))) + return actions def _list_known_models(self) -> list[Action]: @@ -747,6 +770,8 @@ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: return self._resolve_model(name) elif action_type == ActionKind.EMBEDDER: return self._resolve_embedder(name) + elif action_type == ActionKind.RERANKER: + return self._resolve_reranker(name) return None def _resolve_model(self, name: str) -> Action: @@ -817,6 +842,80 @@ def _resolve_embedder(self, name: str) -> Action: ).metadata, ) + def _resolve_reranker(self, name: str) -> Action: + """Create an Action object for a Vertex AI reranker. + + Args: + name: The namespaced name of the reranker. + + Returns: + Action object for the reranker. + """ + # Extract local name (remove plugin prefix) + clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name + + # Validate project is configured (required for reranker API) + if not self._project: + raise ValueError( + 'VertexAI plugin requires a project ID to use rerankers. ' + 'Set the project parameter or GOOGLE_CLOUD_PROJECT environment variable.' + ) + + # Use project and location stored on the plugin instance during init. + # This avoids accessing private attributes of the client library. + client_options = VertexRerankerClientOptions( + project_id=self._project, + location=self._location, + ) + + async def wrapper( + request: RerankerRequest, + _ctx: Any, # noqa: ANN401 + ) -> RerankerResponse: + """Wrapper that takes RerankerRequest and returns RerankerResponse. + + This matches the signature expected by the Action class (max 2 args). + """ + query_doc = Document.from_document_data(request.query) + documents = [Document.from_document_data(d) for d in request.documents] + options = request.options + + config = VertexRerankerConfig.model_validate(options or {}) + + # Use location from config if provided, otherwise use client default + effective_options = VertexRerankerClientOptions( + project_id=client_options.project_id, + location=config.location or client_options.location, + ) + + rerank_request = RerankRequest( + model=clean_name, + query=query_doc.text(), + records=[_to_reranker_doc(doc, idx) for idx, doc in enumerate(documents)], + top_n=config.top_n, + ignore_record_details_in_response=config.ignore_record_details_in_response, + ) + + response = await reranker_rank(clean_name, rerank_request, effective_options) + ranked_docs = _from_rerank_response(response, documents) + + # Convert to RerankerResponse format - ranked_docs are RankedDocument instances + response_docs: list[RankedDocumentData] = [] + for doc in ranked_docs: + metadata = RankedDocumentMetadata(score=doc.score if doc.score is not None else 0.0) + response_docs.append(RankedDocumentData(content=doc.content, metadata=metadata)) + + return RerankerResponse(documents=response_docs) + + metadata = reranker_action_metadata(name) + + return Action( + kind=ActionKind.RERANKER, + name=name, + fn=wrapper, + metadata=metadata.metadata, + ) + async def list_actions(self) -> list[ActionMetadata]: """Generate a list of available actions or models. diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/__init__.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/__init__.py new file mode 100644 index 0000000000..ebc02715f8 --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/__init__.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Vertex AI Rerankers for the Genkit framework. + +This module provides reranking functionality using the Vertex AI Discovery Engine +Ranking API. Rerankers improve RAG (Retrieval-Augmented Generation) quality by +re-scoring documents based on their relevance to a query. + +Key Concepts (ELI5):: + + ┌─────────────────────┬────────────────────────────────────────────────────┐ + │ Concept │ ELI5 Explanation │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Reranker │ A "second opinion" scorer that re-orders your │ + │ │ search results by relevance. Like asking an expert │ + │ │ to sort your library books by importance. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Semantic Ranker │ Uses AI to understand meaning, not just keywords. │ + │ │ Knows "car" and "automobile" mean the same thing. │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ top_n │ How many top results to return after reranking. │ + │ │ "Give me the 5 most relevant documents." │ + ├─────────────────────┼────────────────────────────────────────────────────┤ + │ Score │ A number (0-1) showing how relevant a document is.│ + │ │ Higher = more relevant to your query. │ + └─────────────────────┴────────────────────────────────────────────────────┘ + +Data Flow:: + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ RAG WITH RERANKING │ + │ │ + │ User Query: "How do neural networks learn?" │ + │ │ │ + │ ▼ │ + │ ┌─────────────┐ │ + │ │ Retriever │ ◄── Fast initial search, returns ~100 docs │ + │ └──────┬──────┘ │ + │ │ [doc1, doc2, doc3, ... doc100] │ + │ ▼ │ + │ ┌─────────────┐ │ + │ │ Reranker │ ◄── AI-powered relevance scoring │ + │ │ (Vertex) │ │ + │ └──────┬──────┘ │ + │ │ [doc47: 0.95, doc3: 0.87, doc12: 0.82, ...] │ + │ ▼ │ + │ ┌─────────────┐ │ + │ │ Model │ ◄── Uses top-k most relevant docs │ + │ │ (Gemini) │ │ + │ └──────┬──────┘ │ + │ ▼ │ + │ High-quality answer with accurate citations │ + └─────────────────────────────────────────────────────────────────────────┘ + +Overview: + Vertex AI offers semantic rerankers that use machine learning to score + documents based on their semantic similarity to a query. This is typically + used after initial retrieval to improve the quality of the top-k results. + + Reranking is a two-stage retrieval pattern: + 1. **Fast retrieval**: Get many candidates quickly (e.g., 100 docs) + 2. **Quality reranking**: Score candidates by relevance, keep top-k + +Available Models: + +--------------------------------+-----------------------------------------+ + | Model | Description | + +--------------------------------+-----------------------------------------+ + | semantic-ranker-default@latest | Latest default semantic ranker | + | semantic-ranker-default-004 | Semantic ranker version 004 | + | semantic-ranker-fast-004 | Fast variant (lower latency, less acc.) | + | semantic-ranker-default-003 | Semantic ranker version 003 | + | semantic-ranker-default-002 | Semantic ranker version 002 | + +--------------------------------+-----------------------------------------+ + +Example: + Basic reranking:: + + from genkit import Genkit + from genkit.plugins.google_genai import VertexAI + + ai = Genkit(plugins=[VertexAI(project='my-project')]) + + # Rerank documents after retrieval + ranked_docs = await ai.rerank( + reranker='vertexai/semantic-ranker-default@latest', + query='What is machine learning?', + documents=retrieved_docs, + options={'top_n': 5}, + ) + + Full RAG pipeline with reranking:: + + # 1. Retrieve initial candidates + candidates = await ai.retrieve( + retriever='my-retriever', + query='How do neural networks learn?', + options={'limit': 50}, + ) + + # 2. Rerank for quality + ranked = await ai.rerank( + reranker='vertexai/semantic-ranker-default@latest', + query='How do neural networks learn?', + documents=candidates, + options={'top_n': 5}, + ) + + # 3. Generate with top results + response = await ai.generate( + model='vertexai/gemini-2.0-flash', + prompt='Explain how neural networks learn.', + docs=ranked, + ) + +Caveats: + - Requires Google Cloud project with Discovery Engine API enabled + - Reranking adds latency - use for quality-critical applications + - Models may silently fall back to default if name is not recognized + +See Also: + - Vertex AI Ranking API: https://cloud.google.com/generative-ai-app-builder/docs/ranking + - RAG best practices: https://genkit.dev/docs/rag +""" + +from genkit.plugins.google_genai.rerankers.reranker import ( + DEFAULT_MODEL_NAME, + KNOWN_MODELS, + RerankRequest, + RerankResponse, + VertexRerankerClientOptions, + VertexRerankerConfig, + _from_rerank_response, + _to_reranker_doc, + is_reranker_model_name, + reranker_rank, +) + +__all__ = [ + 'DEFAULT_MODEL_NAME', + 'KNOWN_MODELS', + 'RerankRequest', + 'RerankResponse', + 'VertexRerankerClientOptions', + 'VertexRerankerConfig', + '_from_rerank_response', + '_to_reranker_doc', + 'is_reranker_model_name', + 'reranker_rank', +] diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/reranker.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/reranker.py new file mode 100644 index 0000000000..196cd4c5b5 --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/rerankers/reranker.py @@ -0,0 +1,361 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Vertex AI Reranker implementation. + +This module implements the Vertex AI Discovery Engine Ranking API for reranking +documents based on their semantic relevance to a query. + +Architecture:: + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ Vertex AI Reranker Module │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Constants & Configuration │ + │ ├── DEFAULT_LOCATION (global) │ + │ ├── DEFAULT_MODEL_NAME (semantic-ranker-default@latest) │ + │ └── KNOWN_MODELS (supported model registry) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Request/Response Types (Pydantic) │ + │ ├── VertexRerankerConfig - User-facing configuration │ + │ ├── VertexRerankerClientOptions - Internal client config │ + │ ├── RerankRequest, RerankRequestRecord - API request types │ + │ └── RerankResponse, RerankResponseRecord - API response types │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ API Client │ + │ ├── reranker_rank() - Async API call to Discovery Engine │ + │ └── get_vertex_rerank_url() - URL builder for ranking endpoint │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ Conversion Functions │ + │ ├── _to_reranker_doc() - Document → RerankRequestRecord │ + │ └── _from_rerank_response() - Response → RankedDocument list │ + └─────────────────────────────────────────────────────────────────────────┘ + +Implementation Notes: + - Uses Google Cloud Application Default Credentials (ADC) for auth + - Calls the Discovery Engine rankingConfigs:rank endpoint + - Supports configurable location and top_n parameters + - Returns RankedDocument instances with scores + +Note: + The actual reranker action registration is handled by the VertexAI plugin + in google.py via the _resolve_reranker method, which uses the conversion + functions and API client defined here. +""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any, ClassVar + +from google.auth import default as google_auth_default +from google.auth.transport.requests import Request +from pydantic import BaseModel, ConfigDict, Field + +from genkit.blocks.document import Document +from genkit.blocks.model import text_from_content +from genkit.blocks.reranker import RankedDocument +from genkit.core.error import GenkitError +from genkit.core.http_client import get_cached_client +from genkit.core.typing import DocumentData + +# Default location for Vertex AI Ranking API (global is recommended per docs) +DEFAULT_LOCATION = 'global' + +# Default reranker model name +DEFAULT_MODEL_NAME = 'semantic-ranker-default@latest' + +# Known reranker models +KNOWN_MODELS: dict[str, str] = { + 'semantic-ranker-default@latest': 'semantic-ranker-default@latest', + 'semantic-ranker-default-004': 'semantic-ranker-default-004', + 'semantic-ranker-fast-004': 'semantic-ranker-fast-004', + 'semantic-ranker-default-003': 'semantic-ranker-default-003', + 'semantic-ranker-default-002': 'semantic-ranker-default-002', +} + + +def is_reranker_model_name(value: str | None) -> bool: + """Check if a value is a valid reranker model name. + + Args: + value: The value to check. + + Returns: + True if the value is a valid reranker model name. + """ + return value is not None and value.startswith('semantic-ranker-') + + +class VertexRerankerConfig(BaseModel): + """Configuration options for Vertex AI reranker. + + Attributes: + top_n: Number of top documents to return. If not specified, all documents + are returned with their scores. + ignore_record_details_in_response: If True, the response will only contain + record ID and score. Defaults to False. + location: Google Cloud location (e.g., "us-central1"). If not specified, + uses the default location from plugin options. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict( + extra='allow', + populate_by_name=True, + ) + + top_n: int | None = Field(default=None, alias='topN') + ignore_record_details_in_response: bool | None = Field( + default=None, + alias='ignoreRecordDetailsInResponse', + ) + location: str | None = None + + +class RerankRequestRecord(BaseModel): + """A record to be reranked. + + Attributes: + id: Unique identifier for the record. + title: Optional title of the record. + content: The content of the record to be ranked. + """ + + id: str + title: str | None = None + content: str + + +class RerankRequest(BaseModel): + """Request body for the rerank API. + + Attributes: + model: The reranker model to use. + query: The query to rank documents against. + records: The records to be ranked. + top_n: Number of top documents to return. + ignore_record_details_in_response: If True, only return ID and score. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict( + extra='allow', + populate_by_name=True, + ) + + model: str + query: str + records: list[RerankRequestRecord] + top_n: int | None = Field(default=None, alias='topN') + ignore_record_details_in_response: bool | None = Field( + default=None, + alias='ignoreRecordDetailsInResponse', + ) + + +class RerankResponseRecord(BaseModel): + """A record in the rerank response. + + Attributes: + id: The record ID. + score: The relevance score (0-1). + content: The record content (if not ignored). + title: The record title (if present). + """ + + id: str + score: float + content: str | None = None + title: str | None = None + + +class RerankResponse(BaseModel): + """Response from the rerank API. + + Attributes: + records: The ranked records with scores. + """ + + records: list[RerankResponseRecord] + + +class VertexRerankerClientOptions(BaseModel): + """Client options for the Vertex AI reranker. + + Attributes: + project_id: Google Cloud project ID. + location: Google Cloud location (e.g., "us-central1"). + """ + + project_id: str + location: str = DEFAULT_LOCATION + + +async def reranker_rank( + model: str, + request: RerankRequest, + client_options: VertexRerankerClientOptions, +) -> RerankResponse: + """Call the Vertex AI Ranking API. + + Args: + model: The reranker model name. + request: The rerank request. + client_options: Client options including project and location. + + Returns: + The rerank response with scored records. + + Raises: + GenkitError: If the API call fails. + """ + url = get_vertex_rerank_url(client_options) + + # Get authentication token + # Use asyncio.to_thread to avoid blocking the event loop during token refresh + credentials, _ = google_auth_default() + await asyncio.to_thread(credentials.refresh, Request()) + token = credentials.token + + if not token: + raise GenkitError( + message='Unable to authenticate your request. ' + 'Please ensure you have valid Google Cloud credentials configured.', + status='UNAUTHENTICATED', + ) + + headers = { + 'Authorization': f'Bearer {token}', + 'x-goog-user-project': client_options.project_id, + 'Content-Type': 'application/json', + } + + # Prepare request body - only include non-None values + request_body: dict[str, Any] = { + 'model': request.model, + 'query': request.query, + 'records': [r.model_dump(exclude_none=True) for r in request.records], + } + if request.top_n is not None: + request_body['topN'] = request.top_n + if request.ignore_record_details_in_response is not None: + request_body['ignoreRecordDetailsInResponse'] = request.ignore_record_details_in_response + + # Use cached client for better connection reuse. + # Note: Auth headers are passed per-request since tokens may expire. + client = get_cached_client( + cache_key='vertex-ai-reranker', + timeout=60.0, + ) + + try: + response = await client.post( + url, + headers=headers, + json=request_body, + ) + + if response.status_code != 200: + error_message = response.text + try: + error_json = response.json() + if 'error' in error_json and 'message' in error_json['error']: + error_message = error_json['error']['message'] + except json.JSONDecodeError: # noqa: S110 + # JSON parsing failed, use raw text + pass + + raise GenkitError( + message=f'Error calling Vertex AI Reranker API: [{response.status_code}] {error_message}', + status='INTERNAL', + ) + + return RerankResponse.model_validate(response.json()) + + except Exception as e: + if isinstance(e, GenkitError): + raise + raise GenkitError( + message=f'Failed to call Vertex AI Reranker API: {e}', + status='UNAVAILABLE', + ) from e + + +def get_vertex_rerank_url(client_options: VertexRerankerClientOptions) -> str: + """Get the URL for the Vertex AI Ranking API. + + Args: + client_options: Client options including project and location. + + Returns: + The API endpoint URL. + """ + return ( + f'https://discoveryengine.googleapis.com/v1/projects/{client_options.project_id}' + f'/locations/{client_options.location}/rankingConfigs/default_ranking_config:rank' + ) + + +def _to_reranker_doc(doc: Document | DocumentData, idx: int) -> RerankRequestRecord: + """Convert a document to a rerank request record. + + Args: + doc: The document to convert. + idx: The index of the document (used as ID). + + Returns: + A rerank request record. + """ + if isinstance(doc, Document): + text = doc.text() + else: + # DocumentData - use text_from_content helper + text = text_from_content(doc.content) + + return RerankRequestRecord( + id=str(idx), + content=text, + ) + + +def _from_rerank_response( + response: RerankResponse, + documents: list[Document], +) -> list[RankedDocument]: + """Convert rerank response to ranked documents. + + Args: + response: The rerank response. + documents: The original documents. + + Returns: + RankedDocument instances with scores, sorted by relevance. + """ + ranked_docs: list[RankedDocument] = [] + for record in response.records: + idx = int(record.id) + original_doc = documents[idx] + + # Create RankedDocument with the score from the API response + ranked_docs.append( + RankedDocument( + content=original_doc.content, + metadata=original_doc.metadata, + score=record.score, + ) + ) + + return ranked_docs diff --git a/py/plugins/google-genai/test/google_plugin_test.py b/py/plugins/google-genai/test/google_plugin_test.py index 61b50b47ad..e6db3feee6 100644 --- a/py/plugins/google-genai/test/google_plugin_test.py +++ b/py/plugins/google-genai/test/google_plugin_test.py @@ -476,7 +476,7 @@ def test_init_with_all(self, mock_genai_client: MagicMock) -> None: @patch('google.genai.client.Client') def vertexai_plugin_instance(client: MagicMock) -> VertexAI: """VertexAI fixture.""" - return VertexAI() + return VertexAI(project='test-project', location='us-central1') @pytest.mark.asyncio diff --git a/py/plugins/google-genai/tests/evaluators_test.py b/py/plugins/google-genai/tests/evaluators_test.py new file mode 100644 index 0000000000..823531df31 --- /dev/null +++ b/py/plugins/google-genai/tests/evaluators_test.py @@ -0,0 +1,286 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Vertex AI Evaluators.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from genkit.plugins.google_genai.evaluators import ( + VertexAIEvaluationMetricType, + create_vertex_evaluators, +) +from genkit.plugins.google_genai.evaluators.evaluation import ( + EvaluatorFactory, + VertexAIEvaluationMetricConfig, + _is_config, + _stringify, +) + + +def test_vertex_ai_evaluation_metric_type_values() -> None: + """Test that VertexAIEvaluationMetricType has expected values.""" + assert VertexAIEvaluationMetricType.BLEU == 'BLEU' + assert VertexAIEvaluationMetricType.ROUGE == 'ROUGE' + assert VertexAIEvaluationMetricType.FLUENCY == 'FLUENCY' + assert VertexAIEvaluationMetricType.SAFETY == 'SAFETY' + assert VertexAIEvaluationMetricType.GROUNDEDNESS == 'GROUNDEDNESS' + assert VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY == 'SUMMARIZATION_QUALITY' + assert VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS == 'SUMMARIZATION_HELPFULNESS' + assert VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY == 'SUMMARIZATION_VERBOSITY' + + +def test_vertex_ai_evaluation_metric_type_is_str_enum() -> None: + """Test that metric types can be used as strings.""" + metric = VertexAIEvaluationMetricType.FLUENCY + assert isinstance(metric, str) + assert metric == 'FLUENCY' + + +def test_vertex_ai_evaluation_metric_config_basic() -> None: + """Test VertexAIEvaluationMetricConfig model.""" + config = VertexAIEvaluationMetricConfig( + type=VertexAIEvaluationMetricType.BLEU, + metric_spec={'use_sentence_level': True}, + ) + assert config.type == VertexAIEvaluationMetricType.BLEU + assert config.metric_spec == {'use_sentence_level': True} + + +def test_vertex_ai_evaluation_metric_config_defaults() -> None: + """Test VertexAIEvaluationMetricConfig default values.""" + config = VertexAIEvaluationMetricConfig(type=VertexAIEvaluationMetricType.SAFETY) + assert config.type == VertexAIEvaluationMetricType.SAFETY + assert config.metric_spec is None + + +def test_stringify_string_input() -> None: + """Test _stringify with string input returns unchanged.""" + result = _stringify('hello world') + assert result == 'hello world' + + +def test_stringify_dict_input() -> None: + """Test _stringify with dict input returns JSON.""" + result = _stringify({'key': 'value'}) + assert result == '{"key": "value"}' + + +def test_stringify_list_input() -> None: + """Test _stringify with list input returns JSON.""" + result = _stringify(['a', 'b', 'c']) + assert result == '["a", "b", "c"]' + + +def test_stringify_number_input() -> None: + """Test _stringify with number input returns JSON.""" + result = _stringify(42) + assert result == '42' + + +def test_is_config_with_metric_type() -> None: + """Test _is_config returns False for metric type.""" + metric = VertexAIEvaluationMetricType.FLUENCY + assert _is_config(metric) is False + + +def test_is_config_with_metric_config() -> None: + """Test _is_config returns True for metric config.""" + config = VertexAIEvaluationMetricConfig(type=VertexAIEvaluationMetricType.FLUENCY) + assert _is_config(config) is True + + +def test_evaluator_factory_initialization() -> None: + """Test EvaluatorFactory can be initialized.""" + factory = EvaluatorFactory( + project_id='test-project', + location='us-central1', + ) + assert factory.project_id == 'test-project' + assert factory.location == 'us-central1' + + +@pytest.mark.asyncio +async def test_evaluator_factory_evaluate_instances_structure() -> None: + """Test that evaluate_instances makes correct API call structure.""" + factory = EvaluatorFactory( + project_id='test-project', + location='us-central1', + ) + + mock_credentials = MagicMock() + mock_credentials.token = 'mock-token' + mock_credentials.expired = False + + mock_response_data = { + 'fluencyResult': { + 'score': 4.5, + 'explanation': 'Very fluent text', + } + } + + with patch('genkit.plugins.google_genai.evaluators.evaluation.google_auth_default') as mock_auth: + mock_auth.return_value = (mock_credentials, 'test-project') + + # Mock get_cached_client to return a mock client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.is_closed = False + + with patch('genkit.plugins.google_genai.evaluators.evaluation.get_cached_client', return_value=mock_client): + result = await factory.evaluate_instances({'fluencyInput': {'prediction': 'Test'}}) + + assert result == mock_response_data + mock_client.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_evaluator_factory_evaluate_instances_error_handling() -> None: + """Test that evaluate_instances raises GenkitError on API failure.""" + from genkit.core.error import GenkitError + + factory = EvaluatorFactory( + project_id='test-project', + location='us-central1', + ) + + mock_credentials = MagicMock() + mock_credentials.token = 'mock-token' + mock_credentials.expired = False + + with patch('genkit.plugins.google_genai.evaluators.evaluation.google_auth_default') as mock_auth: + mock_auth.return_value = (mock_credentials, 'test-project') + + # Mock get_cached_client to return a mock client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = 'Internal Server Error' + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.is_closed = False + + with patch('genkit.plugins.google_genai.evaluators.evaluation.get_cached_client', return_value=mock_client): + with pytest.raises(GenkitError) as exc_info: + await factory.evaluate_instances({'input': 'test'}) + + assert exc_info.value.status == 'INTERNAL' + + +def test_create_vertex_evaluators_with_metric_types() -> None: + """Test create_vertex_evaluators with simple metric types.""" + mock_registry = MagicMock() + mock_registry.define_evaluator = MagicMock() + + metrics = [ + VertexAIEvaluationMetricType.FLUENCY, + VertexAIEvaluationMetricType.SAFETY, + ] + + create_vertex_evaluators( + registry=mock_registry, + metrics=metrics, + project_id='test-project', + location='us-central1', + ) + + assert mock_registry.define_evaluator.call_count == 2 + + +def test_create_vertex_evaluators_with_metric_configs() -> None: + """Test create_vertex_evaluators with metric configs.""" + mock_registry = MagicMock() + mock_registry.define_evaluator = MagicMock() + + metrics = [ + VertexAIEvaluationMetricConfig( + type=VertexAIEvaluationMetricType.BLEU, + metric_spec={'use_sentence_level': True}, + ), + ] + + create_vertex_evaluators( + registry=mock_registry, + metrics=metrics, + project_id='test-project', + location='us-central1', + ) + + mock_registry.define_evaluator.assert_called_once() + + +def test_create_vertex_evaluators_names_format() -> None: + """Test that evaluator names follow vertexai/{metric} format.""" + mock_registry = MagicMock() + evaluator_names: list[str] = [] + + def capture_name(*args: object, **kwargs: object) -> None: + if 'name' in kwargs: + name = kwargs['name'] + if isinstance(name, str): + evaluator_names.append(name) + + mock_registry.define_evaluator = capture_name + + metrics = [ + VertexAIEvaluationMetricType.FLUENCY, + VertexAIEvaluationMetricType.GROUNDEDNESS, + ] + + create_vertex_evaluators( + registry=mock_registry, + metrics=metrics, + project_id='test-project', + location='us-central1', + ) + + assert 'vertexai/fluency' in evaluator_names + assert 'vertexai/groundedness' in evaluator_names + + +def test_create_vertex_evaluators_empty_metrics() -> None: + """Test create_vertex_evaluators with empty metrics list.""" + mock_registry = MagicMock() + mock_registry.define_evaluator = MagicMock() + + create_vertex_evaluators( + registry=mock_registry, + metrics=[], + project_id='test-project', + location='us-central1', + ) + + mock_registry.define_evaluator.assert_not_called() + + +def test_all_metric_types_supported() -> None: + """Test that all metric types are supported by create_vertex_evaluators.""" + mock_registry = MagicMock() + mock_registry.define_evaluator = MagicMock() + + all_metrics = list(VertexAIEvaluationMetricType) + + create_vertex_evaluators( + registry=mock_registry, + metrics=all_metrics, + project_id='test-project', + location='us-central1', + ) + + assert mock_registry.define_evaluator.call_count == len(all_metrics) diff --git a/py/plugins/google-genai/tests/rerankers_test.py b/py/plugins/google-genai/tests/rerankers_test.py new file mode 100644 index 0000000000..e42b1d224f --- /dev/null +++ b/py/plugins/google-genai/tests/rerankers_test.py @@ -0,0 +1,346 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Vertex AI Rerankers.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from genkit.blocks.document import Document +from genkit.core.typing import TextPart +from genkit.plugins.google_genai.rerankers import ( + DEFAULT_MODEL_NAME, + KNOWN_MODELS, + VertexRerankerConfig, + is_reranker_model_name, +) +from genkit.plugins.google_genai.rerankers.reranker import ( + RerankRequest, + RerankRequestRecord, + RerankResponse, + RerankResponseRecord, + VertexRerankerClientOptions, + _from_rerank_response, + _to_reranker_doc, + get_vertex_rerank_url, +) + + +def test_default_model_name() -> None: + """Test that DEFAULT_MODEL_NAME is set correctly.""" + assert DEFAULT_MODEL_NAME == 'semantic-ranker-default@latest' + + +def test_known_models_contains_expected_models() -> None: + """Test that KNOWN_MODELS contains expected reranker models.""" + assert 'semantic-ranker-default@latest' in KNOWN_MODELS + assert 'semantic-ranker-default-004' in KNOWN_MODELS + assert 'semantic-ranker-fast-004' in KNOWN_MODELS + assert 'semantic-ranker-default-003' in KNOWN_MODELS + assert 'semantic-ranker-default-002' in KNOWN_MODELS + + +def test_is_reranker_model_name_valid() -> None: + """Test is_reranker_model_name returns True for valid names.""" + assert is_reranker_model_name('semantic-ranker-default@latest') is True + assert is_reranker_model_name('semantic-ranker-fast-004') is True + + +def test_is_reranker_model_name_invalid() -> None: + """Test is_reranker_model_name returns False for invalid names.""" + assert is_reranker_model_name('gemini-2.0-flash') is False + assert is_reranker_model_name('text-embedding-004') is False + assert is_reranker_model_name(None) is False + assert is_reranker_model_name('') is False + + +def test_vertex_reranker_config() -> None: + """Test VertexRerankerConfig model.""" + config = VertexRerankerConfig(top_n=5) + assert config.top_n == 5 + + +def test_vertex_reranker_config_defaults() -> None: + """Test VertexRerankerConfig default values.""" + config = VertexRerankerConfig() + assert config.top_n is None + assert config.location is None + assert config.ignore_record_details_in_response is None + + +def test_vertex_reranker_config_with_aliases() -> None: + """Test VertexRerankerConfig works with aliases.""" + # Use Python field names (populate_by_name=True allows both) + config = VertexRerankerConfig(top_n=10, ignore_record_details_in_response=True) + assert config.top_n == 10 + assert config.ignore_record_details_in_response is True + + +def test_vertex_reranker_client_options() -> None: + """Test VertexRerankerClientOptions model.""" + options = VertexRerankerClientOptions( + project_id='my-project', + location='us-central1', + ) + assert options.project_id == 'my-project' + assert options.location == 'us-central1' + + +def test_vertex_reranker_client_options_default_location() -> None: + """Test VertexRerankerClientOptions uses default location.""" + options = VertexRerankerClientOptions(project_id='test-project') + assert options.project_id == 'test-project' + assert options.location == 'global' + + +def test_rerank_request_record() -> None: + """Test RerankRequestRecord model.""" + record = RerankRequestRecord( + id='doc-1', + title='Test Document', + content='This is the document content.', + ) + assert record.id == 'doc-1' + assert record.title == 'Test Document' + assert record.content == 'This is the document content.' + + +def test_rerank_request_record_no_title() -> None: + """Test RerankRequestRecord without optional title.""" + record = RerankRequestRecord(id='1', content='Content only') + assert record.id == '1' + assert record.title is None + assert record.content == 'Content only' + + +def test_rerank_request() -> None: + """Test RerankRequest model.""" + records = [ + RerankRequestRecord(id='1', content='Doc 1'), + RerankRequestRecord(id='2', content='Doc 2'), + ] + request = RerankRequest( + query='What is machine learning?', + records=records, + model='semantic-ranker-default@latest', + top_n=5, + ) + assert request.query == 'What is machine learning?' + assert len(request.records) == 2 + assert request.model == 'semantic-ranker-default@latest' + assert request.top_n == 5 + + +def test_rerank_response_record() -> None: + """Test RerankResponseRecord model.""" + record = RerankResponseRecord( + id='doc-1', + score=0.95, + content='Document content', + ) + assert record.id == 'doc-1' + assert record.score == 0.95 + assert record.content == 'Document content' + + +def test_rerank_response_record_minimal() -> None: + """Test RerankResponseRecord with only required fields.""" + record = RerankResponseRecord(id='1', score=0.5) + assert record.id == '1' + assert record.score == 0.5 + assert record.content is None + assert record.title is None + + +def test_rerank_response() -> None: + """Test RerankResponse model.""" + records = [ + RerankResponseRecord(id='1', score=0.9), + RerankResponseRecord(id='2', score=0.7), + ] + response = RerankResponse(records=records) + assert len(response.records) == 2 + assert response.records[0].score == 0.9 + + +def test_get_vertex_rerank_url() -> None: + """Test get_vertex_rerank_url builds correct URL.""" + options = VertexRerankerClientOptions( + project_id='my-project', + location='us-central1', + ) + + url = get_vertex_rerank_url(options) + + assert 'my-project' in url + assert 'us-central1' in url + assert 'discoveryengine.googleapis.com' in url + assert ':rank' in url + + +def test_get_vertex_rerank_url_different_location() -> None: + """Test get_vertex_rerank_url with different location.""" + options = VertexRerankerClientOptions( + project_id='test-project', + location='europe-west1', + ) + + url = get_vertex_rerank_url(options) + + assert 'test-project' in url + assert 'europe-west1' in url + + +def test_to_reranker_doc_from_document() -> None: + """Test _to_reranker_doc converts Document to RerankRequestRecord.""" + from genkit.core.typing import DocumentPart + + doc = Document(content=[DocumentPart(root=TextPart(text='This is document content.'))]) + + record = _to_reranker_doc(doc, 0) + + assert record.content == 'This is document content.' + assert record.id == '0' + + +def test_to_reranker_doc_different_index() -> None: + """Test _to_reranker_doc uses provided index.""" + from genkit.core.typing import DocumentPart + + doc = Document(content=[DocumentPart(root=TextPart(text='Content'))]) + + record = _to_reranker_doc(doc, 5) + + assert record.id == '5' + + +def test_from_rerank_response_basic() -> None: + """Test _from_rerank_response converts response to scored documents.""" + from genkit.core.typing import DocumentPart + + original_docs = [ + Document(content=[DocumentPart(root=TextPart(text='Doc 0'))]), + Document(content=[DocumentPart(root=TextPart(text='Doc 1'))]), + Document(content=[DocumentPart(root=TextPart(text='Doc 2'))]), + ] + + response = RerankResponse( + records=[ + RerankResponseRecord(id='1', score=0.9), + RerankResponseRecord(id='0', score=0.7), + RerankResponseRecord(id='2', score=0.5), + ] + ) + + result = _from_rerank_response(response, original_docs) + + assert len(result) == 3 + for doc in result: + assert doc.metadata is not None + assert 'score' in doc.metadata + + +def test_from_rerank_response_preserves_content() -> None: + """Test _from_rerank_response preserves document content.""" + from genkit.core.typing import DocumentPart + + original_docs = [ + Document(content=[DocumentPart(root=TextPart(text='Original content'))]), + ] + + response = RerankResponse(records=[RerankResponseRecord(id='0', score=0.85)]) + + result = _from_rerank_response(response, original_docs) + + assert len(result) == 1 + assert result[0].text() == 'Original content' + assert result[0].metadata is not None + assert result[0].metadata.get('score') == 0.85 + + +def test_from_rerank_response_preserves_original_metadata() -> None: + """Test _from_rerank_response preserves original document metadata.""" + from genkit.core.typing import DocumentPart + + original_docs = [ + Document( + content=[DocumentPart(root=TextPart(text='Content'))], + metadata={'custom_field': 'value'}, + ), + ] + + response = RerankResponse(records=[RerankResponseRecord(id='0', score=0.85)]) + + result = _from_rerank_response(response, original_docs) + + assert len(result) == 1 + assert result[0].metadata is not None + assert result[0].metadata.get('custom_field') == 'value' + assert result[0].metadata.get('score') == 0.85 + + +def test_from_rerank_response_empty() -> None: + """Test _from_rerank_response handles empty response.""" + response = RerankResponse(records=[]) + + result = _from_rerank_response(response, []) + + assert result == [] + + +@pytest.mark.asyncio +async def test_reranker_api_call_structure() -> None: + """Test that reranker API call is structured correctly.""" + from genkit.plugins.google_genai.rerankers.reranker import reranker_rank + + mock_credentials = MagicMock() + mock_credentials.token = 'mock-token' + mock_credentials.expired = False + + mock_response_data = { + 'records': [ + {'id': '0', 'score': 0.9}, + {'id': '1', 'score': 0.7}, + ] + } + + with patch('genkit.plugins.google_genai.rerankers.reranker.google_auth_default') as mock_auth: + mock_auth.return_value = (mock_credentials, 'test-project') + + # Mock get_cached_client to return a mock client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.is_closed = False + + with patch('genkit.plugins.google_genai.rerankers.reranker.get_cached_client', return_value=mock_client): + request = RerankRequest( + model='semantic-ranker-default@latest', + query='test query', + records=[ + RerankRequestRecord(id='0', content='Doc 1'), + RerankRequestRecord(id='1', content='Doc 2'), + ], + ) + options = VertexRerankerClientOptions(project_id='test-project') + + result = await reranker_rank('semantic-ranker-default@latest', request, options) + + assert isinstance(result, RerankResponse) + assert len(result.records) == 2 diff --git a/py/pyproject.toml b/py/pyproject.toml index 3b32c9017f..496b91f65f 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -163,6 +163,7 @@ short-n-long = { workspace = true } tool-interrupts = { workspace = true } vertex-ai-vector-search-bigquery = { workspace = true } vertex-ai-vector-search-firestore = { workspace = true } +vertexai-rerank-eval = { workspace = true } xai-hello = { workspace = true } # Core packages genkit = { workspace = true } diff --git a/py/samples/README.md b/py/samples/README.md index d0ca02915d..218fb69bce 100644 --- a/py/samples/README.md +++ b/py/samples/README.md @@ -27,8 +27,9 @@ This directory contains sample applications demonstrating various Genkit feature │ ┌─────────────────────────┐ │ format-demo │ │ │ │ dev-local-vectorstore │ │ multi-server │ │ │ │ vertex-ai-vector-search │ │ evaluator-demo │ │ -│ │ firestore-retriever │ │ flask-hello │ │ -│ └─────────────────────────┘ └─────────────────────────┘ │ +│ │ firestore-retriever │ │ vertexai-rerank-eval │ │ +│ └─────────────────────────┘ │ flask-hello │ │ +│ └─────────────────────────┘ │ │ │ │ MULTIMODAL GOOGLE AI FEATURES │ │ ────────── ────────────────── │ @@ -96,6 +97,7 @@ cd py/samples/ | **format-demo** | Formats | Output formatting and schemas | | **multi-server** | Architecture | Multiple Genkit servers | | **evaluator-demo** | Evaluation | Custom evaluators and RAGAS | +| **vertexai-rerank-eval** | RAG, Evaluation | Vertex AI rerankers and evaluators | | **flask-hello** | Integrations | Flask HTTP endpoints | ### Multimodal Samples diff --git a/py/samples/_common.sh b/py/samples/_common.sh index b61b65a197..bc7c4bafa2 100644 --- a/py/samples/_common.sh +++ b/py/samples/_common.sh @@ -215,3 +215,182 @@ print_help_footer() { echo " 2. Run: ./run.sh" echo " 3. Browser opens automatically to http://localhost:${port}" } + +# ============================================================================ +# Google Cloud (gcloud) Helper Functions +# ============================================================================ +# These functions provide interactive API enablement for samples that require +# Google Cloud APIs. + +# Check if gcloud CLI is installed +# Usage: check_gcloud_installed || exit 1 +check_gcloud_installed() { + if ! command -v gcloud &> /dev/null; then + echo -e "${RED}Error: gcloud CLI is not installed${NC}" + echo "" + echo "Install the Google Cloud SDK from:" + echo " https://cloud.google.com/sdk/docs/install" + echo "" + return 1 + fi + return 0 +} + +# Check if gcloud is authenticated with Application Default Credentials +# Prompts the user to login if not authenticated (interactive) +# Usage: check_gcloud_auth || true +check_gcloud_auth() { + echo -e "${BLUE}Checking gcloud authentication...${NC}" + + # Check application default credentials + if ! gcloud auth application-default print-access-token &> /dev/null; then + echo -e "${YELLOW}Application default credentials not found.${NC}" + echo "" + + if [[ -t 0 ]] && [ -c /dev/tty ]; then + echo -en "Run ${GREEN}gcloud auth application-default login${NC} now? [Y/n]: " + local response + read -r response < /dev/tty + if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then + echo "" + gcloud auth application-default login + echo "" + else + echo -e "${YELLOW}Skipping authentication. You may encounter auth errors.${NC}" + return 1 + fi + else + echo "Run: gcloud auth application-default login" + return 1 + fi + else + echo -e "${GREEN}✓ Application default credentials found${NC}" + fi + + echo "" + return 0 +} + +# Check if a specific Google Cloud API is enabled +# Usage: is_api_enabled "aiplatform.googleapis.com" "$GOOGLE_CLOUD_PROJECT" +is_api_enabled() { + local api="$1" + local project="$2" + + gcloud services list --project="$project" --enabled --filter="name:$api" --format="value(name)" 2>/dev/null | grep -q "$api" +} + +# Enable required Google Cloud APIs interactively +# Usage: +# REQUIRED_APIS=("aiplatform.googleapis.com" "discoveryengine.googleapis.com") +# enable_required_apis "${REQUIRED_APIS[@]}" +# +# The function will: +# 1. Check which APIs are already enabled +# 2. Prompt the user to enable missing APIs +# 3. Enable APIs on user confirmation +enable_required_apis() { + local project="${GOOGLE_CLOUD_PROJECT:-}" + local apis=("$@") + + if [[ -z "$project" ]]; then + echo -e "${YELLOW}GOOGLE_CLOUD_PROJECT not set, skipping API enablement${NC}" + return 1 + fi + + if [[ ${#apis[@]} -eq 0 ]]; then + echo -e "${YELLOW}No APIs specified${NC}" + return 0 + fi + + echo -e "${BLUE}Checking required APIs for project: ${project}${NC}" + + local apis_to_enable=() + + for api in "${apis[@]}"; do + if is_api_enabled "$api" "$project"; then + echo -e " ${GREEN}✓${NC} $api" + else + echo -e " ${YELLOW}✗${NC} $api (not enabled)" + apis_to_enable+=("$api") + fi + done + + echo "" + + if [[ ${#apis_to_enable[@]} -eq 0 ]]; then + echo -e "${GREEN}All required APIs are already enabled!${NC}" + echo "" + return 0 + fi + + # Prompt to enable APIs + if [[ -t 0 ]] && [ -c /dev/tty ]; then + echo -e "${YELLOW}The following APIs need to be enabled:${NC}" + for api in "${apis_to_enable[@]}"; do + echo " - $api" + done + echo "" + echo -en "Enable these APIs now? [Y/n]: " + local response + read -r response < /dev/tty + + if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then + echo "" + for api in "${apis_to_enable[@]}"; do + echo -e "${BLUE}Enabling $api...${NC}" + if gcloud services enable "$api" --project="$project"; then + echo -e "${GREEN}✓ Enabled $api${NC}" + else + echo -e "${RED}✗ Failed to enable $api${NC}" + return 1 + fi + done + echo "" + echo -e "${GREEN}All APIs enabled successfully!${NC}" + else + echo -e "${YELLOW}Skipping API enablement. You may encounter errors.${NC}" + return 1 + fi + else + echo "Enable APIs with:" + for api in "${apis_to_enable[@]}"; do + echo " gcloud services enable $api --project=$project" + done + return 1 + fi + + echo "" + return 0 +} + +# Run common GCP setup: check gcloud, auth, and enable APIs +# Usage: +# REQUIRED_APIS=("aiplatform.googleapis.com") +# run_gcp_setup "${REQUIRED_APIS[@]}" +run_gcp_setup() { + local apis=("$@") + + # Check gcloud is installed + check_gcloud_installed || return 1 + + # Check/prompt for project + check_env_var "GOOGLE_CLOUD_PROJECT" "" || { + echo -e "${RED}Error: GOOGLE_CLOUD_PROJECT is required${NC}" + echo "" + echo "Set it with:" + echo " export GOOGLE_CLOUD_PROJECT=your-project-id" + echo "" + return 1 + } + + # Check authentication + check_gcloud_auth || true + + # Enable APIs if any were specified + if [[ ${#apis[@]} -gt 0 ]]; then + enable_required_apis "${apis[@]}" || true + fi + + return 0 +} diff --git a/py/samples/firestore-retreiver/run.sh b/py/samples/firestore-retreiver/run.sh index 02831dedb7..7e850cda5c 100755 --- a/py/samples/firestore-retreiver/run.sh +++ b/py/samples/firestore-retreiver/run.sh @@ -61,131 +61,6 @@ print_help() { print_help_footer } -# Check if gcloud is installed -check_gcloud_installed() { - if ! command -v gcloud &> /dev/null; then - echo -e "${RED}Error: gcloud CLI is not installed${NC}" - echo "" - echo "Install the Google Cloud SDK from:" - echo " https://cloud.google.com/sdk/docs/install" - echo "" - return 1 - fi - return 0 -} - -# Check if gcloud is authenticated -check_gcloud_auth() { - echo -e "${BLUE}Checking gcloud authentication...${NC}" - - # Check application default credentials - if ! gcloud auth application-default print-access-token &> /dev/null; then - echo -e "${YELLOW}Application default credentials not found.${NC}" - echo "" - - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -en "Run ${GREEN}gcloud auth application-default login${NC} now? [Y/n]: " - local response - read -r response < /dev/tty - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - gcloud auth application-default login - echo "" - else - echo -e "${YELLOW}Skipping authentication. You may encounter auth errors.${NC}" - return 1 - fi - else - echo "Run: gcloud auth application-default login" - return 1 - fi - else - echo -e "${GREEN}✓ Application default credentials found${NC}" - fi - - echo "" - return 0 -} - -# Check if an API is enabled -is_api_enabled() { - local api="$1" - local project="$2" - - # Use server-side filtering for efficiency - [[ -n "$(gcloud services list --project="$project" --enabled --filter="config.name=$api" --format="value(config.name)" 2>/dev/null)" ]] -} - -# Enable required APIs -enable_required_apis() { - local project="${GOOGLE_CLOUD_PROJECT:-}" - - if [[ -z "$project" ]]; then - echo -e "${YELLOW}GOOGLE_CLOUD_PROJECT not set, skipping API enablement${NC}" - return 1 - fi - - echo -e "${BLUE}Checking required APIs for project: ${project}${NC}" - - local apis_to_enable=() - - for api in "${REQUIRED_APIS[@]}"; do - if is_api_enabled "$api" "$project"; then - echo -e " ${GREEN}✓${NC} $api" - else - echo -e " ${YELLOW}✗${NC} $api (not enabled)" - apis_to_enable+=("$api") - fi - done - - echo "" - - if [[ ${#apis_to_enable[@]} -eq 0 ]]; then - echo -e "${GREEN}All required APIs are already enabled!${NC}" - echo "" - return 0 - fi - - # Prompt to enable APIs - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -e "${YELLOW}The following APIs need to be enabled:${NC}" - for api in "${apis_to_enable[@]}"; do - echo " - $api" - done - echo "" - echo -en "Enable these APIs now? [Y/n]: " - local response - read -r response < /dev/tty - - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - for api in "${apis_to_enable[@]}"; do - echo -e "${BLUE}Enabling $api...${NC}" - if gcloud services enable "$api" --project="$project"; then - echo -e "${GREEN}✓ Enabled $api${NC}" - else - echo -e "${RED}✗ Failed to enable $api${NC}" - return 1 - fi - done - echo "" - echo -e "${GREEN}All APIs enabled successfully!${NC}" - else - echo -e "${YELLOW}Skipping API enablement. You may encounter errors.${NC}" - return 1 - fi - else - echo "Enable APIs with:" - for api in "${apis_to_enable[@]}"; do - echo " gcloud services enable $api --project=$project" - done - return 1 - fi - - echo "" - return 0 -} - # Print reminder about Firestore index print_firestore_index_reminder() { echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" @@ -201,39 +76,6 @@ print_firestore_index_reminder() { echo "" } -# Perform all setup checks (shared by run_setup and main execution) -_perform_setup_checks() { - # Check gcloud is installed - check_gcloud_installed || exit 1 - - # Check/prompt for project - check_env_var "GOOGLE_CLOUD_PROJECT" "" || { - echo -e "${RED}Error: GOOGLE_CLOUD_PROJECT is required${NC}" - echo "" - echo "Set it with:" - echo " export GOOGLE_CLOUD_PROJECT=your-project-id" - echo "" - exit 1 - } - - # Check authentication - check_gcloud_auth || true - - # Enable APIs - enable_required_apis || true - - # Remind about Firestore index - print_firestore_index_reminder -} - -# Run full setup for --setup flag -run_setup() { - print_banner "Setup" "⚙️" - _perform_setup_checks - echo -e "${GREEN}Setup complete!${NC}" - echo "" -} - # Main case "${1:-}" in --help|-h) @@ -241,14 +83,22 @@ case "${1:-}" in exit 0 ;; --setup) - run_setup + print_banner "Setup" "⚙️" + run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + print_firestore_index_reminder + echo -e "${GREEN}Setup complete!${NC}" + echo "" exit 0 ;; esac print_banner "Firestore Retriever Demo" "🔥" -_perform_setup_checks +# Run GCP setup (checks gcloud, auth, enables APIs) +run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + +# Remind about Firestore index +print_firestore_index_reminder # Install dependencies install_deps diff --git a/py/samples/google-genai-vertexai-hello/run.sh b/py/samples/google-genai-vertexai-hello/run.sh index cc57ec5df1..2147598d0c 100755 --- a/py/samples/google-genai-vertexai-hello/run.sh +++ b/py/samples/google-genai-vertexai-hello/run.sh @@ -48,161 +48,6 @@ print_help() { print_help_footer } -# Check if gcloud is installed -check_gcloud_installed() { - if ! command -v gcloud &> /dev/null; then - echo -e "${RED}Error: gcloud CLI is not installed${NC}" - echo "" - echo "Install the Google Cloud SDK from:" - echo " https://cloud.google.com/sdk/docs/install" - echo "" - return 1 - fi - return 0 -} - -# Check if gcloud is authenticated -check_gcloud_auth() { - echo -e "${BLUE}Checking gcloud authentication...${NC}" - - # Check application default credentials - if ! gcloud auth application-default print-access-token &> /dev/null; then - echo -e "${YELLOW}Application default credentials not found.${NC}" - echo "" - - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -en "Run ${GREEN}gcloud auth application-default login${NC} now? [Y/n]: " - local response - read -r response < /dev/tty - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - gcloud auth application-default login - echo "" - else - echo -e "${YELLOW}Skipping authentication. You may encounter auth errors.${NC}" - return 1 - fi - else - echo "Run: gcloud auth application-default login" - return 1 - fi - else - echo -e "${GREEN}✓ Application default credentials found${NC}" - fi - - echo "" - return 0 -} - -# Check if an API is enabled -is_api_enabled() { - local api="$1" - local project="$2" - - # Use server-side filtering for efficiency - [[ -n "$(gcloud services list --project="$project" --enabled --filter="config.name=$api" --format="value(config.name)" 2>/dev/null)" ]] -} - -# Enable required APIs -enable_required_apis() { - local project="${GOOGLE_CLOUD_PROJECT:-}" - - if [[ -z "$project" ]]; then - echo -e "${YELLOW}GOOGLE_CLOUD_PROJECT not set, skipping API enablement${NC}" - return 1 - fi - - echo -e "${BLUE}Checking required APIs for project: ${project}${NC}" - - local apis_to_enable=() - - for api in "${REQUIRED_APIS[@]}"; do - if is_api_enabled "$api" "$project"; then - echo -e " ${GREEN}✓${NC} $api" - else - echo -e " ${YELLOW}✗${NC} $api (not enabled)" - apis_to_enable+=("$api") - fi - done - - echo "" - - if [[ ${#apis_to_enable[@]} -eq 0 ]]; then - echo -e "${GREEN}All required APIs are already enabled!${NC}" - echo "" - return 0 - fi - - # Prompt to enable APIs - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -e "${YELLOW}The following APIs need to be enabled:${NC}" - for api in "${apis_to_enable[@]}"; do - echo " - $api" - done - echo "" - echo -en "Enable these APIs now? [Y/n]: " - local response - read -r response < /dev/tty - - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - for api in "${apis_to_enable[@]}"; do - echo -e "${BLUE}Enabling $api...${NC}" - if gcloud services enable "$api" --project="$project"; then - echo -e "${GREEN}✓ Enabled $api${NC}" - else - echo -e "${RED}✗ Failed to enable $api${NC}" - return 1 - fi - done - echo "" - echo -e "${GREEN}All APIs enabled successfully!${NC}" - else - echo -e "${YELLOW}Skipping API enablement. You may encounter errors.${NC}" - return 1 - fi - else - echo "Enable APIs with:" - for api in "${apis_to_enable[@]}"; do - echo " gcloud services enable $api --project=$project" - done - return 1 - fi - - echo "" - return 0 -} - -# Perform all setup checks (shared by run_setup and main execution) -_perform_setup_checks() { - # Check gcloud is installed - check_gcloud_installed || exit 1 - - # Check/prompt for project - check_env_var "GOOGLE_CLOUD_PROJECT" "" || { - echo -e "${RED}Error: GOOGLE_CLOUD_PROJECT is required${NC}" - echo "" - echo "Set it with:" - echo " export GOOGLE_CLOUD_PROJECT=your-project-id" - echo "" - exit 1 - } - - # Check authentication - check_gcloud_auth || true - - # Enable APIs - enable_required_apis || true -} - -# Run full setup for --setup flag -run_setup() { - print_banner "Setup" "⚙️" - _perform_setup_checks - echo -e "${GREEN}Setup complete!${NC}" - echo "" -} - # Main case "${1:-}" in --help|-h) @@ -210,14 +55,18 @@ case "${1:-}" in exit 0 ;; --setup) - run_setup + print_banner "Setup" "⚙️" + run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + echo -e "${GREEN}Setup complete!${NC}" + echo "" exit 0 ;; esac print_banner "Vertex AI Hello World" "☁️" -_perform_setup_checks +# Run GCP setup (checks gcloud, auth, enables APIs) +run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 # Install dependencies install_deps diff --git a/py/samples/google-genai-vertexai-image/run.sh b/py/samples/google-genai-vertexai-image/run.sh index 29b9c23c21..ebeb7cab39 100755 --- a/py/samples/google-genai-vertexai-image/run.sh +++ b/py/samples/google-genai-vertexai-image/run.sh @@ -48,161 +48,6 @@ print_help() { print_help_footer } -# Check if gcloud is installed -check_gcloud_installed() { - if ! command -v gcloud &> /dev/null; then - echo -e "${RED}Error: gcloud CLI is not installed${NC}" - echo "" - echo "Install the Google Cloud SDK from:" - echo " https://cloud.google.com/sdk/docs/install" - echo "" - return 1 - fi - return 0 -} - -# Check if gcloud is authenticated -check_gcloud_auth() { - echo -e "${BLUE}Checking gcloud authentication...${NC}" - - # Check application default credentials - if ! gcloud auth application-default print-access-token &> /dev/null; then - echo -e "${YELLOW}Application default credentials not found.${NC}" - echo "" - - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -en "Run ${GREEN}gcloud auth application-default login${NC} now? [Y/n]: " - local response - read -r response < /dev/tty - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - gcloud auth application-default login - echo "" - else - echo -e "${YELLOW}Skipping authentication. You may encounter auth errors.${NC}" - return 1 - fi - else - echo "Run: gcloud auth application-default login" - return 1 - fi - else - echo -e "${GREEN}✓ Application default credentials found${NC}" - fi - - echo "" - return 0 -} - -# Check if an API is enabled -is_api_enabled() { - local api="$1" - local project="$2" - - # Use server-side filtering for efficiency - [[ -n "$(gcloud services list --project="$project" --enabled --filter="config.name=$api" --format="value(config.name)" 2>/dev/null)" ]] -} - -# Enable required APIs -enable_required_apis() { - local project="${GOOGLE_CLOUD_PROJECT:-}" - - if [[ -z "$project" ]]; then - echo -e "${YELLOW}GOOGLE_CLOUD_PROJECT not set, skipping API enablement${NC}" - return 1 - fi - - echo -e "${BLUE}Checking required APIs for project: ${project}${NC}" - - local apis_to_enable=() - - for api in "${REQUIRED_APIS[@]}"; do - if is_api_enabled "$api" "$project"; then - echo -e " ${GREEN}✓${NC} $api" - else - echo -e " ${YELLOW}✗${NC} $api (not enabled)" - apis_to_enable+=("$api") - fi - done - - echo "" - - if [[ ${#apis_to_enable[@]} -eq 0 ]]; then - echo -e "${GREEN}All required APIs are already enabled!${NC}" - echo "" - return 0 - fi - - # Prompt to enable APIs - if [[ -t 0 ]] && [ -c /dev/tty ]; then - echo -e "${YELLOW}The following APIs need to be enabled:${NC}" - for api in "${apis_to_enable[@]}"; do - echo " - $api" - done - echo "" - echo -en "Enable these APIs now? [Y/n]: " - local response - read -r response < /dev/tty - - if [[ -z "$response" || "$response" =~ ^[Yy] ]]; then - echo "" - for api in "${apis_to_enable[@]}"; do - echo -e "${BLUE}Enabling $api...${NC}" - if gcloud services enable "$api" --project="$project"; then - echo -e "${GREEN}✓ Enabled $api${NC}" - else - echo -e "${RED}✗ Failed to enable $api${NC}" - return 1 - fi - done - echo "" - echo -e "${GREEN}All APIs enabled successfully!${NC}" - else - echo -e "${YELLOW}Skipping API enablement. You may encounter errors.${NC}" - return 1 - fi - else - echo "Enable APIs with:" - for api in "${apis_to_enable[@]}"; do - echo " gcloud services enable $api --project=$project" - done - return 1 - fi - - echo "" - return 0 -} - -# Perform all setup checks (shared by run_setup and main execution) -_perform_setup_checks() { - # Check gcloud is installed - check_gcloud_installed || exit 1 - - # Check/prompt for project - check_env_var "GOOGLE_CLOUD_PROJECT" "" || { - echo -e "${RED}Error: GOOGLE_CLOUD_PROJECT is required${NC}" - echo "" - echo "Set it with:" - echo " export GOOGLE_CLOUD_PROJECT=your-project-id" - echo "" - exit 1 - } - - # Check authentication - check_gcloud_auth || true - - # Enable APIs - enable_required_apis || true -} - -# Run full setup for --setup flag -run_setup() { - print_banner "Setup" "⚙️" - _perform_setup_checks - echo -e "${GREEN}Setup complete!${NC}" - echo "" -} - # Main case "${1:-}" in --help|-h) @@ -210,14 +55,18 @@ case "${1:-}" in exit 0 ;; --setup) - run_setup + print_banner "Setup" "⚙️" + run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + echo -e "${GREEN}Setup complete!${NC}" + echo "" exit 0 ;; esac print_banner "Vertex AI Image Demo" "🖼️" -_perform_setup_checks +# Run GCP setup (checks gcloud, auth, enables APIs) +run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 # Install dependencies install_deps diff --git a/py/samples/vertexai-rerank-eval/LICENSE b/py/samples/vertexai-rerank-eval/LICENSE new file mode 100644 index 0000000000..996f16986d --- /dev/null +++ b/py/samples/vertexai-rerank-eval/LICENSE @@ -0,0 +1,207 @@ +``` + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ +``` + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +``` + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. +``` + +Copyright \[yyyy] \[name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +``` + http://www.apache.org/licenses/LICENSE-2.0 +``` + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/py/samples/vertexai-rerank-eval/README.md b/py/samples/vertexai-rerank-eval/README.md new file mode 100644 index 0000000000..5c8720d53e --- /dev/null +++ b/py/samples/vertexai-rerank-eval/README.md @@ -0,0 +1,161 @@ +# Vertex AI Rerankers and Evaluators Demo + +Demonstrates using Vertex AI rerankers for RAG quality improvement and evaluators +for assessing model outputs. + +## Features + +### Rerankers + +Semantic document reranking improves RAG quality by re-ordering retrieved documents +based on their semantic relevance to a query. + +* **`rerank_documents`** - Basic document reranking +* **`rag_with_reranking`** - Full RAG pipeline (retrieve → rerank → generate) + +### Evaluators + +Vertex AI evaluators assess model outputs using various quality metrics: + +* **`evaluate_fluency`** - Text fluency (1-5 scale) +* **`evaluate_safety`** - Content safety assessment +* **`evaluate_groundedness`** - Hallucination detection (is output grounded in context?) +* **`evaluate_bleu`** - BLEU score for translation quality +* **`evaluate_summarization`** - Summarization quality assessment + +## Quick Start + +```bash +export GOOGLE_CLOUD_PROJECT=your-project-id +./run.sh +``` + +That's it! The script will: + +1. ✓ Prompt for your project ID if not set +2. ✓ Check gcloud authentication (and help you authenticate if needed) +3. ✓ Enable required APIs (with your permission) +4. ✓ Install dependencies +5. ✓ Start the demo and open your browser + +## Manual Setup (if needed) + +If you prefer manual setup or the automatic setup fails: + +### 1. Authentication + +```bash +gcloud auth application-default login +``` + +### 2. Enable Required APIs + +```bash +# Vertex AI API (for models and evaluators) +gcloud services enable aiplatform.googleapis.com + +# Discovery Engine API (for rerankers) +gcloud services enable discoveryengine.googleapis.com +``` + +### 3. Run the Demo + +```bash +./run.sh +``` + +Or manually: + +```bash +genkit start -- uv run src/main.py +``` + +Then open the Dev UI at http://localhost:4000 + +## Testing the Demo + +### Reranker Flows + +1. **`rerank_documents`** + * Input: A query string (default: "How do neural networks learn?") + * Output: Documents sorted by relevance score + * The sample includes irrelevant documents to show how reranking filters them + +2. **`rag_with_reranking`** + * Input: A question (default: "What is machine learning?") + * Output: Generated answer using top-ranked documents as context + * Demonstrates the two-stage retrieval pattern + +### Evaluator Flows + +1. **`evaluate_fluency`** + * Tests text fluency with samples including intentionally poor grammar + * Scores: 1 (poor) to 5 (excellent) + +2. **`evaluate_safety`** + * Tests content safety + * Higher scores = safer content + +3. **`evaluate_groundedness`** + * Tests if outputs are grounded in provided context + * Includes a hallucination example (claims population when not in context) + +4. **`evaluate_bleu`** + * Tests translation quality against reference translations + * Scores: 0 to 1 (higher = closer to reference) + +5. **`evaluate_summarization`** + * Tests summarization quality + +## Supported Reranker Models + +| Model | Description | +|-------|-------------| +| `semantic-ranker-default@latest` | Latest default semantic ranker | +| `semantic-ranker-default-004` | Semantic ranker version 004 | +| `semantic-ranker-fast-004` | Fast variant (lower latency) | +| `semantic-ranker-default-003` | Semantic ranker version 003 | +| `semantic-ranker-default-002` | Semantic ranker version 002 | + +## Supported Evaluation Metrics + +| Metric | Description | +|--------|-------------| +| BLEU | Translation quality (compare to reference) | +| ROUGE | Summarization quality (compare to reference) | +| FLUENCY | Language mastery and readability | +| SAFETY | Harmful/inappropriate content check | +| GROUNDEDNESS | Factual grounding in context | +| SUMMARIZATION\_QUALITY | Overall summarization ability | +| SUMMARIZATION\_HELPFULNESS | Usefulness as a summary | +| SUMMARIZATION\_VERBOSITY | Conciseness of summary | + +## Troubleshooting + +### "Discovery Engine API not enabled" + +The script should enable this automatically, but if it fails: + +```bash +gcloud services enable discoveryengine.googleapis.com +``` + +### "Permission denied" + +Ensure your account has the required IAM roles: + +* `roles/discoveryengine.admin` (for rerankers) +* `roles/aiplatform.user` (for evaluators) + +### "Project not found" + +Verify `GOOGLE_CLOUD_PROJECT` is set correctly: + +```bash +echo $GOOGLE_CLOUD_PROJECT +``` + +### "gcloud not found" + +Install the Google Cloud SDK from: +https://cloud.google.com/sdk/docs/install diff --git a/py/samples/vertexai-rerank-eval/pyproject.toml b/py/samples/vertexai-rerank-eval/pyproject.toml new file mode 100644 index 0000000000..18c632c381 --- /dev/null +++ b/py/samples/vertexai-rerank-eval/pyproject.toml @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "rich>=13.0.0", + "genkit", + "genkit-plugin-google-genai", + "pydantic>=2.10.5", + "structlog>=25.2.0", + "uvloop>=0.21.0", +] +description = "Vertex AI Rerankers and Evaluators Demo" +license = "Apache-2.0" +name = "vertexai-rerank-eval" +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0" + +[project.optional-dependencies] +dev = ["watchdog>=6.0.0"] + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src/vertexai_rerank_eval"] diff --git a/py/samples/vertexai-rerank-eval/run.sh b/py/samples/vertexai-rerank-eval/run.sh new file mode 100755 index 0000000000..8bec99c5b4 --- /dev/null +++ b/py/samples/vertexai-rerank-eval/run.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +# Vertex AI Rerankers and Evaluators Demo +# ======================================= +# +# Demonstrates using Vertex AI rerankers for RAG quality improvement +# and evaluators for assessing model outputs. +# +# This script automates most of the setup: +# - Detects/prompts for GOOGLE_CLOUD_PROJECT +# - Checks gcloud authentication +# - Enables required APIs +# - Installs dependencies +# +# Usage: +# ./run.sh # Start the demo with Dev UI +# ./run.sh --setup # Run setup only (check auth, enable APIs) +# ./run.sh --help # Show this help message + +set -euo pipefail + +cd "$(dirname "$0")" +source "../_common.sh" + +# Required APIs for this demo +REQUIRED_APIS=( + "aiplatform.googleapis.com" # Vertex AI API (models and evaluators) + "discoveryengine.googleapis.com" # Discovery Engine API (rerankers) +) + +print_help() { + print_banner "Vertex AI Rerankers & Evaluators" "🔍" + echo "Usage: ./run.sh [options]" + echo "" + echo "Options:" + echo " --help Show this help message" + echo " --setup Run setup only (auth check, enable APIs)" + echo "" + echo "The script will automatically:" + echo " 1. Prompt for GOOGLE_CLOUD_PROJECT if not set" + echo " 2. Check gcloud authentication" + echo " 3. Enable required APIs (with your permission)" + echo " 4. Install dependencies" + echo " 5. Start the demo and open the browser" + echo "" + echo "Required APIs (enabled automatically):" + echo " - Vertex AI API (aiplatform.googleapis.com)" + echo " - Discovery Engine API (discoveryengine.googleapis.com)" + print_help_footer +} + +# Main +case "${1:-}" in + --help|-h) + print_help + exit 0 + ;; + --setup) + print_banner "Setup" "⚙️" + run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + echo -e "${GREEN}Setup complete!${NC}" + echo "" + exit 0 + ;; +esac + +print_banner "Vertex AI Rerankers & Evaluators" "🔍" + +# Run GCP setup (checks gcloud, auth, enables APIs) +run_gcp_setup "${REQUIRED_APIS[@]}" || exit 1 + +# Install dependencies +install_deps + +# Start the demo +genkit_start_with_browser -- \ + uv tool run --from watchdog watchmedo auto-restart \ + -d src \ + -d ../../packages \ + -d ../../plugins \ + -p '*.py;*.prompt;*.json' \ + -R \ + -- uv run src/main.py "$@" diff --git a/py/samples/vertexai-rerank-eval/src/main.py b/py/samples/vertexai-rerank-eval/src/main.py new file mode 100644 index 0000000000..ea6acb47de --- /dev/null +++ b/py/samples/vertexai-rerank-eval/src/main.py @@ -0,0 +1,396 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Vertex AI Rerankers and Evaluators Demo. + +This sample demonstrates: +- Semantic document reranking for RAG quality improvement +- Model output evaluation using Vertex AI metrics (BLEU, ROUGE, fluency, safety, etc.) + +Prerequisites: +- GOOGLE_CLOUD_PROJECT environment variable set +- gcloud auth application-default login +- Discovery Engine API enabled (for rerankers) +- Vertex AI API enabled (for evaluators) +""" + +from typing import Any, cast + +import structlog +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.blocks.document import Document +from genkit.core.typing import BaseDataPoint, DocumentData, Score +from genkit.plugins.google_genai import VertexAI + +logger = structlog.get_logger(__name__) + + +ai = Genkit( + plugins=[ + VertexAI(location='us-central1'), + ], + model='vertexai/gemini-2.5-flash', +) + + +# ============================================================================= +# Reranker Examples +# ============================================================================= + + +class RerankResult(BaseModel): + """Result of a rerank operation.""" + + query: str + ranked_documents: list[dict[str, Any]] + + +@ai.flow() +async def rerank_documents(query: str = 'How do neural networks learn?') -> RerankResult: + """Rerank documents based on relevance to query. + + This demonstrates using Vertex AI's semantic reranker to re-order + documents by their semantic relevance to a query. Useful for improving + RAG (Retrieval-Augmented Generation) quality. + """ + # Sample documents to rerank (in a real app, these would come from a retriever) + documents: list[Document] = [ + Document.from_text('Neural networks learn through backpropagation, adjusting weights based on errors.'), + Document.from_text('Python is a popular programming language for machine learning.'), + Document.from_text('The gradient descent algorithm minimizes the loss function during training.'), + Document.from_text('Cats are popular pets known for their independence.'), + Document.from_text('Deep learning models use multiple layers to extract hierarchical features.'), + Document.from_text('The weather today is sunny with a high of 75 degrees.'), + Document.from_text('Transformers use attention mechanisms to process sequential data efficiently.'), + ] + + # Rerank documents using Vertex AI semantic reranker + # Document extends DocumentData, so we can cast and pass documents directly + ranked_docs = await ai.rerank( + reranker='vertexai/semantic-ranker-default@latest', + query=query, + documents=cast(list[DocumentData], documents), + options={'top_n': 5}, + ) + + # Format results + results: list[dict[str, Any]] = [] + for doc in ranked_docs: + results.append({ + 'text': doc.text(), + 'score': doc.score, + }) + + return RerankResult(query=query, ranked_documents=results) + + +@ai.flow() +async def rag_with_reranking(question: str = 'What is machine learning?') -> str: + """Full RAG pipeline with reranking. + + Demonstrates a two-stage retrieval pattern: + 1. Initial retrieval (simulated with sample docs) + 2. Reranking for quality + 3. Generation using top-k results + """ + # Simulated retrieval results (in production, use a real retriever) + retrieved_docs: list[Document] = [ + Document.from_text('Machine learning is a subset of artificial intelligence.'), + Document.from_text('Supervised learning uses labeled data to train models.'), + Document.from_text('The stock market closed higher today.'), + Document.from_text('ML algorithms can identify patterns in large datasets.'), + Document.from_text('Unsupervised learning finds hidden patterns without labels.'), + Document.from_text('Pizza is a popular Italian dish.'), + Document.from_text('Deep learning uses neural networks with many layers.'), + Document.from_text('Reinforcement learning learns from rewards and penalties.'), + ] + + # Stage 2: Rerank for quality + # Document extends DocumentData, so we can cast and pass documents directly + ranked_docs = await ai.rerank( + reranker='vertexai/semantic-ranker-default@latest', + query=question, + documents=cast(list[DocumentData], retrieved_docs), + options={'top_n': 3}, + ) + + # Build context from top-ranked documents + context = '\n'.join([f'- {doc.text()}' for doc in ranked_docs]) + + # Stage 3: Generate answer using reranked context + response = await ai.generate( + model='vertexai/gemini-2.5-flash', + prompt=f"""Answer the following question based on the provided context. + +Context: +{context} + +Question: {question} + +Answer:""", + ) + + return response.text + + +# ============================================================================= +# Evaluator Examples +# ============================================================================= + + +class EvalResult(BaseModel): + """Result of an evaluation.""" + + metric: str + scores: list[dict[str, Any]] + + +def _extract_score(evaluation: Score | list[Score]) -> float | str | bool | None: + """Extract score from evaluation result.""" + if isinstance(evaluation, list): + return evaluation[0].score if evaluation else None + return evaluation.score + + +def _extract_reasoning(evaluation: Score | list[Score]) -> str | None: + """Extract reasoning from evaluation result.""" + if isinstance(evaluation, list): + if evaluation and evaluation[0].details: + return evaluation[0].details.reasoning + return None + if evaluation.details: + return evaluation.details.reasoning + return None + + +@ai.flow() +async def evaluate_fluency() -> EvalResult: + """Evaluate text fluency using Vertex AI. + + Fluency measures how natural and readable the text is. + Scores range from 1 (poor) to 5 (excellent). + """ + dataset: list[BaseDataPoint] = [ + BaseDataPoint( + input='Write a sentence about AI.', + output=( + 'Artificial intelligence is transforming industries by enabling ' + 'machines to learn from data and make intelligent decisions.' + ), + ), + BaseDataPoint( + input='Describe machine learning.', + # Intentionally poor grammar for demonstration + output='Machine learning algorithms data patterns the find automatically.', + ), + BaseDataPoint( + input='Explain neural networks.', + output=( + 'Neural networks are computational models inspired by the human brain, ' + 'consisting of interconnected nodes that process information in layers.' + ), + ), + ] + + results = await ai.evaluate( + evaluator='vertexai/fluency', + dataset=dataset, + ) + + scores: list[dict[str, Any]] = [] + for i, result in enumerate(results.root): + scores.append({ + 'sample_index': i, + 'test_case_id': result.test_case_id, + 'score': _extract_score(result.evaluation), + 'reasoning': _extract_reasoning(result.evaluation), + }) + + return EvalResult(metric='fluency', scores=scores) + + +@ai.flow() +async def evaluate_safety() -> EvalResult: + """Evaluate text safety using Vertex AI. + + Safety checks for harmful, biased, or inappropriate content. + Higher scores indicate safer content. + """ + dataset: list[BaseDataPoint] = [ + BaseDataPoint( + input='Write about technology.', + output=( + 'Technology has revolutionized communication, enabling people to connect instantly across the globe.' + ), + ), + BaseDataPoint( + input='Describe a helpful assistant.', + output=( + 'A helpful AI assistant provides accurate information, ' + 'respects user privacy, and declines harmful requests.' + ), + ), + ] + + results = await ai.evaluate( + evaluator='vertexai/safety', + dataset=dataset, + ) + + scores: list[dict[str, Any]] = [] + for i, result in enumerate(results.root): + scores.append({ + 'sample_index': i, + 'test_case_id': result.test_case_id, + 'score': _extract_score(result.evaluation), + }) + + return EvalResult(metric='safety', scores=scores) + + +@ai.flow() +async def evaluate_groundedness() -> EvalResult: + """Evaluate groundedness using Vertex AI. + + Groundedness checks if the output is factually grounded in the provided context. + This helps detect hallucinations in RAG applications. + """ + dataset: list[BaseDataPoint] = [ + BaseDataPoint( + input='What is the capital of France?', + output='The capital of France is Paris.', + context=[ + 'France is a country in Western Europe. Its capital city is Paris, which is known for the Eiffel Tower.' + ], + ), + BaseDataPoint( + input='What is the population of Paris?', + # Hallucinated - context doesn't mention population + output='Paris has a population of about 12 million people.', + context=['Paris is the capital of France. It is known for art, fashion, and culture.'], + ), + BaseDataPoint( + input='What is France known for?', + output='France is known for wine, cheese, and the Eiffel Tower.', + context=[ + 'France is famous for its cuisine, especially wine and cheese. ' + 'The Eiffel Tower in Paris is a major landmark.' + ], + ), + ] + + results = await ai.evaluate( + evaluator='vertexai/groundedness', + dataset=dataset, + ) + + scores: list[dict[str, Any]] = [] + for i, result in enumerate(results.root): + scores.append({ + 'sample_index': i, + 'test_case_id': result.test_case_id, + 'score': _extract_score(result.evaluation), + 'reasoning': _extract_reasoning(result.evaluation), + }) + + return EvalResult(metric='groundedness', scores=scores) + + +@ai.flow() +async def evaluate_bleu() -> EvalResult: + """Evaluate using BLEU score. + + BLEU (Bilingual Evaluation Understudy) compares output to a reference. + Commonly used for translation and text generation quality. + Scores range from 0 to 1, with higher being better. + """ + dataset: list[BaseDataPoint] = [ + BaseDataPoint( + input='Translate to French: Hello, how are you?', + output='Bonjour, comment allez-vous?', + reference='Bonjour, comment allez-vous?', # Perfect match + ), + BaseDataPoint( + input='Translate to French: Good morning', + output='Bon matin', + reference='Bonjour', # Different but valid translation + ), + ] + + results = await ai.evaluate( + evaluator='vertexai/bleu', + dataset=dataset, + ) + + scores: list[dict[str, Any]] = [] + for i, result in enumerate(results.root): + scores.append({ + 'sample_index': i, + 'test_case_id': result.test_case_id, + 'score': _extract_score(result.evaluation), + }) + + return EvalResult(metric='bleu', scores=scores) + + +@ai.flow() +async def evaluate_summarization() -> EvalResult: + """Evaluate summarization quality using Vertex AI. + + Summarization quality assesses how well a summary captures the key points + of the original text. + """ + dataset: list[BaseDataPoint] = [ + BaseDataPoint( + input='Summarize this article about climate change.', + output='Climate change is causing rising temperatures and extreme weather events globally.', + context=[ + 'Climate change refers to long-term shifts in temperatures and weather patterns. ' + 'Human activities have been the main driver since the 1800s, primarily due to ' + 'burning fossil fuels. This has led to rising global temperatures, melting ice ' + 'caps, rising sea levels, and more frequent extreme weather events like ' + 'hurricanes, droughts, and floods.' + ], + ), + ] + + results = await ai.evaluate( + evaluator='vertexai/summarization_quality', + dataset=dataset, + ) + + scores: list[dict[str, Any]] = [] + for i, result in enumerate(results.root): + scores.append({ + 'sample_index': i, + 'test_case_id': result.test_case_id, + 'score': _extract_score(result.evaluation), + 'reasoning': _extract_reasoning(result.evaluation), + }) + + return EvalResult(metric='summarization_quality', scores=scores) + + +async def main() -> None: + """Main function.""" + # Example run logic can go here or be empty for pure flow server + pass + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index 7003f045b4..0b6a5c8104 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -64,6 +64,7 @@ members = [ "tool-interrupts", "vertex-ai-vector-search-bigquery", "vertex-ai-vector-search-firestore", + "vertexai-rerank-eval", "xai-hello", ] @@ -7925,6 +7926,36 @@ requires-dist = [ ] provides-extras = ["dev"] +[[package]] +name = "vertexai-rerank-eval" +version = "0.1.0" +source = { editable = "samples/vertexai-rerank-eval" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-google-genai" }, + { name = "pydantic" }, + { name = "rich" }, + { name = "structlog" }, + { name = "uvloop" }, +] + +[package.optional-dependencies] +dev = [ + { name = "watchdog" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, + { name = "pydantic", specifier = ">=2.10.5" }, + { name = "rich", specifier = ">=13.0.0" }, + { name = "structlog", specifier = ">=25.2.0" }, + { name = "uvloop", specifier = ">=0.21.0" }, + { name = "watchdog", marker = "extra == 'dev'", specifier = ">=6.0.0" }, +] +provides-extras = ["dev"] + [[package]] name = "virtualenv" version = "20.36.1"