fix: add plugin before/after tool callbacks to live execution path#4709
fix: add plugin before/after tool callbacks to live execution path#4709OiPunk wants to merge 2 commits intogoogle:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an inconsistency in how tool execution callbacks were handled between live (bidi streaming) and non-live sessions. Previously, plugin-level Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly adds the missing plugin before_tool_callback and after_tool_callback calls to the live execution path, ensuring its behavior matches the non-live path. However, it introduces a regression in error handling where the live path does not catch exceptions during tool execution, potentially leading to session crashes. Furthermore, the stop_streaming tool lacks input validation for the function_name argument, which could be exploited for a denial of service. Additionally, while the accompanying tests are a good start, refactoring is suggested to improve their maintainability and robustness by reducing code duplication and expanding coverage through parameterization.
| # --- Plugin callback tests for live mode --- | ||
|
|
||
|
|
||
| class MockPlugin(BasePlugin): | ||
| """A mock plugin for testing plugin callbacks in live mode.""" | ||
|
|
||
| before_tool_response = {"MockPlugin": "before_tool_response from MockPlugin"} | ||
| after_tool_response = {"MockPlugin": "after_tool_response from MockPlugin"} | ||
|
|
||
| def __init__(self, name="mock_plugin"): | ||
| self.name = name | ||
| self.enable_before_tool_callback = False | ||
| self.enable_after_tool_callback = False | ||
|
|
||
| async def before_tool_callback( | ||
| self, | ||
| *, | ||
| tool: BaseTool, | ||
| tool_args: dict[str, Any], | ||
| tool_context: ToolContext, | ||
| ) -> Optional[dict]: | ||
| if not self.enable_before_tool_callback: | ||
| return None | ||
| return self.before_tool_response | ||
|
|
||
| async def after_tool_callback( | ||
| self, | ||
| *, | ||
| tool: BaseTool, | ||
| tool_args: dict[str, Any], | ||
| tool_context: ToolContext, | ||
| result: dict, | ||
| ) -> Optional[dict]: | ||
| if not self.enable_after_tool_callback: | ||
| return None | ||
| return self.after_tool_response | ||
|
|
||
|
|
||
| async def invoke_tool_with_plugin_live( | ||
| mock_plugin, | ||
| ) -> Optional[Event]: | ||
| """Invokes a tool with a plugin using live mode.""" | ||
|
|
||
| def simple_fn(**kwargs) -> Dict[str, Any]: | ||
| return {"initial": "response"} | ||
|
|
||
| tool = FunctionTool(simple_fn) | ||
| model = testing_utils.MockModel.create(responses=[]) | ||
| agent = Agent( | ||
| name="agent", | ||
| model=model, | ||
| tools=[tool], | ||
| ) | ||
| invocation_context = await testing_utils.create_invocation_context( | ||
| agent=agent, user_content="", plugins=[mock_plugin] | ||
| ) | ||
| function_call = types.FunctionCall(name=tool.name, args={}) | ||
| content = types.Content(parts=[types.Part(function_call=function_call)]) | ||
| event = Event( | ||
| invocation_id=invocation_context.invocation_id, | ||
| author=agent.name, | ||
| content=content, | ||
| ) | ||
| tools_dict = {tool.name: tool} | ||
| return await handle_function_calls_live( | ||
| invocation_context, | ||
| event, | ||
| tools_dict, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_live_plugin_before_tool_callback(): | ||
| """Test that plugin before_tool_callback is called in live mode.""" | ||
| plugin = MockPlugin() | ||
| plugin.enable_before_tool_callback = True | ||
|
|
||
| result_event = await invoke_tool_with_plugin_live(plugin) | ||
|
|
||
| assert result_event is not None | ||
| part = result_event.content.parts[0] | ||
| assert part.function_response.response == plugin.before_tool_response | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_live_plugin_after_tool_callback(): | ||
| """Test that plugin after_tool_callback is called in live mode.""" | ||
| plugin = MockPlugin() | ||
| plugin.enable_after_tool_callback = True | ||
|
|
||
| result_event = await invoke_tool_with_plugin_live(plugin) | ||
|
|
||
| assert result_event is not None | ||
| part = result_event.content.parts[0] | ||
| assert part.function_response.response == plugin.after_tool_response | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_live_plugin_before_tool_callback_disabled(): | ||
| """Test that disabled plugin before_tool_callback allows normal tool execution.""" | ||
| plugin = MockPlugin() | ||
| plugin.enable_before_tool_callback = False | ||
|
|
||
| result_event = await invoke_tool_with_plugin_live(plugin) | ||
|
|
||
| assert result_event is not None | ||
| part = result_event.content.parts[0] | ||
| assert part.function_response.response == {"initial": "response"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_live_plugin_callbacks_match_async_behavior(): | ||
| """Test that plugin callbacks in live mode match the async (non-live) behavior.""" | ||
| from google.adk.flows.llm_flows.functions import handle_function_calls_async | ||
|
|
||
| def simple_fn(**kwargs) -> Dict[str, Any]: | ||
| return {"initial": "response"} | ||
|
|
||
| tool = FunctionTool(simple_fn) | ||
| model = testing_utils.MockModel.create(responses=[]) | ||
| agent = Agent( | ||
| name="agent", | ||
| model=model, | ||
| tools=[tool], | ||
| ) | ||
|
|
||
| # Test with plugin before_tool_callback enabled | ||
| plugin = MockPlugin() | ||
| plugin.enable_before_tool_callback = True | ||
|
|
||
| invocation_context = await testing_utils.create_invocation_context( | ||
| agent=agent, user_content="", plugins=[plugin] | ||
| ) | ||
| function_call = types.FunctionCall(name=tool.name, args={}) | ||
| content = types.Content(parts=[types.Part(function_call=function_call)]) | ||
| event = Event( | ||
| invocation_id=invocation_context.invocation_id, | ||
| author=agent.name, | ||
| content=content, | ||
| ) | ||
| tools_dict = {tool.name: tool} | ||
|
|
||
| async_result = await handle_function_calls_async( | ||
| invocation_context, event, tools_dict | ||
| ) | ||
| live_result = await handle_function_calls_live( | ||
| invocation_context, event, tools_dict | ||
| ) | ||
|
|
||
| assert async_result is not None | ||
| assert live_result is not None | ||
| async_response = async_result.content.parts[0].function_response.response | ||
| live_response = live_result.content.parts[0].function_response.response | ||
| assert async_response == live_response == plugin.before_tool_response |
There was a problem hiding this comment.
The new tests for plugin callbacks are a great addition. However, they could be made more robust and maintainable with a couple of improvements:
-
Reduce Code Duplication: There's significant setup code duplicated between
invoke_tool_with_plugin_liveandtest_live_plugin_callbacks_match_async_behavior. This can be extracted into a shared test harness function. -
Increase Test Coverage: The parity test
test_live_plugin_callbacks_match_async_behavioronly validates thebefore_tool_callbackscenario. To fully ensure that the live and async paths produce identical results, it would be beneficial to cover more cases (e.g.,after_tool_callbackenabled, both enabled, none enabled).
I've provided a suggestion that refactors this entire test section to create a shared setup harness and uses pytest.mark.parametrize to cover all relevant scenarios. This makes the tests more comprehensive and easier to maintain.
# --- Plugin callback tests for live mode ---
async def _setup_plugin_test_harness(plugin: "MockPlugin"):
"""Creates a common test setup for plugin callback tests."""
def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"}
tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content="", plugins=[plugin]
)
function_call = types.FunctionCall(name=tool.name, args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
return invocation_context, event, tools_dict
class MockPlugin(BasePlugin):
"""A mock plugin for testing plugin callbacks in live mode."""
before_tool_response = {"MockPlugin": "before_tool_response from MockPlugin"}
after_tool_response = {"MockPlugin": "after_tool_response from MockPlugin"}
def __init__(self, name="mock_plugin"):
self.name = name
self.enable_before_tool_callback = False
self.enable_after_tool_callback = False
async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
if not self.enable_before_tool_callback:
return None
return self.before_tool_response
async def after_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
if not self.enable_after_tool_callback:
return None
return self.after_tool_response
async def invoke_tool_with_plugin_live(
mock_plugin,
) -> Optional[Event]:
"""Invokes a tool with a plugin using live mode."""
invocation_context, event, tools_dict = await _setup_plugin_test_harness(
mock_plugin
)
return await handle_function_calls_live(
invocation_context,
event,
tools_dict,
)
@pytest.mark.asyncio
async def test_live_plugin_before_tool_callback():
"""Test that plugin before_tool_callback is called in live mode."""
plugin = MockPlugin()
plugin.enable_before_tool_callback = True
result_event = await invoke_tool_with_plugin_live(plugin)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == plugin.before_tool_response
@pytest.mark.asyncio
async def test_live_plugin_after_tool_callback():
"""Test that plugin after_tool_callback is called in live mode."""
plugin = MockPlugin()
plugin.enable_after_tool_callback = True
result_event = await invoke_tool_with_plugin_live(plugin)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == plugin.after_tool_response
@pytest.mark.asyncio
async def test_live_plugin_before_tool_callback_disabled():
"""Test that disabled plugin before_tool_callback allows normal tool execution."""
plugin = MockPlugin()
plugin.enable_before_tool_callback = False
result_event = await invoke_tool_with_plugin_live(plugin)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == {"initial": "response"}
@pytest.mark.parametrize(
"enable_before, enable_after, get_expected_response",
[
(True, False, lambda p: p.before_tool_response),
(False, True, lambda p: p.after_tool_response),
(True, True, lambda p: p.before_tool_response),
(False, False, lambda p: {"initial": "response"}),
],
ids=["before_only", "after_only", "both_enabled", "none_enabled"],
)
@pytest.mark.asyncio
async def test_live_plugin_callbacks_match_async_behavior(
enable_before, enable_after, get_expected_response
):
"""Test that plugin callbacks in live mode match the async (non-live) behavior."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
plugin = MockPlugin()
plugin.enable_before_tool_callback = enable_before
plugin.enable_after_tool_callback = enable_after
invocation_context, event, tools_dict = await _setup_plugin_test_harness(plugin)
async_result = await handle_function_calls_async(
invocation_context, event, tools_dict
)
live_result = await handle_function_calls_live(
invocation_context, event, tools_dict
)
assert async_result is not None
assert live_result is not None
async_response = async_result.content.parts[0].function_response.response
live_response = live_result.content.parts[0].function_response.response
expected_response = get_expected_response(plugin)
assert async_response == live_response == expected_response`_execute_single_function_call_live()` was only calling agent-level canonical callbacks but skipping `plugin_manager.run_before_tool_callback()` and `plugin_manager.run_after_tool_callback()`. This meant plugins registered with `before_tool_callback` / `after_tool_callback` were never invoked during live (bidi streaming) sessions. Align the live path with the non-live `_execute_single_function_call_async()` by inserting the plugin manager calls in the same order: 1. plugin before_tool_callback 2. canonical before_tool_callbacks 3. tool execution 4. plugin after_tool_callback 5. canonical after_tool_callbacks Added tests verifying plugin callbacks fire in live mode and produce the same results as the non-live path. Fixes google#4704
Restore try/except around _process_function_live_helper() in _execute_single_function_call_live so that on_tool_error_callbacks fire when tool execution raises, matching the non-live async path. Also address review feedback: - Extract _setup_plugin_test_harness to reduce test code duplication - Parametrize plugin parity test to cover all callback combinations - Add tests for on_tool_error_callback during live tool execution
d357b0c to
2060fb3
Compare
|
Update: After rebasing to the latest This PR now focuses on contributing test coverage for the live tool callbacks feature ( If this test coverage is not needed, please feel free to close this PR. Otherwise, happy to iterate on the tests. |
Summary
_execute_single_function_call_live()was only calling agent-level canonical callbacks but skippingplugin_manager.run_before_tool_callback()andplugin_manager.run_after_tool_callback(), causing plugins to miss tool execution events during live (bidi streaming) sessions_execute_single_function_call_async()Details
The non-live path in
_execute_single_function_call_async()follows this order:plugin_manager.run_before_tool_callback()agent.canonical_before_tool_callbacksplugin_manager.run_after_tool_callback()agent.canonical_after_tool_callbacksThe live path in
_execute_single_function_call_live()was missing steps 1 and 4, going straight to the canonical callbacks. This PR adds the missing plugin manager calls so both paths behave identically.Test plan
test_live_tool_callbacks.pytests passtest_live_plugin_before_tool_callback— pluginbefore_tool_callbackoverrides responsetest_live_plugin_after_tool_callback— pluginafter_tool_callbackoverrides responsetest_live_plugin_before_tool_callback_disabled— disabled plugin allows normal executiontest_live_plugin_callbacks_match_async_behavior— live and async paths produce identical resultstests/unittests/flows/llm_flows/passpyinkandisortFixes #4704