diff --git a/src/cachier/config.py b/src/cachier/config.py index aa1ae0d5..6910a8ec 100644 --- a/src/cachier/config.py +++ b/src/cachier/config.py @@ -43,7 +43,9 @@ def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None: hasher.update(value.tobytes(order="C")) -def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None: +def _update_hash_for_value( + hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100 +) -> None: """Update hasher with a stable representation of a Python value. Parameters @@ -52,8 +54,24 @@ def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None: The hasher to update. value : Any Value to encode. + depth : int, optional + Current recursion depth (internal use only). + max_depth : int, optional + Maximum allowed recursion depth to prevent stack overflow. + + Raises + ------ + RecursionError + If the recursion depth exceeds max_depth. """ + if depth > max_depth: + raise RecursionError( + f"Maximum recursion depth ({max_depth}) exceeded while hashing nested " + f"data structure. Consider flattening your data or using a custom " + f"hash_func parameter." + ) + if _is_numpy_array(value): _hash_numpy_array(hasher, value) return @@ -61,20 +79,20 @@ def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None: if isinstance(value, tuple): hasher.update(b"tuple") for item in value: - _update_hash_for_value(hasher, item) + _update_hash_for_value(hasher, item, depth + 1, max_depth) return if isinstance(value, list): hasher.update(b"list") for item in value: - _update_hash_for_value(hasher, item) + _update_hash_for_value(hasher, item, depth + 1, max_depth) return if isinstance(value, dict): hasher.update(b"dict") for dict_key in sorted(value): - _update_hash_for_value(hasher, dict_key) - _update_hash_for_value(hasher, value[dict_key]) + _update_hash_for_value(hasher, dict_key, depth + 1, max_depth) + _update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth) return hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)) diff --git a/tests/test_recursion_depth.py b/tests/test_recursion_depth.py new file mode 100644 index 00000000..34795650 --- /dev/null +++ b/tests/test_recursion_depth.py @@ -0,0 +1,147 @@ +"""Tests for recursion depth protection in hash function.""" + +from datetime import timedelta + +import pytest + +from cachier import cachier + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_moderately_nested_structures_work(backend, tmp_path): + """Verify that moderately nested structures (< 100 levels) work fine.""" + call_count = 0 + + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_nested(data): + nonlocal call_count + call_count += 1 + return "processed" + + # Create a nested structure with 50 levels (well below the 100 limit) + nested_list = [] + current = nested_list + for _ in range(50): + inner = [] + current.append(inner) + current = inner + current.append("leaf") + + # Should work without issues + result1 = process_nested(nested_list) + assert result1 == "processed" + assert call_count == 1 + + # Second call should hit cache + result2 = process_nested(nested_list) + assert result2 == "processed" + assert call_count == 1 + + process_nested.clear_cache() + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_deeply_nested_structures_raise_error(backend, tmp_path): + """Verify that deeply nested structures (> 100 levels) raise RecursionError.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_nested(data): + return "processed" + + # Create a nested structure with 150 levels (exceeds the 100 limit) + nested_list = [] + current = nested_list + for _ in range(150): + inner = [] + current.append(inner) + current = inner + current.append("leaf") + + # Should raise RecursionError with a clear message + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_nested(nested_list) + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_nested_dicts_respect_depth_limit(backend, tmp_path): + """Verify that nested dictionaries also respect the depth limit.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_dict(data): + return "processed" + + # Create nested dictionaries beyond the limit + nested_dict = {} + current = nested_dict + for i in range(150): + current[f"level_{i}"] = {} + current = current[f"level_{i}"] + current["leaf"] = "value" + + # Should raise RecursionError + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_dict(nested_dict) + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("memory", marks=pytest.mark.memory), + pytest.param("pickle", marks=pytest.mark.pickle), + ], +) +def test_nested_tuples_respect_depth_limit(backend, tmp_path): + """Verify that nested tuples also respect the depth limit.""" + decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)} + if backend == "pickle": + decorator_kwargs["cache_dir"] = tmp_path + + @cachier(**decorator_kwargs) + def process_tuple(data): + return "processed" + + # Create nested tuples beyond the limit + nested_tuple = ("leaf",) + for _ in range(150): + nested_tuple = (nested_tuple,) + + # Should raise RecursionError + with pytest.raises( + RecursionError, + match=r"Maximum recursion depth \(100\) exceeded while hashing nested", + ): + process_tuple(nested_tuple)