Files
aish/broker.lua
T
marfrit 74e4bffb37 broker + repl + safety: GBNF grammar-sampling passthrough (closes #88)
llama.cpp constrains the sampler to ONLY emit tokens matching a
GBNF grammar. For small models this kills format drift at the
token level — `CMD: <cmd>` is enforced by the sampler rather than
hoped for via prompt discipline.

Probe finding (this commit's pre-implementation): cloud (Anthropic
via Bedrock) silently IGNORES the `grammar` field — returns normally
via standard sampling. Default passthrough is safe for all routes;
no per-model opt-in/opt-out needed in v1.

Changes:

- broker.lua build_request: `if opts.grammar then req.grammar =
  opts.grammar end`. Misformed grammar surfaces at request time
  via the existing transport-error path.

- repl.lua ask_ai: `grammar_override = config.routing.grammars
  [req_class]` (same gating shape as #86's system_prompts override).
  Passed via opts.grammar in the call_broker invocation.

- safety.lua is_destructive threads cfg.safety.probe_grammar through
  opts.grammar so llm_probe constrains the YES/NO output. Skips
  the regex-match dance entirely when the model can't drift.
  Caller-provided opts.grammar takes precedence over cfg.

- config.lua gains two commented examples:
  * routing.grammars per class
  * safety.probe_grammar for the destructive probe

6 unit cases verified (stubbed curl.post_sse / broker.chat):
  - default: no grammar in body
  - opts.grammar -> body contains grammar JSON-encoded
  - safety probe_grammar reaches llm_probe via opts
  - no probe_grammar configured -> opts.grammar nil
  - caller opts.grammar takes precedence over cfg.safety.probe_grammar

E2E against live local broker:
  - `routing.grammars.default = "root ::= \\"ACK\\""` configured;
    prompted "tell me a long story about a fox" -> model output
    EXACTLY "ACK" (sampler forced; would normally produce paragraphs).
    Grammar passthrough end-to-end confirmed.

Regression: test_safety 87/87, test_router_model 31/31, repl loads.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-17 07:00:36 +00:00

295 lines
13 KiB
Lua

-- broker.lua — llama.cpp HTTP client.
-- Phase 0: blocking POST via ffi/curl + vendored dkjson.
-- Phase 1: streaming POST via ffi/curl.post_sse with an OpenAI-shape decoder
-- on top. M.chat becomes a thin buffering wrapper around M.chat_stream so the
-- one streaming path covers both incremental and sync callers.
-- Phase 2: optional opts.tools array passed through to the request body
-- (omitted entirely when nil/empty per §12 risk row 1). The chat_stream
-- on_delta callback widens to (kind, payload) where kind is "text" or
-- "tool_call"; tool_call deltas are accumulated by `index` (default 0 if
-- absent per C2) and emitted as complete records on finish_reason "tool_calls".
-- broker.lua does NOT depend on mcp.lua — the caller assembles opts.tools
-- and passes it in. See docs/PHASE0.md §6, PHASE1.md §3, PHASE2.md §3 / §5.
local curl = require("ffi.curl")
local json = require("dkjson")
local M = {}
local function build_headers(model_cfg)
local h = { "Content-Type: application/json" }
if model_cfg.key_env then
local key = os.getenv(model_cfg.key_env)
if key and key ~= "" then
h[#h + 1] = "Authorization: Bearer " .. key
end
end
return h
end
-- Phase 7 (A3): build_request widens to take an opts table; previously
-- positional (tools, max_tokens). Both internal call sites (chat_stream
-- and M.chat-via-chat_stream) updated. opts fields:
-- .tools per Phase 2 (omitted from body when nil/empty)
-- .max_tokens per Phase 3 (omitted when nil)
-- .include_usage Phase 7 — default true; sets stream_options.include_usage
-- in the request body (B1: required for local llama.cpp
-- to emit usage; no-op for cloud which emits anyway).
local function build_request(model_cfg, messages, stream, opts)
if not (model_cfg and model_cfg.endpoint and model_cfg.model) then
return nil, "broker: model_cfg.endpoint and .model are required"
end
opts = opts or {}
local url = model_cfg.endpoint:gsub("/+$", "") .. "/v1/chat/completions"
local req = {
model = model_cfg.model,
messages = messages,
stream = stream and true or false,
temperature = model_cfg.temperature or 0.2,
}
-- Per PHASE2.md §12 risk row "Empty tools array": some servers reject
-- "tools": []. Only set the field when the list has entries.
if opts.tools and #opts.tools > 0 then req.tools = opts.tools end
-- Phase 3 (A2): max_tokens passthrough — used by safety.is_destructive
-- to cap YES/NO probes at ~4 tokens. Omitted when nil (Phase 1/2
-- callers unaffected — model defaults still apply).
if opts.max_tokens then req.max_tokens = opts.max_tokens end
-- Phase 7 (B1): default ON for streaming requests; the flag is
-- required to make local llama.cpp emit usage. Cloud honors it as
-- a no-op (emits usage with or without). Per-call opt-out:
-- opts.include_usage = false.
if stream and opts.include_usage ~= false then
req.stream_options = { include_usage = true }
end
-- #88: GBNF grammar passthrough. llama.cpp constrains the sampler
-- to only emit tokens matching the grammar — eliminates format
-- drift on small models. Probed cloud (Anthropic via Bedrock)
-- silently ignores the field, so default passthrough is safe;
-- no per-model opt-out needed in v1. Misformed grammar produces
-- a broker error at request time (visible via the usual transport
-- error path).
if opts.grammar then req.grammar = opts.grammar end
return url, json.encode(req), build_headers(model_cfg),
(model_cfg.timeout_ms or 60000)
end
-- Streaming /v1/chat/completions.
-- Signature widens vs Phase 1: opts is optional and may carry .tools.
-- Phase 7 adds .include_usage (default true) + .category (echoed into
-- the emitted usage payload for caller-side accumulator tagging).
-- on_delta is called as on_delta(kind, payload):
-- on_delta("text", content_string) - per text chunk
-- on_delta("tool_call", { id, name, arguments }) - once per completed
-- tool call (on finish_reason "tool_calls").
-- on_delta("usage", { prompt_tokens, completion_tokens,
-- total_tokens, cost, model, category })
-- - Phase 7: emitted once after the stream
-- completes successfully, IF the provider sent
-- a usage block. Skipped on transport / API
-- errors. model is model_cfg.model (caller-
-- stable per B4 + R2); cost is nil for
-- providers that don't emit it (local llama.cpp);
-- category is opts.category or "main".
-- Returns:
-- true stream ended cleanly
-- nil, errmsg transport / API failure
function M.chat_stream(model_cfg, messages, on_delta, opts)
opts = opts or {}
local url, body, headers, timeout_ms =
build_request(model_cfg, messages, true, opts)
if not url then return nil, body end -- url slot carries err on bad cfg
-- Phase 3: opts.timeout_ms overrides the model's default. Used by
-- safety.is_destructive's LLM probe to cap YES/NO checks at ~15s even
-- when the model's normal timeout is much higher (e.g. user's deep
-- model has 1800000ms for long generations).
if opts.timeout_ms then timeout_ms = opts.timeout_ms end
local done = false
local api_err
-- Tool-call accumulator keyed by index. Each slot is filled across
-- many deltas: id+name come on the opener, arguments arrives as
-- character-fragment JSON-string chunks (PHASE2-baseline.md §4).
local tc_by_index = {}
local tc_index_order = {} -- preserve emission order
local index_absent_warned = false
-- Phase 7: usage captured from the final SSE chunk (per B2 either
-- on a non-empty-choices chunk with finish_reason — cloud, or on a
-- choices=[] chunk before [DONE] — local). Emitted as
-- on_delta("usage", ...) AFTER curl.post_sse returns (B5).
local final_usage = nil
local function on_event(data)
if done then return end
if data == "[DONE]" then done = true; return end
local doc = json.decode(data)
if not doc then return end -- ignore unparseable events
-- Some servers emit an SSE-framed error envelope at the start of the
-- stream — surface it.
if doc.error then
local m = (type(doc.error) == "table" and doc.error.message)
or tostring(doc.error)
api_err = m
done = true
return
end
-- N1: usage branch is INDEPENDENT of the choice/delta branches.
-- Check unconditionally — local emits usage on choices=[] chunks
-- where `choice` is nil; cloud emits with non-empty choices.
-- R2: payload.model is the caller-stable model_cfg.model (upvar),
-- so call_broker's fallback retry naturally credits the right
-- model — wrapper callers key by payload.model.
if doc.usage then
final_usage = {
prompt_tokens = doc.usage.prompt_tokens or 0,
completion_tokens = doc.usage.completion_tokens or 0,
total_tokens = doc.usage.total_tokens or 0,
cost = doc.usage.cost, -- nil for local (R6 preserves nil)
model = model_cfg.model, -- caller-stable per B4/R2
category = opts.category or "main",
}
-- Don't emit yet; fired after curl.post_sse returns.
end
local choice = doc.choices and doc.choices[1]
local delta = choice and choice.delta
-- Text path (unchanged from Phase 1 semantics; kind widened).
local content = delta and delta.content
if type(content) == "string" and #content > 0 then
on_delta("text", content)
end
-- Tool-call accumulation (Phase 2).
local tcs = delta and delta.tool_calls
if type(tcs) == "table" then
for _, tc in ipairs(tcs) do
local idx = tc.index
if idx == nil then
idx = 0
if not index_absent_warned then
index_absent_warned = true
-- One-shot debug status per stream; printed to
-- stderr so it doesn't interleave with renderer
-- stdout output.
io.stderr:write(
"[aish] broker: tool_calls[].index absent; assuming 0\n")
end
end
local slot = tc_by_index[idx]
if not slot then
slot = { id = nil, name = nil, arguments = "" }
tc_by_index[idx] = slot
tc_index_order[#tc_index_order + 1] = idx
end
if tc.id then slot.id = tc.id end
if tc["function"] then
local fn = tc["function"]
if fn.name then slot.name = fn.name end
if fn.arguments then
slot.arguments = slot.arguments .. fn.arguments
end
end
end
end
-- On finish_reason "tool_calls", emit all accumulated calls.
if choice and choice.finish_reason == "tool_calls" then
for _, idx in ipairs(tc_index_order) do
on_delta("tool_call", tc_by_index[idx])
end
tc_by_index = {}
tc_index_order = {}
end
end
local ok, err = curl.post_sse(url, body, headers, on_event, timeout_ms)
if api_err then return nil, "api: " .. api_err end
if not ok then return nil, "transport: " .. tostring(err) end
-- Phase 7 (B5): emit captured usage AFTER stream completes, as the
-- last event in stream order. Skipped on transport/api errors (the
-- accumulator stays unchanged for the failed call).
if final_usage then on_delta("usage", final_usage) end
return true
end
-- Send a /v1/chat/completions request and return the full assistant text.
-- Thin buffering wrapper over M.chat_stream — same path as the streaming
-- consumer, so the broker keeps one HTTP shape (stream:true always).
-- M.chat's external contract widens in Phase 7 (R1): now returns
-- (text, usage). Existing callers that ignore the second value continue
-- to work — Lua silently drops extra return values. Callers that want
-- cost/usage tracking do `local r, u = broker.chat(...)` and route u
-- to ctx:add_usage via the central _record_usage helper.
-- Tool-call kinds are still silently ignored (no caller of M.chat
-- passes opts.tools).
-- Returns:
-- text, usage on success (usage may be nil if
-- the provider didn't emit one)
-- nil, errmsg on transport / decode / API failure
function M.chat(model_cfg, messages, opts)
local parts = {}
local captured_usage -- R1: required so callers see usage
local ok, err = M.chat_stream(model_cfg, messages, function(kind, payload)
if kind == "text" then parts[#parts + 1] = payload
elseif kind == "usage" then captured_usage = payload
end
end, opts)
if not ok then return nil, err end
return table.concat(parts), captured_usage
end
-- ---------------------------------------------------------------- token_count (Phase 8)
-- Returns an accurate token count by hitting <endpoint>/tokenize when
-- the endpoint supports it; falls back to the Phase 0 §8 char/4
-- heuristic otherwise. Per-endpoint capability cache (session-local;
-- key per R6 is endpoint-only since B1 confirms /tokenize ignores the
-- model field on the observed broker).
--
-- Never errors. Returns a non-negative integer.
-- 2s timeout per call so a misbehaving endpoint can't stall the
-- caller; first miss caches as unsupported for the session.
local _tokenize_capable = {} -- [endpoint] = true | false (nil = unprobed)
function M.token_count(model_cfg, text)
text = text or ""
if text == "" then return 0 end
if not (model_cfg and model_cfg.endpoint) then
return math.floor(#text / 4)
end
local ep = model_cfg.endpoint
local cap = _tokenize_capable[ep]
if cap == false then
return math.floor(#text / 4)
end
local url = ep:gsub("/+$", "") .. "/tokenize"
local body = json.encode({ content = text, model = model_cfg.model })
local out, status = curl.post(url, body,
{ "Content-Type: application/json" },
2000) -- 2s timeout per R5 risk row
if not (status == 200 and out) then
_tokenize_capable[ep] = false
return math.floor(#text / 4)
end
local doc = json.decode(out)
local toks = doc and doc.tokens
if type(toks) ~= "table" then
_tokenize_capable[ep] = false
return math.floor(#text / 4)
end
_tokenize_capable[ep] = true
return #toks
end
-- Introspection: nil if endpoint un-probed; true/false for the cached
-- capability. Used by tests and future :tokenize debug meta.
function M.tokenize_supported(model_cfg)
if not (model_cfg and model_cfg.endpoint) then return nil end
return _tokenize_capable[model_cfg.endpoint]
end
-- Test hook: reset the cache between LuaJIT-VM-shared test runs.
function M._reset_tokenize_cache()
_tokenize_capable = {}
end
return M