diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 66274d3dd1..a4bca19cb0 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -147,7 +147,9 @@ async def _call_tool_in_thread_pool( # For sync FunctionTool, call the underlying function directly def run_sync_tool(): if isinstance(tool, FunctionTool): - args_to_call = tool._preprocess_args(args) + args_to_call, validation_errors = tool._preprocess_args(args) + if validation_errors: + return tool._build_validation_error_response(validation_errors) signature = inspect.signature(tool.func) valid_params = {param for param in signature.parameters} if tool._context_param_name in valid_params: diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index fca8ba9f50..0c056558fd 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -73,8 +73,12 @@ async def run_async( duplicates, but is re-added if the function signature explicitly requires it as a parameter. """ - # Preprocess arguments (includes Pydantic model conversion) - args_to_call = self._preprocess_args(args) + # Preprocess arguments (includes Pydantic model conversion and type + # validation) + args_to_call, validation_errors = self._preprocess_args(args) + + if validation_errors: + return self._build_validation_error_response(validation_errors) signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 10e32a5473..da5397e024 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -18,8 +18,6 @@ import logging from typing import Any from typing import Callable -from typing import get_args -from typing import get_origin from typing import Optional from typing import Union @@ -85,6 +83,7 @@ def __init__( self._context_param_name = find_context_parameter(func) or 'tool_context' self._ignore_params = [self._context_param_name, 'input_stream'] self._require_confirmation = require_confirmation + self._type_adapter_cache: dict[Any, pydantic.TypeAdapter] = {} @override def _get_declaration(self) -> Optional[types.FunctionDeclaration]: @@ -100,68 +99,94 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]: return function_decl - def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: - """Preprocess and convert function arguments before invocation. + def _preprocess_args( + self, args: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: + """Preprocess, validate, and convert function arguments before invocation. - Currently handles: + Handles: - Converting JSON dictionaries to Pydantic model instances where expected - - Future extensions could include: - - Type coercion for other complex types - - Validation and sanitization - - Custom conversion logic + - Validating and coercing primitive types (int, float, str, bool) + - Validating enum values + - Validating container types (list[int], dict[str, float], etc.) Args: args: Raw arguments from the LLM tool call Returns: - Processed arguments ready for function invocation + A tuple of (processed_args, validation_errors). If validation_errors is + non-empty, the caller should return the errors to the LLM instead of + invoking the function. """ signature = inspect.signature(self.func) converted_args = args.copy() + validation_errors = [] for param_name, param in signature.parameters.items(): - if param_name in args and param.annotation != inspect.Parameter.empty: - target_type = param.annotation - - # Handle Optional[PydanticModel] types - if get_origin(param.annotation) is Union: - union_args = get_args(param.annotation) - # Find the non-None type in Optional[T] (which is Union[T, None]) - non_none_types = [arg for arg in union_args if arg is not type(None)] - if len(non_none_types) == 1: - target_type = non_none_types[0] - - # Check if the target type is a Pydantic model - if inspect.isclass(target_type) and issubclass( - target_type, pydantic.BaseModel - ): - # Skip conversion if the value is None and the parameter is Optional - if args[param_name] is None: - continue - - # Convert to Pydantic model if it's not already the correct type - if not isinstance(args[param_name], target_type): - try: - converted_args[param_name] = target_type.model_validate( - args[param_name] - ) - except Exception as e: - logger.warning( - f"Failed to convert argument '{param_name}' to Pydantic model" - f' {target_type.__name__}: {e}' - ) - # Keep the original value if conversion fails - pass - - return converted_args + if ( + param_name not in args + or param.annotation is inspect.Parameter.empty + or param_name in self._ignore_params + ): + continue + + target_type = param.annotation + + # Validate and coerce using TypeAdapter. Handles primitives, enums, + # Pydantic models, Optional[T], T | None, and container types natively. + try: + try: + adapter = self._type_adapter_cache[target_type] + except TypeError: + adapter = pydantic.TypeAdapter(target_type) + except KeyError: + adapter = pydantic.TypeAdapter(target_type) + self._type_adapter_cache[target_type] = adapter + converted_args[param_name] = adapter.validate_python( + args[param_name] + ) + except pydantic.ValidationError as e: + validation_errors.append( + f"Parameter '{param_name}': expected type '{getattr(target_type, '__name__', target_type)}'," + f' validation error: {e}' + ) + except (TypeError, NameError) as e: + # TypeAdapter could not handle this annotation (e.g. a forward + # reference string). Skip validation but log a warning. + logger.warning( + "Skipping validation for parameter '%s' due to unhandled" + " annotation type '%s': %s", + param_name, + target_type, + e, + ) + + return converted_args, validation_errors + + def _build_validation_error_response( + self, validation_errors: list[str] + ) -> dict[str, str]: + """Formats validation errors into an error dict for the LLM.""" + validation_errors_str = '\n'.join(validation_errors) + return { + 'error': ( + f'Invoking `{self.name}()` failed due to argument validation' + f' errors:\n{validation_errors_str}\nYou could retry calling' + ' this tool with corrected argument types.' + ) + } @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext ) -> Any: - # Preprocess arguments (includes Pydantic model conversion) - args_to_call = self._preprocess_args(args) + # Preprocess arguments (includes Pydantic model conversion and type + # validation). Validation errors are returned to the LLM so it can + # self-correct and retry with proper argument types. + args_to_call, validation_errors = self._preprocess_args(args) + + if validation_errors: + return self._build_validation_error_response(validation_errors) signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} diff --git a/tests/unittests/tools/test_function_tool_arg_validation.py b/tests/unittests/tools/test_function_tool_arg_validation.py new file mode 100644 index 0000000000..7b9779c619 --- /dev/null +++ b/tests/unittests/tools/test_function_tool_arg_validation.py @@ -0,0 +1,213 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FunctionTool argument type validation and coercion.""" + +from enum import Enum +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.sessions.session import Session +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +def int_func(num: int) -> int: + return num + + +def float_func(val: float) -> float: + return val + + +def bool_func(flag: bool) -> bool: + return flag + + +def enum_func(color: Color) -> str: + return color.value + + +def list_int_func(nums: list[int]) -> list[int]: + return nums + + +def optional_int_func(num: Optional[int] = None) -> Optional[int]: + return num + + +def multi_param_func(name: str, count: int, flag: bool) -> dict: + return {"name": name, "count": count, "flag": flag} + + +# --- _preprocess_args coercion tests --- + + +class TestArgCoercion: + + def test_string_to_int(self): + tool = FunctionTool(int_func) + args, errors = tool._preprocess_args({"num": "42"}) + assert errors == [] + assert args["num"] == 42 + assert isinstance(args["num"], int) + + def test_float_to_int(self): + """Pydantic lax mode truncates float to int.""" + tool = FunctionTool(int_func) + args, errors = tool._preprocess_args({"num": 3.0}) + assert errors == [] + assert args["num"] == 3 + assert isinstance(args["num"], int) + + def test_string_to_float(self): + tool = FunctionTool(float_func) + args, errors = tool._preprocess_args({"val": "3.14"}) + assert errors == [] + assert abs(args["val"] - 3.14) < 1e-9 + + def test_int_to_float(self): + tool = FunctionTool(float_func) + args, errors = tool._preprocess_args({"val": 5}) + assert errors == [] + assert args["val"] == 5.0 + assert isinstance(args["val"], float) + + def test_enum_valid_value(self): + tool = FunctionTool(enum_func) + args, errors = tool._preprocess_args({"color": "red"}) + assert errors == [] + assert args["color"] == Color.RED + + def test_enum_invalid_value(self): + tool = FunctionTool(enum_func) + args, errors = tool._preprocess_args({"color": "purple"}) + assert len(errors) == 1 + assert "color" in errors[0] + + def test_list_int_coercion(self): + tool = FunctionTool(list_int_func) + args, errors = tool._preprocess_args({"nums": ["1", "2", "3"]}) + assert errors == [] + assert args["nums"] == [1, 2, 3] + + def test_optional_none_skipped(self): + tool = FunctionTool(optional_int_func) + args, errors = tool._preprocess_args({"num": None}) + assert errors == [] + assert args["num"] is None + + def test_optional_value_coerced(self): + tool = FunctionTool(optional_int_func) + args, errors = tool._preprocess_args({"num": "7"}) + assert errors == [] + assert args["num"] == 7 + + def test_bool_from_int(self): + tool = FunctionTool(bool_func) + args, errors = tool._preprocess_args({"flag": 1}) + assert errors == [] + assert args["flag"] is True + + +# --- _preprocess_args validation error tests --- + + +class TestArgValidationErrors: + + def test_string_for_int_returns_error(self): + tool = FunctionTool(int_func) + args, errors = tool._preprocess_args({"num": "foobar"}) + assert len(errors) == 1 + assert "num" in errors[0] + + def test_none_for_required_int_returns_error(self): + """None for a non-Optional int should be flagged.""" + tool = FunctionTool(int_func) + # None passed for a required int param. The Optional unwrap won't + # trigger because the annotation is plain `int`, not Optional[int]. + # TypeAdapter(int).validate_python(None) raises ValidationError. + args, errors = tool._preprocess_args({"num": None}) + assert len(errors) == 1 + assert "num" in errors[0] + + def test_multiple_param_errors(self): + tool = FunctionTool(multi_param_func) + args, errors = tool._preprocess_args( + {"name": 123, "count": "not_a_number", "flag": "not_a_bool"} + ) + # All three fail: pydantic rejects int->str, "not_a_number"->int, + # and "not_a_bool"->bool. + assert len(errors) == 3 + assert any("name" in e for e in errors) + assert any("count" in e for e in errors) + assert any("flag" in e for e in errors) + + +# --- run_async integration tests --- + + +def _make_tool_context(): + tool_context_mock = MagicMock(spec=ToolContext) + invocation_context_mock = MagicMock(spec=InvocationContext) + session_mock = MagicMock(spec=Session) + invocation_context_mock.session = session_mock + tool_context_mock.invocation_context = invocation_context_mock + return tool_context_mock + + +class TestRunAsyncValidation: + + @pytest.mark.asyncio + async def test_invalid_arg_returns_error_to_llm(self): + tool = FunctionTool(int_func) + result = await tool.run_async( + args={"num": "foobar"}, tool_context=_make_tool_context() + ) + assert isinstance(result, dict) + assert "error" in result + assert "validation error" in result["error"].lower() + + @pytest.mark.asyncio + async def test_valid_coercion_invokes_function(self): + tool = FunctionTool(int_func) + result = await tool.run_async( + args={"num": "42"}, tool_context=_make_tool_context() + ) + assert result == 42 + + @pytest.mark.asyncio + async def test_enum_invalid_returns_error(self): + tool = FunctionTool(enum_func) + result = await tool.run_async( + args={"color": "purple"}, tool_context=_make_tool_context() + ) + assert isinstance(result, dict) + assert "error" in result + + @pytest.mark.asyncio + async def test_enum_valid_invokes_function(self): + tool = FunctionTool(enum_func) + result = await tool.run_async( + args={"color": "green"}, tool_context=_make_tool_context() + ) + assert result == "green" diff --git a/tests/unittests/tools/test_function_tool_pydantic.py b/tests/unittests/tools/test_function_tool_pydantic.py index 82f5631a35..8ffebcf21f 100644 --- a/tests/unittests/tools/test_function_tool_pydantic.py +++ b/tests/unittests/tools/test_function_tool_pydantic.py @@ -97,7 +97,7 @@ def test_preprocess_args_with_dict_to_pydantic_conversion(): "user": {"name": "Alice", "age": 30, "email": "alice@example.com"} } - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Check that the dict was converted to a Pydantic model assert "user" in processed_args @@ -116,7 +116,7 @@ def test_preprocess_args_with_existing_pydantic_model(): existing_user = UserModel(name="Bob", age=25) input_args = {"user": existing_user} - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Check that the existing model was not changed (same object) assert "user" in processed_args @@ -132,7 +132,7 @@ def test_preprocess_args_with_optional_pydantic_model_none(): input_args = {"user": {"name": "Charlie", "age": 35}, "preferences": None} - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Check user conversion assert isinstance(processed_args["user"], UserModel) @@ -151,7 +151,7 @@ def test_preprocess_args_with_optional_pydantic_model_dict(): "preferences": {"theme": "dark", "notifications": False}, } - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Check both conversions assert isinstance(processed_args["user"], UserModel) @@ -172,7 +172,7 @@ def test_preprocess_args_with_mixed_types(): "count": 10, } - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Check that only Pydantic model was converted assert processed_args["name"] == "test_name" # string unchanged @@ -184,17 +184,18 @@ def test_preprocess_args_with_mixed_types(): assert processed_args["user"].age == 40 -def test_preprocess_args_with_invalid_data_graceful_failure(): - """Test _preprocess_args handles invalid data gracefully.""" +def test_preprocess_args_with_invalid_data_returns_error(): + """Test _preprocess_args returns validation error for invalid Pydantic data.""" tool = FunctionTool(sync_function_with_pydantic_model) # Invalid data that can't be converted to UserModel input_args = {"user": "invalid_string"} # string instead of dict/model - processed_args = tool._preprocess_args(input_args) + _, errors = tool._preprocess_args(input_args) - # Should keep original value when conversion fails - assert processed_args["user"] == "invalid_string" + # Should return a validation error for the LLM to self-correct + assert len(errors) == 1 + assert "user" in errors[0] def test_preprocess_args_with_non_pydantic_parameters(): @@ -206,7 +207,7 @@ def simple_function(name: str, age: int) -> dict: tool = FunctionTool(simple_function) input_args = {"name": "test", "age": 25} - processed_args = tool._preprocess_args(input_args) + processed_args, _ = tool._preprocess_args(input_args) # Should remain unchanged (no Pydantic models to convert) assert processed_args == input_args