From b3bae32b8577974a113ad1dc3effb5d3a3db4fe0 Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Fri, 6 Mar 2026 17:55:27 -0800 Subject: [PATCH] feat: "global" endpoint supports "grpc" transport PiperOrigin-RevId: 879899627 --- google/cloud/aiplatform/initializer.py | 4 ---- tests/unit/aiplatform/test_initializer.py | 14 ++++++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index deea0611b5..d9d360d1bb 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -245,10 +245,6 @@ def init( # Set api_transport as "rest" if location is "global". if location == "global" and not api_transport: self._api_transport = "rest" - elif location == "global" and api_transport == "grpc": - raise ValueError( - "api_transport cannot be 'grpc' when location is 'global'." - ) if experiment_description and experiment is None: raise ValueError( "Experiment needs to be set in `init` in order to add experiment" diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 65903f2f51..9fdb3eed6a 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -306,10 +306,16 @@ def test_create_client_with_global_location(self): assert client._transport._host == f"https://{constants.API_BASE_PATH}" def test_create_client_with_global_location_and_grpc_transport(self): - with pytest.raises(ValueError): - initializer.global_config.init( - project=_TEST_PROJECT, location="global", api_transport="grpc" - ) + initializer.global_config.init( + project=_TEST_PROJECT, location="global", api_transport="grpc" + ) + client = initializer.global_config.create_client( + client_class=utils.PredictionClientWithOverride + ) + assert initializer.global_config.location == "global" + assert initializer.global_config._api_transport == "grpc" + assert isinstance(client, utils.PredictionClientWithOverride) + assert client._transport._host == f"{constants.API_BASE_PATH}:443" def test_create_client_with_api_key_and_grpc_transport(self): with pytest.raises(ValueError):