74e4bffb37
llama.cpp constrains the sampler to ONLY emit tokens matching a GBNF grammar. For small models this kills format drift at the token level — `CMD: <cmd>` is enforced by the sampler rather than hoped for via prompt discipline. Probe finding (this commit's pre-implementation): cloud (Anthropic via Bedrock) silently IGNORES the `grammar` field — returns normally via standard sampling. Default passthrough is safe for all routes; no per-model opt-in/opt-out needed in v1. Changes: - broker.lua build_request: `if opts.grammar then req.grammar = opts.grammar end`. Misformed grammar surfaces at request time via the existing transport-error path. - repl.lua ask_ai: `grammar_override = config.routing.grammars [req_class]` (same gating shape as #86's system_prompts override). Passed via opts.grammar in the call_broker invocation. - safety.lua is_destructive threads cfg.safety.probe_grammar through opts.grammar so llm_probe constrains the YES/NO output. Skips the regex-match dance entirely when the model can't drift. Caller-provided opts.grammar takes precedence over cfg. - config.lua gains two commented examples: * routing.grammars per class * safety.probe_grammar for the destructive probe 6 unit cases verified (stubbed curl.post_sse / broker.chat): - default: no grammar in body - opts.grammar -> body contains grammar JSON-encoded - safety probe_grammar reaches llm_probe via opts - no probe_grammar configured -> opts.grammar nil - caller opts.grammar takes precedence over cfg.safety.probe_grammar E2E against live local broker: - `routing.grammars.default = "root ::= \\"ACK\\""` configured; prompted "tell me a long story about a fox" -> model output EXACTLY "ACK" (sampler forced; would normally produce paragraphs). Grammar passthrough end-to-end confirmed. Regression: test_safety 87/87, test_router_model 31/31, repl loads. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
295 lines
13 KiB
Lua
295 lines
13 KiB
Lua
-- broker.lua — llama.cpp HTTP client.
|
|
-- Phase 0: blocking POST via ffi/curl + vendored dkjson.
|
|
-- Phase 1: streaming POST via ffi/curl.post_sse with an OpenAI-shape decoder
|
|
-- on top. M.chat becomes a thin buffering wrapper around M.chat_stream so the
|
|
-- one streaming path covers both incremental and sync callers.
|
|
-- Phase 2: optional opts.tools array passed through to the request body
|
|
-- (omitted entirely when nil/empty per §12 risk row 1). The chat_stream
|
|
-- on_delta callback widens to (kind, payload) where kind is "text" or
|
|
-- "tool_call"; tool_call deltas are accumulated by `index` (default 0 if
|
|
-- absent per C2) and emitted as complete records on finish_reason "tool_calls".
|
|
-- broker.lua does NOT depend on mcp.lua — the caller assembles opts.tools
|
|
-- and passes it in. See docs/PHASE0.md §6, PHASE1.md §3, PHASE2.md §3 / §5.
|
|
|
|
local curl = require("ffi.curl")
|
|
local json = require("dkjson")
|
|
|
|
local M = {}
|
|
|
|
local function build_headers(model_cfg)
|
|
local h = { "Content-Type: application/json" }
|
|
if model_cfg.key_env then
|
|
local key = os.getenv(model_cfg.key_env)
|
|
if key and key ~= "" then
|
|
h[#h + 1] = "Authorization: Bearer " .. key
|
|
end
|
|
end
|
|
return h
|
|
end
|
|
|
|
-- Phase 7 (A3): build_request widens to take an opts table; previously
|
|
-- positional (tools, max_tokens). Both internal call sites (chat_stream
|
|
-- and M.chat-via-chat_stream) updated. opts fields:
|
|
-- .tools per Phase 2 (omitted from body when nil/empty)
|
|
-- .max_tokens per Phase 3 (omitted when nil)
|
|
-- .include_usage Phase 7 — default true; sets stream_options.include_usage
|
|
-- in the request body (B1: required for local llama.cpp
|
|
-- to emit usage; no-op for cloud which emits anyway).
|
|
local function build_request(model_cfg, messages, stream, opts)
|
|
if not (model_cfg and model_cfg.endpoint and model_cfg.model) then
|
|
return nil, "broker: model_cfg.endpoint and .model are required"
|
|
end
|
|
opts = opts or {}
|
|
local url = model_cfg.endpoint:gsub("/+$", "") .. "/v1/chat/completions"
|
|
local req = {
|
|
model = model_cfg.model,
|
|
messages = messages,
|
|
stream = stream and true or false,
|
|
temperature = model_cfg.temperature or 0.2,
|
|
}
|
|
-- Per PHASE2.md §12 risk row "Empty tools array": some servers reject
|
|
-- "tools": []. Only set the field when the list has entries.
|
|
if opts.tools and #opts.tools > 0 then req.tools = opts.tools end
|
|
-- Phase 3 (A2): max_tokens passthrough — used by safety.is_destructive
|
|
-- to cap YES/NO probes at ~4 tokens. Omitted when nil (Phase 1/2
|
|
-- callers unaffected — model defaults still apply).
|
|
if opts.max_tokens then req.max_tokens = opts.max_tokens end
|
|
-- Phase 7 (B1): default ON for streaming requests; the flag is
|
|
-- required to make local llama.cpp emit usage. Cloud honors it as
|
|
-- a no-op (emits usage with or without). Per-call opt-out:
|
|
-- opts.include_usage = false.
|
|
if stream and opts.include_usage ~= false then
|
|
req.stream_options = { include_usage = true }
|
|
end
|
|
-- #88: GBNF grammar passthrough. llama.cpp constrains the sampler
|
|
-- to only emit tokens matching the grammar — eliminates format
|
|
-- drift on small models. Probed cloud (Anthropic via Bedrock)
|
|
-- silently ignores the field, so default passthrough is safe;
|
|
-- no per-model opt-out needed in v1. Misformed grammar produces
|
|
-- a broker error at request time (visible via the usual transport
|
|
-- error path).
|
|
if opts.grammar then req.grammar = opts.grammar end
|
|
return url, json.encode(req), build_headers(model_cfg),
|
|
(model_cfg.timeout_ms or 60000)
|
|
end
|
|
|
|
-- Streaming /v1/chat/completions.
|
|
-- Signature widens vs Phase 1: opts is optional and may carry .tools.
|
|
-- Phase 7 adds .include_usage (default true) + .category (echoed into
|
|
-- the emitted usage payload for caller-side accumulator tagging).
|
|
-- on_delta is called as on_delta(kind, payload):
|
|
-- on_delta("text", content_string) - per text chunk
|
|
-- on_delta("tool_call", { id, name, arguments }) - once per completed
|
|
-- tool call (on finish_reason "tool_calls").
|
|
-- on_delta("usage", { prompt_tokens, completion_tokens,
|
|
-- total_tokens, cost, model, category })
|
|
-- - Phase 7: emitted once after the stream
|
|
-- completes successfully, IF the provider sent
|
|
-- a usage block. Skipped on transport / API
|
|
-- errors. model is model_cfg.model (caller-
|
|
-- stable per B4 + R2); cost is nil for
|
|
-- providers that don't emit it (local llama.cpp);
|
|
-- category is opts.category or "main".
|
|
-- Returns:
|
|
-- true stream ended cleanly
|
|
-- nil, errmsg transport / API failure
|
|
function M.chat_stream(model_cfg, messages, on_delta, opts)
|
|
opts = opts or {}
|
|
local url, body, headers, timeout_ms =
|
|
build_request(model_cfg, messages, true, opts)
|
|
if not url then return nil, body end -- url slot carries err on bad cfg
|
|
-- Phase 3: opts.timeout_ms overrides the model's default. Used by
|
|
-- safety.is_destructive's LLM probe to cap YES/NO checks at ~15s even
|
|
-- when the model's normal timeout is much higher (e.g. user's deep
|
|
-- model has 1800000ms for long generations).
|
|
if opts.timeout_ms then timeout_ms = opts.timeout_ms end
|
|
|
|
local done = false
|
|
local api_err
|
|
-- Tool-call accumulator keyed by index. Each slot is filled across
|
|
-- many deltas: id+name come on the opener, arguments arrives as
|
|
-- character-fragment JSON-string chunks (PHASE2-baseline.md §4).
|
|
local tc_by_index = {}
|
|
local tc_index_order = {} -- preserve emission order
|
|
local index_absent_warned = false
|
|
-- Phase 7: usage captured from the final SSE chunk (per B2 either
|
|
-- on a non-empty-choices chunk with finish_reason — cloud, or on a
|
|
-- choices=[] chunk before [DONE] — local). Emitted as
|
|
-- on_delta("usage", ...) AFTER curl.post_sse returns (B5).
|
|
local final_usage = nil
|
|
|
|
local function on_event(data)
|
|
if done then return end
|
|
if data == "[DONE]" then done = true; return end
|
|
local doc = json.decode(data)
|
|
if not doc then return end -- ignore unparseable events
|
|
-- Some servers emit an SSE-framed error envelope at the start of the
|
|
-- stream — surface it.
|
|
if doc.error then
|
|
local m = (type(doc.error) == "table" and doc.error.message)
|
|
or tostring(doc.error)
|
|
api_err = m
|
|
done = true
|
|
return
|
|
end
|
|
-- N1: usage branch is INDEPENDENT of the choice/delta branches.
|
|
-- Check unconditionally — local emits usage on choices=[] chunks
|
|
-- where `choice` is nil; cloud emits with non-empty choices.
|
|
-- R2: payload.model is the caller-stable model_cfg.model (upvar),
|
|
-- so call_broker's fallback retry naturally credits the right
|
|
-- model — wrapper callers key by payload.model.
|
|
if doc.usage then
|
|
final_usage = {
|
|
prompt_tokens = doc.usage.prompt_tokens or 0,
|
|
completion_tokens = doc.usage.completion_tokens or 0,
|
|
total_tokens = doc.usage.total_tokens or 0,
|
|
cost = doc.usage.cost, -- nil for local (R6 preserves nil)
|
|
model = model_cfg.model, -- caller-stable per B4/R2
|
|
category = opts.category or "main",
|
|
}
|
|
-- Don't emit yet; fired after curl.post_sse returns.
|
|
end
|
|
local choice = doc.choices and doc.choices[1]
|
|
local delta = choice and choice.delta
|
|
|
|
-- Text path (unchanged from Phase 1 semantics; kind widened).
|
|
local content = delta and delta.content
|
|
if type(content) == "string" and #content > 0 then
|
|
on_delta("text", content)
|
|
end
|
|
|
|
-- Tool-call accumulation (Phase 2).
|
|
local tcs = delta and delta.tool_calls
|
|
if type(tcs) == "table" then
|
|
for _, tc in ipairs(tcs) do
|
|
local idx = tc.index
|
|
if idx == nil then
|
|
idx = 0
|
|
if not index_absent_warned then
|
|
index_absent_warned = true
|
|
-- One-shot debug status per stream; printed to
|
|
-- stderr so it doesn't interleave with renderer
|
|
-- stdout output.
|
|
io.stderr:write(
|
|
"[aish] broker: tool_calls[].index absent; assuming 0\n")
|
|
end
|
|
end
|
|
local slot = tc_by_index[idx]
|
|
if not slot then
|
|
slot = { id = nil, name = nil, arguments = "" }
|
|
tc_by_index[idx] = slot
|
|
tc_index_order[#tc_index_order + 1] = idx
|
|
end
|
|
if tc.id then slot.id = tc.id end
|
|
if tc["function"] then
|
|
local fn = tc["function"]
|
|
if fn.name then slot.name = fn.name end
|
|
if fn.arguments then
|
|
slot.arguments = slot.arguments .. fn.arguments
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
-- On finish_reason "tool_calls", emit all accumulated calls.
|
|
if choice and choice.finish_reason == "tool_calls" then
|
|
for _, idx in ipairs(tc_index_order) do
|
|
on_delta("tool_call", tc_by_index[idx])
|
|
end
|
|
tc_by_index = {}
|
|
tc_index_order = {}
|
|
end
|
|
end
|
|
|
|
local ok, err = curl.post_sse(url, body, headers, on_event, timeout_ms)
|
|
if api_err then return nil, "api: " .. api_err end
|
|
if not ok then return nil, "transport: " .. tostring(err) end
|
|
-- Phase 7 (B5): emit captured usage AFTER stream completes, as the
|
|
-- last event in stream order. Skipped on transport/api errors (the
|
|
-- accumulator stays unchanged for the failed call).
|
|
if final_usage then on_delta("usage", final_usage) end
|
|
return true
|
|
end
|
|
|
|
-- Send a /v1/chat/completions request and return the full assistant text.
|
|
-- Thin buffering wrapper over M.chat_stream — same path as the streaming
|
|
-- consumer, so the broker keeps one HTTP shape (stream:true always).
|
|
-- M.chat's external contract widens in Phase 7 (R1): now returns
|
|
-- (text, usage). Existing callers that ignore the second value continue
|
|
-- to work — Lua silently drops extra return values. Callers that want
|
|
-- cost/usage tracking do `local r, u = broker.chat(...)` and route u
|
|
-- to ctx:add_usage via the central _record_usage helper.
|
|
-- Tool-call kinds are still silently ignored (no caller of M.chat
|
|
-- passes opts.tools).
|
|
-- Returns:
|
|
-- text, usage on success (usage may be nil if
|
|
-- the provider didn't emit one)
|
|
-- nil, errmsg on transport / decode / API failure
|
|
function M.chat(model_cfg, messages, opts)
|
|
local parts = {}
|
|
local captured_usage -- R1: required so callers see usage
|
|
local ok, err = M.chat_stream(model_cfg, messages, function(kind, payload)
|
|
if kind == "text" then parts[#parts + 1] = payload
|
|
elseif kind == "usage" then captured_usage = payload
|
|
end
|
|
end, opts)
|
|
if not ok then return nil, err end
|
|
return table.concat(parts), captured_usage
|
|
end
|
|
|
|
-- ---------------------------------------------------------------- token_count (Phase 8)
|
|
-- Returns an accurate token count by hitting <endpoint>/tokenize when
|
|
-- the endpoint supports it; falls back to the Phase 0 §8 char/4
|
|
-- heuristic otherwise. Per-endpoint capability cache (session-local;
|
|
-- key per R6 is endpoint-only since B1 confirms /tokenize ignores the
|
|
-- model field on the observed broker).
|
|
--
|
|
-- Never errors. Returns a non-negative integer.
|
|
-- 2s timeout per call so a misbehaving endpoint can't stall the
|
|
-- caller; first miss caches as unsupported for the session.
|
|
local _tokenize_capable = {} -- [endpoint] = true | false (nil = unprobed)
|
|
|
|
function M.token_count(model_cfg, text)
|
|
text = text or ""
|
|
if text == "" then return 0 end
|
|
if not (model_cfg and model_cfg.endpoint) then
|
|
return math.floor(#text / 4)
|
|
end
|
|
local ep = model_cfg.endpoint
|
|
local cap = _tokenize_capable[ep]
|
|
if cap == false then
|
|
return math.floor(#text / 4)
|
|
end
|
|
local url = ep:gsub("/+$", "") .. "/tokenize"
|
|
local body = json.encode({ content = text, model = model_cfg.model })
|
|
local out, status = curl.post(url, body,
|
|
{ "Content-Type: application/json" },
|
|
2000) -- 2s timeout per R5 risk row
|
|
if not (status == 200 and out) then
|
|
_tokenize_capable[ep] = false
|
|
return math.floor(#text / 4)
|
|
end
|
|
local doc = json.decode(out)
|
|
local toks = doc and doc.tokens
|
|
if type(toks) ~= "table" then
|
|
_tokenize_capable[ep] = false
|
|
return math.floor(#text / 4)
|
|
end
|
|
_tokenize_capable[ep] = true
|
|
return #toks
|
|
end
|
|
|
|
-- Introspection: nil if endpoint un-probed; true/false for the cached
|
|
-- capability. Used by tests and future :tokenize debug meta.
|
|
function M.tokenize_supported(model_cfg)
|
|
if not (model_cfg and model_cfg.endpoint) then return nil end
|
|
return _tokenize_capable[model_cfg.endpoint]
|
|
end
|
|
|
|
-- Test hook: reset the cache between LuaJIT-VM-shared test runs.
|
|
function M._reset_tokenize_cache()
|
|
_tokenize_capable = {}
|
|
end
|
|
|
|
return M
|