diff --git a/src/ucache.hpp b/src/ucache.hpp index 92e94a16d..d32476185 100644 --- a/src/ucache.hpp +++ b/src/ucache.hpp @@ -19,6 +19,7 @@ struct UCacheConfig { bool adaptive_threshold = true; float early_step_multiplier = 0.5f; float late_step_multiplier = 1.5f; + float relative_norm_gain = 1.6f; bool reset_error_on_compute = true; }; @@ -45,14 +46,16 @@ struct UCacheState { bool has_output_prev_norm = false; bool has_relative_transformation_rate = false; float relative_transformation_rate = 0.0f; - float cumulative_change_rate = 0.0f; float last_input_change = 0.0f; bool has_last_input_change = false; + float output_change_ema = 0.0f; + bool has_output_change_ema = false; int total_steps_skipped = 0; int current_step_index = -1; int steps_computed_since_active = 0; + int expected_total_steps = 0; + int consecutive_skipped_steps = 0; float accumulated_error = 0.0f; - float reference_output_norm = 0.0f; struct BlockMetrics { float sum_transformation_rate = 0.0f; @@ -106,14 +109,16 @@ struct UCacheState { has_output_prev_norm = false; has_relative_transformation_rate = false; relative_transformation_rate = 0.0f; - cumulative_change_rate = 0.0f; last_input_change = 0.0f; has_last_input_change = false; + output_change_ema = 0.0f; + has_output_change_ema = false; total_steps_skipped = 0; current_step_index = -1; steps_computed_since_active = 0; + expected_total_steps = 0; + consecutive_skipped_steps = 0; accumulated_error = 0.0f; - reference_output_norm = 0.0f; block_metrics.reset(); total_active_steps = 0; } @@ -133,7 +138,8 @@ struct UCacheState { if (!initialized || sigmas.size() < 2) { return; } - size_t n_steps = sigmas.size() - 1; + size_t n_steps = sigmas.size() - 1; + expected_total_steps = static_cast(n_steps); size_t start_step = static_cast(config.start_percent * n_steps); size_t end_step = static_cast(config.end_percent * n_steps); @@ -207,11 +213,15 @@ struct UCacheState { } int effective_total = estimated_total_steps; + if (effective_total <= 0) { + effective_total = expected_total_steps; + } if (effective_total <= 0) { effective_total = std::max(20, steps_computed_since_active * 2); } float progress = (effective_total > 0) ? (static_cast(steps_computed_since_active) / effective_total) : 0.0f; + progress = std::max(0.0f, std::min(1.0f, progress)); float multiplier = 1.0f; if (progress < 0.2f) { @@ -309,17 +319,31 @@ struct UCacheState { if (has_output_prev_norm && has_relative_transformation_rate && last_input_change > 0.0f && output_prev_norm > 0.0f) { - float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; - accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate; + float approx_output_change = relative_transformation_rate * last_input_change; + float approx_output_change_rate; + if (config.use_relative_threshold) { + float base_scale = std::max(output_prev_norm, 1e-6f); + float dyn_scale = has_output_change_ema + ? std::max(output_change_ema * std::max(1.0f, config.relative_norm_gain), 1e-6f) + : base_scale; + float scale = std::sqrt(base_scale * dyn_scale); + approx_output_change_rate = approx_output_change / scale; + } else { + approx_output_change_rate = approx_output_change; + } + // Increase estimated error with skip horizon to avoid long extrapolation streaks + approx_output_change_rate *= (1.0f + 0.50f * consecutive_skipped_steps); + accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate; float effective_threshold = get_adaptive_threshold(); - if (config.use_relative_threshold && reference_output_norm > 0.0f) { - effective_threshold = effective_threshold * reference_output_norm; + if (!config.use_relative_threshold && output_prev_norm > 0.0f) { + effective_threshold = effective_threshold * output_prev_norm; } if (accumulated_error < effective_threshold) { skip_current_step = true; total_steps_skipped++; + consecutive_skipped_steps++; apply_cache(cond, input, output); return true; } else if (config.reset_error_on_compute) { @@ -340,6 +364,8 @@ struct UCacheState { if (cond != anchor_condition) { return; } + steps_computed_since_active++; + consecutive_skipped_steps = 0; size_t ne = static_cast(ggml_nelements(input)); float* in_data = (float*)input->data; @@ -359,6 +385,14 @@ struct UCacheState { output_change /= static_cast(ne); } } + if (std::isfinite(output_change) && output_change > 0.0f) { + if (!has_output_change_ema) { + output_change_ema = output_change; + has_output_change_ema = true; + } else { + output_change_ema = 0.8f * output_change_ema + 0.2f * output_change; + } + } prev_output.resize(ne); for (size_t i = 0; i < ne; ++i) { @@ -373,10 +407,6 @@ struct UCacheState { output_prev_norm = (ne > 0) ? (mean_abs / static_cast(ne)) : 0.0f; has_output_prev_norm = output_prev_norm > 0.0f; - if (reference_output_norm == 0.0f) { - reference_output_norm = output_prev_norm; - } - if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { float rate = output_change / last_input_change; if (std::isfinite(rate)) {