JSON Constrained Decoding in Transformers (Without SGLang)
The mechanism is the same as SGLang's response_format — mask invalid tokens at each step — but you hook into model.generate() yourself.
The Hook: `prefix_allowed_tokens_fn`
Every HuggingFace generate() call accepts this parameter:
outputs = model.generate(
input_ids,
prefix_allowed_tokens_fn=fn, # called at every token step
)fn(batch_id, input_ids_so_far) must return a list of allowed token IDs for the next position. If only {, ", t, ... are valid JSON continuations, all other tokens get -inf logits and can never be sampled.
Option A — `lm-format-enforcer` (Simplest, Drop-in)
pip install lm-format-enforcerfrom lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import (
build_transformers_prefix_allowed_tokens_fn
)
schema = {
"type": "object",
"properties": {
"transcript": {"type": "string"},
"language": {"type": "string"},
"confidence": {"type": "string", "enum": ["high", "medium", "low"]}
},
"required": ["transcript", "language", "confidence"],
"additionalProperties": False
}
parser = JsonSchemaParser(schema)
prefix_fn = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
outputs = model.generate(
input_ids,
prefix_allowed_tokens_fn=prefix_fn,
max_new_tokens=512,
)Works with any HF model. No schema compilation step needed.
Option B — `outlines` (More Powerful, Supports Regex Too)
pip install outlinesimport outlines
from pydantic import BaseModel
from typing import List
class AudioAnalysis(BaseModel):
transcript: str
summary: str
language: str
sentiment: str # could use Literal["positive","negative","neutral","mixed"]
key_points: List[str]
# Load model through outlines (wraps HF model internally)
model = outlines.models.transformers("Qwen/Qwen2-Audio-7B-Instruct")
generator = outlines.generate.json(model, AudioAnalysis)
result = generator(prompt) # result is already a typed Pydantic object
print(result.transcript)
print(result.key_points)The advantage: result is a Pydantic object, not a string to parse. Type-safe out of the box.
Option C — Manual `LogitsProcessor` (No Extra Dependencies)
If you don't want extra libraries, write a processor using the json module's incremental decoder:
import json
import torch
from transformers import LogitsProcessor
class JSONConstraintProcessor(LogitsProcessor):
"""
Naive approach: force the output to start with {.
Not true constrained decoding — use Option A or B for real enforcement.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.open_id = tokenizer.encode("{", add_special_tokens=False)[0]
self.close_id = tokenizer.encode("}", add_special_tokens=False)[0]
def __call__(self, input_ids, scores):
# Force first token to be `{`
if input_ids.shape[1] == 0:
mask = torch.full_like(scores, float("-inf"))
mask[:, self.open_id] = 0
return mask
return scores # let rest generate freelyNote: This is not true constrained decoding — just a nudge. Use Option A or B for real schema enforcement.
How It Fits Into Your Server
Replacing the SGLang call in server.py / server_json.py:
# Instead of:
raw = await _call_sglang_json(messages, schema, ...)
# You'd do (in a thread to not block the event loop):
import asyncio, functools
def _infer(input_ids, schema):
parser = JsonSchemaParser(schema)
prefix_fn = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
out = model.generate(
input_ids,
prefix_allowed_tokens_fn=prefix_fn,
max_new_tokens=512,
temperature=0.3,
do_sample=True,
)
return tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
raw = await asyncio.to_thread(functools.partial(_infer, input_ids, schema))Comparison
SGLang (response_format) | lm-format-enforcer | outlines | |
|---|---|---|---|
| Extra install | None (built-in) | pip install lm-format-enforcer | pip install outlines |
| Works with audio models | Yes | Yes | Yes (via HF) |
| Returns | String | String | Pydantic object |
| Speed overhead | ~5–10% | ~10–15% | ~10–20% |
| Custom regex support | No | Yes | Yes |
| Pydantic schema | No (dict only) | No (dict only) | Yes |
Recommendation
Use lm-format-enforcer for a quick drop-in with your existing transformers-based servers (Qwen2.5-Omni, Gemma 4). It requires the fewest changes — just wrap the generate() call.
Use outlines if you want Pydantic type safety and regex-constrained generation beyond JSON.
Avoid the manual LogitsProcessor approach for production — it only nudges the model, not truly constrains it.
