Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/google/adk/agents/context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from __future__ import annotations

from typing import Optional

from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
Expand All @@ -38,10 +41,12 @@ class ContextCacheConfig(BaseModel):
cache_intervals: Maximum number of invocations to reuse the same cache before refreshing it
ttl_seconds: Time-to-live for cache in seconds
min_tokens: Minimum tokens required to enable caching
create_http_options: HTTP options for cache creation API calls
"""

model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)

cache_intervals: int = Field(
Expand Down Expand Up @@ -72,6 +77,18 @@ class ContextCacheConfig(BaseModel):
),
)

create_http_options: Optional[types.HttpOptions] = Field(
default=None,
description=(
"HTTP options for cache creation API calls. Use this to set a"
" timeout on CachedContent.create() calls (e.g."
" types.HttpOptions(timeout=10000) for a 10-second timeout in"
" milliseconds). When the cache creation call exceeds the timeout,"
" it fails and the request proceeds without caching. None uses the"
" client's default HTTP options."
),
)

@property
def ttl_string(self) -> str:
"""Get TTL as string format for cache creation."""
Expand All @@ -81,5 +98,6 @@ def __str__(self) -> str:
"""String representation for logging."""
return (
f"ContextCacheConfig(cache_intervals={self.cache_intervals}, "
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens})"
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens}, "
f"create_http_options={self.create_http_options})"
)
12 changes: 11 additions & 1 deletion src/google/adk/models/gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ async def handle_context_caching(
)
if cache_metadata:
self._apply_cache_to_request(
llm_request, cache_metadata.cache_name, cache_contents_count
llm_request,
cache_metadata.cache_name,
cache_contents_count,
)
return cache_metadata

Expand All @@ -127,6 +129,7 @@ async def handle_context_caching(
fingerprint_for_all = self._generate_cache_fingerprint(
llm_request, total_contents_count
)

return CacheMetadata(
fingerprint=fingerprint_for_all,
contents_count=total_contents_count,
Expand Down Expand Up @@ -386,6 +389,13 @@ async def _create_gemini_cache(
if llm_request.config and llm_request.config.tool_config:
cache_config.tool_config = llm_request.config.tool_config

# Pass through HTTP options (e.g. timeout) from cache config
if (
llm_request.cache_config
and llm_request.cache_config.create_http_options
):
cache_config.http_options = llm_request.cache_config.create_http_options

span.set_attribute("cache_contents_count", cache_contents_count)
span.set_attribute("model", llm_request.model)
span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds)
Expand Down
28 changes: 13 additions & 15 deletions tests/unittests/agents/test_context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,19 @@ def test_str_representation(self):
)

expected = (
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024)"
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024,"
" create_http_options=None)"
)
assert str(config) == expected

def test_str_representation_defaults(self):
"""Test string representation with default values."""
config = ContextCacheConfig()

expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0)"
expected = (
"ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0,"
" create_http_options=None)"
)
assert str(config) == expected

def test_pydantic_model_validation(self):
Expand All @@ -126,25 +130,19 @@ def test_pydantic_model_validation(self):

def test_field_descriptions(self):
"""Test that fields have proper descriptions."""
config = ContextCacheConfig()
schema = config.model_json_schema()
fields = ContextCacheConfig.model_fields

assert "cache_intervals" in schema["properties"]
assert "cache_intervals" in fields
assert (
"Maximum number of invocations"
in schema["properties"]["cache_intervals"]["description"]
"Maximum number of invocations" in fields["cache_intervals"].description
)

assert "ttl_seconds" in schema["properties"]
assert (
"Time-to-live for cache"
in schema["properties"]["ttl_seconds"]["description"]
)
assert "ttl_seconds" in fields
assert "Time-to-live for cache" in fields["ttl_seconds"].description

assert "min_tokens" in schema["properties"]
assert "min_tokens" in fields
assert (
"Minimum estimated request tokens"
in schema["properties"]["min_tokens"]["description"]
"Minimum estimated request tokens" in fields["min_tokens"].description
)

def test_immutability_config(self):
Expand Down
60 changes: 60 additions & 0 deletions tests/unittests/agents/test_gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,63 @@ async def test_cache_creation_without_token_count(self):
assert result.cache_name is None
assert result.fingerprint == "test_fp"
self.manager.genai_client.aio.caches.create.assert_not_called()

async def test_create_http_options_passthrough(self):
"""Test that create_http_options is passed through to cache creation config."""
mock_cached_content = AsyncMock()
mock_cached_content.name = (
"projects/test/locations/us-central1/cachedContents/test123"
)
self.manager.genai_client.aio.caches.create = AsyncMock(
return_value=mock_cached_content
)

# Create config with http_options (e.g. 10s timeout)
http_options = types.HttpOptions(timeout=10000)
cache_config_with_timeout = ContextCacheConfig(
cache_intervals=10,
ttl_seconds=1800,
min_tokens=0,
create_http_options=http_options,
)

llm_request = self.create_llm_request()
llm_request.cache_config = cache_config_with_timeout

cache_contents_count = max(0, len(llm_request.contents) - 1)

with patch.object(
self.manager, "_generate_cache_fingerprint", return_value="test_fp"
):
await self.manager._create_gemini_cache(llm_request, cache_contents_count)

# Verify cache creation call includes http_options
create_call = self.manager.genai_client.aio.caches.create.call_args
assert create_call is not None
cache_config = create_call[1]["config"]
assert cache_config.http_options is not None
assert cache_config.http_options.timeout == 10000

async def test_create_without_http_options(self):
"""Test that cache creation works without create_http_options."""
mock_cached_content = AsyncMock()
mock_cached_content.name = (
"projects/test/locations/us-central1/cachedContents/test123"
)
self.manager.genai_client.aio.caches.create = AsyncMock(
return_value=mock_cached_content
)

llm_request = self.create_llm_request()
cache_contents_count = max(0, len(llm_request.contents) - 1)

with patch.object(
self.manager, "_generate_cache_fingerprint", return_value="test_fp"
):
await self.manager._create_gemini_cache(llm_request, cache_contents_count)

# Verify cache creation call does not include http_options
create_call = self.manager.genai_client.aio.caches.create.call_args
assert create_call is not None
cache_config = create_call[1]["config"]
assert cache_config.http_options is None
3 changes: 2 additions & 1 deletion tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,8 @@ def test_runner_realistic_cache_config_scenario(self):

# Verify string representation
expected_str = (
"ContextCacheConfig(cache_intervals=30, ttl=14400s, min_tokens=4096)"
"ContextCacheConfig(cache_intervals=30, ttl=14400s, min_tokens=4096,"
" create_http_options=None)"
)
assert str(runner.context_cache_config) == expected_str

Expand Down