diff --git a/tests/unit/vertexai/model_garden/test_model_garden.py b/tests/unit/vertexai/model_garden/test_model_garden.py index 931b7e8f2c..1e0ee250cb 100644 --- a/tests/unit/vertexai/model_garden/test_model_garden.py +++ b/tests/unit/vertexai/model_garden/test_model_garden.py @@ -1355,6 +1355,44 @@ def test_list_deployable_models(self, list_publisher_models_mock): "google/gemma-2-2b", ] + def test_list_models(self, list_publisher_models_mock): + """Tests listing models.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + mg_models = model_garden.list_models() + list_publisher_models_mock.assert_called_with( + types.ListPublisherModelsRequest( + parent="publishers/*", + list_all_versions=True, + filter="is_hf_wildcard(false)", + ) + ) + + assert mg_models == [ + "google/paligemma@001", + "google/paligemma@002", + "google/paligemma@003", + "google/paligemma@004", + ] + + hf_models = model_garden.list_models(list_hf_models=True) + list_publisher_models_mock.assert_called_with( + types.ListPublisherModelsRequest( + parent="publishers/*", + list_all_versions=True, + filter="is_hf_wildcard(true)", + ) + ) + assert hf_models == [ + "google/gemma-2-2b", + "google/gemma-2-2b", + "google/gemma-2-2b", + "google/gemma-2-2b", + ] + def test_batch_prediction_success(self, batch_prediction_mock): aiplatform.init( project=_TEST_PROJECT, diff --git a/vertexai/model_garden/__init__.py b/vertexai/model_garden/__init__.py index 512d1aac4f..3cfc87f864 100644 --- a/vertexai/model_garden/__init__.py +++ b/vertexai/model_garden/__init__.py @@ -22,5 +22,6 @@ OpenModel = _model_garden.OpenModel PartnerModel = _model_garden.PartnerModel list_deployable_models = _model_garden.list_deployable_models +list_models = _model_garden.list_models -__all__ = ("OpenModel", "PartnerModel", "list_deployable_models") +__all__ = ("OpenModel", "PartnerModel", "list_deployable_models", "list_models") diff --git a/vertexai/model_garden/_model_garden.py b/vertexai/model_garden/_model_garden.py index 14b2062c39..e67557bc30 100644 --- a/vertexai/model_garden/_model_garden.py +++ b/vertexai/model_garden/_model_garden.py @@ -62,6 +62,7 @@ def list_deployable_models( `{publisher}/{model}@{version}` or Hugging Face model ID in the format of `{organization}/{model}`. """ + filter_str = _NATIVE_MODEL_FILTER if list_hf_models: filter_str = " AND ".join([_HF_WILDCARD_FILTER, _VERIFIED_DEPLOYMENT_FILTER]) @@ -93,6 +94,50 @@ def list_deployable_models( return output +def list_models( + *, list_hf_models: bool = False, model_filter: Optional[str] = None +) -> List[str]: + """Lists the models in Model Garden. + + Args: + list_hf_models: Whether to list the Hugging Face models. + model_filter: Optional. A string to filter the models by. + + Returns: + The names of the models in Model Garden in the format of + `{publisher}/{model}@{version}` or Hugging Face model ID in the format + of `{organization}/{model}`. + """ + filter_str = _NATIVE_MODEL_FILTER + if list_hf_models: + filter_str = _HF_WILDCARD_FILTER + if model_filter: + filter_str = ( + f'{filter_str} AND (model_user_id=~"(?i).*{model_filter}.*" OR' + f' display_name=~"(?i).*{model_filter}.*")' + ) + + request = types.ListPublisherModelsRequest( + parent="publishers/*", + list_all_versions=True, + filter=filter_str, + ) + client = initializer.global_config.create_client( + client_class=_ModelGardenClientWithOverride, + credentials=initializer.global_config.credentials, + location_override="us-central1", + ) + response = client.list_publisher_models(request) + output = [] + for page in response.pages: + for model in page.publisher_models: + output.append( + re.sub(r"publishers/(hf-|)|models/", "", model.name) + + ("" if list_hf_models else ("@" + model.version_id)) + ) + return output + + def _is_hugging_face_model(model_name: str) -> bool: """Returns whether the model is a Hugging Face model.""" return re.match(r"^(?P[^/]+)/(?P[^/@]+)$", model_name)