Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None:
hasher.update(value.tobytes(order="C"))


def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
def _update_hash_for_value(
hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100
) -> None:
"""Update hasher with a stable representation of a Python value.

Parameters
Expand All @@ -52,29 +54,45 @@ def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
The hasher to update.
value : Any
Value to encode.
depth : int, optional
Current recursion depth (internal use only).
max_depth : int, optional
Maximum allowed recursion depth to prevent stack overflow.

Raises
------
RecursionError
If the recursion depth exceeds max_depth.

"""
if depth > max_depth:
raise RecursionError(
f"Maximum recursion depth ({max_depth}) exceeded while hashing nested "
f"data structure. Consider flattening your data or using a custom "
f"hash_func parameter."
)

if _is_numpy_array(value):
_hash_numpy_array(hasher, value)
return

if isinstance(value, tuple):
hasher.update(b"tuple")
for item in value:
_update_hash_for_value(hasher, item)
_update_hash_for_value(hasher, item, depth + 1, max_depth)
return

if isinstance(value, list):
hasher.update(b"list")
for item in value:
_update_hash_for_value(hasher, item)
_update_hash_for_value(hasher, item, depth + 1, max_depth)
return

if isinstance(value, dict):
hasher.update(b"dict")
for dict_key in sorted(value):
_update_hash_for_value(hasher, dict_key)
_update_hash_for_value(hasher, value[dict_key])
_update_hash_for_value(hasher, dict_key, depth + 1, max_depth)
_update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth)
return

hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
Expand Down
147 changes: 147 additions & 0 deletions tests/test_recursion_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Tests for recursion depth protection in hash function."""

from datetime import timedelta

import pytest

from cachier import cachier


@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)
def test_moderately_nested_structures_work(backend, tmp_path):
"""Verify that moderately nested structures (< 100 levels) work fine."""
call_count = 0

decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def process_nested(data):
nonlocal call_count
call_count += 1
return "processed"

# Create a nested structure with 50 levels (well below the 100 limit)
nested_list = []
current = nested_list
for _ in range(50):
inner = []
current.append(inner)
current = inner
current.append("leaf")

# Should work without issues
result1 = process_nested(nested_list)
assert result1 == "processed"
assert call_count == 1

# Second call should hit cache
result2 = process_nested(nested_list)
assert result2 == "processed"
assert call_count == 1

process_nested.clear_cache()


@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)
def test_deeply_nested_structures_raise_error(backend, tmp_path):
"""Verify that deeply nested structures (> 100 levels) raise RecursionError."""
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def process_nested(data):
return "processed"

# Create a nested structure with 150 levels (exceeds the 100 limit)
nested_list = []
current = nested_list
for _ in range(150):
inner = []
current.append(inner)
current = inner
current.append("leaf")

# Should raise RecursionError with a clear message
with pytest.raises(
RecursionError,
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
):
process_nested(nested_list)


@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)
def test_nested_dicts_respect_depth_limit(backend, tmp_path):
"""Verify that nested dictionaries also respect the depth limit."""
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def process_dict(data):
return "processed"

# Create nested dictionaries beyond the limit
nested_dict = {}
current = nested_dict
for i in range(150):
current[f"level_{i}"] = {}
current = current[f"level_{i}"]
current["leaf"] = "value"

# Should raise RecursionError
with pytest.raises(
RecursionError,
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
):
process_dict(nested_dict)


@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)
def test_nested_tuples_respect_depth_limit(backend, tmp_path):
"""Verify that nested tuples also respect the depth limit."""
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def process_tuple(data):
return "processed"

# Create nested tuples beyond the limit
nested_tuple = ("leaf",)
for _ in range(150):
nested_tuple = (nested_tuple,)

# Should raise RecursionError
with pytest.raises(
RecursionError,
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
):
process_tuple(nested_tuple)