Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions src/ucache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct UCacheConfig {
bool adaptive_threshold = true;
float early_step_multiplier = 0.5f;
float late_step_multiplier = 1.5f;
float relative_norm_gain = 1.6f;
bool reset_error_on_compute = true;
};

Expand All @@ -45,14 +46,16 @@ struct UCacheState {
bool has_output_prev_norm = false;
bool has_relative_transformation_rate = false;
float relative_transformation_rate = 0.0f;
float cumulative_change_rate = 0.0f;
float last_input_change = 0.0f;
bool has_last_input_change = false;
float output_change_ema = 0.0f;
bool has_output_change_ema = false;
int total_steps_skipped = 0;
int current_step_index = -1;
int steps_computed_since_active = 0;
int expected_total_steps = 0;
int consecutive_skipped_steps = 0;
float accumulated_error = 0.0f;
float reference_output_norm = 0.0f;

struct BlockMetrics {
float sum_transformation_rate = 0.0f;
Expand Down Expand Up @@ -106,14 +109,16 @@ struct UCacheState {
has_output_prev_norm = false;
has_relative_transformation_rate = false;
relative_transformation_rate = 0.0f;
cumulative_change_rate = 0.0f;
last_input_change = 0.0f;
has_last_input_change = false;
output_change_ema = 0.0f;
has_output_change_ema = false;
total_steps_skipped = 0;
current_step_index = -1;
steps_computed_since_active = 0;
expected_total_steps = 0;
consecutive_skipped_steps = 0;
accumulated_error = 0.0f;
reference_output_norm = 0.0f;
block_metrics.reset();
total_active_steps = 0;
}
Expand All @@ -133,7 +138,8 @@ struct UCacheState {
if (!initialized || sigmas.size() < 2) {
return;
}
size_t n_steps = sigmas.size() - 1;
size_t n_steps = sigmas.size() - 1;
expected_total_steps = static_cast<int>(n_steps);

size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
Expand Down Expand Up @@ -207,11 +213,15 @@ struct UCacheState {
}

int effective_total = estimated_total_steps;
if (effective_total <= 0) {
effective_total = expected_total_steps;
}
if (effective_total <= 0) {
effective_total = std::max(20, steps_computed_since_active * 2);
}

float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
progress = std::max(0.0f, std::min(1.0f, progress));

float multiplier = 1.0f;
if (progress < 0.2f) {
Expand Down Expand Up @@ -309,17 +319,31 @@ struct UCacheState {

if (has_output_prev_norm && has_relative_transformation_rate &&
last_input_change > 0.0f && output_prev_norm > 0.0f) {
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
float approx_output_change = relative_transformation_rate * last_input_change;
float approx_output_change_rate;
if (config.use_relative_threshold) {
float base_scale = std::max(output_prev_norm, 1e-6f);
float dyn_scale = has_output_change_ema
? std::max(output_change_ema * std::max(1.0f, config.relative_norm_gain), 1e-6f)
: base_scale;
float scale = std::sqrt(base_scale * dyn_scale);
approx_output_change_rate = approx_output_change / scale;
} else {
approx_output_change_rate = approx_output_change;
}
// Increase estimated error with skip horizon to avoid long extrapolation streaks
approx_output_change_rate *= (1.0f + 0.50f * consecutive_skipped_steps);
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;

float effective_threshold = get_adaptive_threshold();
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
effective_threshold = effective_threshold * reference_output_norm;
if (!config.use_relative_threshold && output_prev_norm > 0.0f) {
effective_threshold = effective_threshold * output_prev_norm;
}

if (accumulated_error < effective_threshold) {
skip_current_step = true;
total_steps_skipped++;
consecutive_skipped_steps++;
apply_cache(cond, input, output);
return true;
} else if (config.reset_error_on_compute) {
Expand All @@ -340,6 +364,8 @@ struct UCacheState {
if (cond != anchor_condition) {
return;
}
steps_computed_since_active++;
consecutive_skipped_steps = 0;

size_t ne = static_cast<size_t>(ggml_nelements(input));
float* in_data = (float*)input->data;
Expand All @@ -359,6 +385,14 @@ struct UCacheState {
output_change /= static_cast<float>(ne);
}
}
if (std::isfinite(output_change) && output_change > 0.0f) {
if (!has_output_change_ema) {
output_change_ema = output_change;
has_output_change_ema = true;
} else {
output_change_ema = 0.8f * output_change_ema + 0.2f * output_change;
}
}

prev_output.resize(ne);
for (size_t i = 0; i < ne; ++i) {
Expand All @@ -373,10 +407,6 @@ struct UCacheState {
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
has_output_prev_norm = output_prev_norm > 0.0f;

if (reference_output_norm == 0.0f) {
reference_output_norm = output_prev_norm;
}

if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
float rate = output_change / last_input_change;
if (std::isfinite(rate)) {
Expand Down
Loading