Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spec/unit/bundle_loader_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,13 @@ runner:then_("^the compiled bundle includes runtime override blocks$", function(
assert.equals(true, ctx.compiled.global_shadow.enabled)
assert.equals("incident-global-shadow", ctx.compiled.global_shadow.reason)
assert.equals("2030-01-01T00:00:00Z", ctx.compiled.global_shadow.expires_at)
assert.is_number(ctx.compiled.global_shadow._expires_epoch)

assert.is_table(ctx.compiled.kill_switch_override)
assert.equals(true, ctx.compiled.kill_switch_override.enabled)
assert.equals("incident-ks-override", ctx.compiled.kill_switch_override.reason)
assert.equals("2030-01-01T00:00:00Z", ctx.compiled.kill_switch_override.expires_at)
assert.is_number(ctx.compiled.kill_switch_override._expires_epoch)
end)

runner:then_("^policies_by_id contains exactly one entry for that ID with the last policy spec$", function(ctx)
Expand Down
28 changes: 28 additions & 0 deletions spec/unit/circuit_breaker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,32 @@ describe("circuit_breaker targeted direct coverage", function()
assert.is_nil(result.spend_rate)
assert.is_nil(result.reason)
end)

it("clears rate keys on auto-reset before resuming traffic", function()
local env = mock_ngx.setup_ngx()
local now = env.time.now()
local limit_key = "org-auto-reset"
local config = {
enabled = true,
spend_rate_threshold_per_minute = 100,
action = "reject",
alert = false,
auto_reset_after_minutes = 5,
}
local current_window = _window_start(now)
local previous_window = current_window - 60

env.dict:set(circuit_breaker.build_state_key(limit_key), "open:" .. tostring(now - 300))
env.dict:set(circuit_breaker.build_rate_key(limit_key, current_window), 100)
env.dict:set(circuit_breaker.build_rate_key(limit_key, previous_window), 100)

local result = circuit_breaker.check(env.dict, config, limit_key, 1, now)

assert.is_false(result.tripped)
assert.equals("closed", result.state)
assert.equals(1, result.spend_rate)
assert.is_nil(env.dict:get(circuit_breaker.build_state_key(limit_key)))
assert.equals(1, env.dict:get(circuit_breaker.build_rate_key(limit_key, current_window)))
assert.is_nil(env.dict:get(circuit_breaker.build_rate_key(limit_key, previous_window)))
end)
end)
34 changes: 34 additions & 0 deletions spec/unit/rule_engine_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -916,4 +916,38 @@ describe("rule_engine targeted direct coverage", function()
assert.is_true(wrapped.would_reject)
debug.setupvalue(rule_engine.evaluate, shadow_index, original_shadow_mode)
end)

it("uses cached override expiry without reparsing expires_at", function()
local ctx = {}
_setup_engine(ctx)

ctx.matching_policy_ids = { "p1" }
ctx.request_context._descriptors["jwt:org_id"] = "org-shadow-cached"
ctx.bundle.global_shadow = {
enabled = true,
reason = "incident-global-shadow-cached",
expires_at = "not-a-date",
_expires_epoch = ctx.time.now() + 300,
}
ctx.bundle.policies_by_id.p1 = _new_policy("p1", "enforce", {
{
name = "reject_rule",
algorithm = "token_bucket",
limit_keys = { "jwt:org_id" },
algorithm_config = {},
},
})
ctx.rule_results.reject_rule = {
allowed = false,
reason = "rate_limited",
limit = 100,
remaining = 0,
retry_after = 2,
}

local decision = ctx.engine.evaluate(ctx.request_context, ctx.bundle)
assert.equals("allow", decision.action)
assert.equals("true", decision.headers["X-Fairvisor-Global-Shadow"])
assert.equals("incident-global-shadow-cached", decision.headers["X-Fairvisor-Global-Shadow-Reason"])
end)
end)
40 changes: 40 additions & 0 deletions spec/unit/saas_client_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -743,4 +743,44 @@ describe("saas_client targeted direct coverage", function()
local flushed = reloaded.flush_events()
assert.equals(3, flushed)
end)

it("rebuilds cached auth header after reinit with a different token", function()
local reloaded = _reload_saas_client()
local http = mock_http.new()
http.queue_response("POST", "https://s/api/v1/edge/register", { status = 200 })
http.queue_response("POST", "https://s/api/v1/edge/register", { status = 200 })

local deps = {
bundle_loader = { get_current = function() end, load_from_string = function() end, apply = function() end },
health = { set = function() end, inc = function() end },
http_client = http.client,
}

local ok, err = reloaded.init({
edge_id = "e",
edge_token = "token-1",
saas_url = "https://s",
}, deps)
assert.is_true(ok)
assert.is_nil(err)

ok, err = reloaded.init({
edge_id = "e",
edge_token = "token-2",
saas_url = "https://s",
}, deps)
assert.is_true(ok)
assert.is_nil(err)

local register_calls = {}
for _, request in ipairs(http.requests) do
if request.method == "POST" and request.url == "https://s/api/v1/edge/register" then
register_calls[#register_calls + 1] = request
end
end

assert.equals(2, #register_calls)
assert.equals("Bearer token-1", register_calls[1].headers.Authorization)
assert.equals("Bearer token-2", register_calls[2].headers.Authorization)
end)
end)
2 changes: 2 additions & 0 deletions src/fairvisor/bundle_loader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ local function _validate_top_level(bundle)
if expires_err then
return nil, "expires_at_invalid"
end
bundle._expires_epoch = expires_epoch

if ngx and ngx.now and expires_epoch <= ngx.now() then
return nil, "bundle_expired"
Expand Down Expand Up @@ -457,6 +458,7 @@ local function _validate_top_level(bundle)
if expires_err then
return nil, field_name .. "_invalid: expires_at_invalid"
end
block._expires_epoch = expires_epoch

if ngx and ngx.now and expires_epoch <= ngx.now() then
return nil, field_name .. "_invalid: expired"
Expand Down
2 changes: 1 addition & 1 deletion src/fairvisor/circuit_breaker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function _M.check(dict, config, limit_key, cost, now)
}
end

dict:delete(state_key)
_M.reset(dict, limit_key, now)
else
return {
tripped = true,
Expand Down
26 changes: 17 additions & 9 deletions src/fairvisor/llm_limiter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ local string_gsub = string.gsub
local string_format = string.format
local string_lower = string.lower
local string_sub = string.sub
local string_byte = string.byte
local os_date = os.date

local token_bucket = require("fairvisor.token_bucket")
Expand Down Expand Up @@ -160,42 +161,47 @@ local function _simple_word_estimate(request_context)
if body == "" then
return 0
end
if #body > MAX_BODY_SCAN_BYTES then
body = string_sub(body, 1, MAX_BODY_SCAN_BYTES)

local scan_limit = #body
if scan_limit > MAX_BODY_SCAN_BYTES then
scan_limit = MAX_BODY_SCAN_BYTES
end

local messages_start = string_find(body, "\"messages\"", 1, true)
if messages_start then
if messages_start and messages_start < scan_limit then
local array_start = string_find(body, "[", messages_start, true)
if array_start then
if array_start and array_start < scan_limit then
local position = array_start
local char_count = 0

while true do
-- Find "content" key
local key_start = string_find(body, "\"content\"", position, true)
if not key_start then
if not key_start or key_start >= scan_limit then
break
end

-- Look for value start: : "..."
-- pattern: %s*:%s*"
local val_marker_start, val_marker_end = string_find(body, "^%s*:%s*\"", key_start + 9)

if not val_marker_start then
if not val_marker_start or val_marker_start > scan_limit then
-- False positive (e.g. key was in a string), skip it
position = key_start + 9
else
local content_start = val_marker_end + 1
local content_end = content_start
while true do
content_end = string_find(body, "\"", content_end, true)
if not content_end then break end
if not content_end or content_end > scan_limit then
content_end = nil
break
end

-- Count backslashes before this quote
local bs_count = 0
local p = content_end - 1
while p >= content_start and string_sub(body, p, p) == "\\" do
while p >= content_start and string_byte(body, p) == 92 do
bs_count = bs_count + 1
p = p - 1
end
Expand All @@ -219,7 +225,9 @@ local function _simple_word_estimate(request_context)
end
end

return ceil(#body / 4)
local total_len = #body
if total_len > MAX_BODY_SCAN_BYTES then total_len = MAX_BODY_SCAN_BYTES end
return ceil(total_len / 4)
end

local function _extract_max_tokens(body)
Expand Down
9 changes: 7 additions & 2 deletions src/fairvisor/rule_engine.lua
Original file line number Diff line number Diff line change
Expand Up @@ -459,17 +459,22 @@ local function _is_override_active(block, now)
return false
end

local expires_epoch = block._expires_epoch
if expires_epoch then
return expires_epoch > now
end

local expires_at = block.expires_at
if type(expires_at) ~= "string" or expires_at == "" then
return false
end

local expires_epoch, parse_err = utils.parse_iso8601(expires_at)
local parsed_epoch, parse_err = utils.parse_iso8601(expires_at)
if parse_err ~= nil then
return false
end

return expires_epoch > now
return parsed_epoch > now
end

local function _maybe_log_override_state(flags)
Expand Down
14 changes: 11 additions & 3 deletions src/fairvisor/saas_client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,22 @@ local function _http_client()
end

local function _auth_header()
if _state.auth_header then
return _state.auth_header
end

local token = _state.config.edge_token
if type(token) ~= "string" then
return "Bearer "
_state.auth_header = "Bearer "
return _state.auth_header
end
-- Defensive: reject tokens that could inject CR/LF into HTTP headers.
if token:find("[\r\n]") then
return "Bearer "
_state.auth_header = "Bearer "
return _state.auth_header
end
return "Bearer " .. token
_state.auth_header = "Bearer " .. token
return _state.auth_header
end

local function _is_non_retriable_status(status)
Expand Down Expand Up @@ -697,6 +704,7 @@ local function _reset_state(config, deps)
_state.register_attempt = 0
_state.register_next_retry_at = 0
_state.last_config_poll_at = 0
_state.auth_header = nil
end

local function _validate_config(config)
Expand Down
Loading