JSON Constrained Decoding in Transformers (Without SGLang)

The mechanism is the same as SGLang's response_format — mask invalid tokens at each step — but you hook into model.generate() yourself.


The Hook: `prefix_allowed_tokens_fn`

Every HuggingFace generate() call accepts this parameter:

python
outputs = model.generate(
    input_ids,
    prefix_allowed_tokens_fn=fn,  # called at every token step
)

fn(batch_id, input_ids_so_far) must return a list of allowed token IDs for the next position. If only {, ", t, ... are valid JSON continuations, all other tokens get -inf logits and can never be sampled.


Option A — `lm-format-enforcer` (Simplest, Drop-in)

bash
pip install lm-format-enforcer
python
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import (
    build_transformers_prefix_allowed_tokens_fn
)

schema = {
    "type": "object",
    "properties": {
        "transcript":  {"type": "string"},
        "language":    {"type": "string"},
        "confidence":  {"type": "string", "enum": ["high", "medium", "low"]}
    },
    "required": ["transcript", "language", "confidence"],
    "additionalProperties": False
}

parser    = JsonSchemaParser(schema)
prefix_fn = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)

outputs = model.generate(
    input_ids,
    prefix_allowed_tokens_fn=prefix_fn,
    max_new_tokens=512,
)

Works with any HF model. No schema compilation step needed.


Option B — `outlines` (More Powerful, Supports Regex Too)

bash
pip install outlines
python
import outlines
from pydantic import BaseModel
from typing import List

class AudioAnalysis(BaseModel):
    transcript: str
    summary:    str
    language:   str
    sentiment:  str       # could use Literal["positive","negative","neutral","mixed"]
    key_points: List[str]

# Load model through outlines (wraps HF model internally)
model     = outlines.models.transformers("Qwen/Qwen2-Audio-7B-Instruct")
generator = outlines.generate.json(model, AudioAnalysis)

result = generator(prompt)   # result is already a typed Pydantic object
print(result.transcript)
print(result.key_points)

The advantage: result is a Pydantic object, not a string to parse. Type-safe out of the box.


Option C — Manual `LogitsProcessor` (No Extra Dependencies)

If you don't want extra libraries, write a processor using the json module's incremental decoder:

python
import json
import torch
from transformers import LogitsProcessor

class JSONConstraintProcessor(LogitsProcessor):
    """
    Naive approach: force the output to start with {.
    Not true constrained decoding — use Option A or B for real enforcement.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.open_id  = tokenizer.encode("{", add_special_tokens=False)[0]
        self.close_id = tokenizer.encode("}", add_special_tokens=False)[0]

    def __call__(self, input_ids, scores):
        # Force first token to be `{`
        if input_ids.shape[1] == 0:
            mask = torch.full_like(scores, float("-inf"))
            mask[:, self.open_id] = 0
            return mask
        return scores   # let rest generate freely
Note: This is not true constrained decoding — just a nudge. Use Option A or B for real schema enforcement.

How It Fits Into Your Server

Replacing the SGLang call in server.py / server_json.py:

python
# Instead of:
raw = await _call_sglang_json(messages, schema, ...)

# You'd do (in a thread to not block the event loop):
import asyncio, functools

def _infer(input_ids, schema):
    parser    = JsonSchemaParser(schema)
    prefix_fn = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
    out = model.generate(
        input_ids,
        prefix_allowed_tokens_fn=prefix_fn,
        max_new_tokens=512,
        temperature=0.3,
        do_sample=True,
    )
    return tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)

raw = await asyncio.to_thread(functools.partial(_infer, input_ids, schema))

Comparison

SGLang (response_format)lm-format-enforceroutlines
Extra installNone (built-in)pip install lm-format-enforcerpip install outlines
Works with audio modelsYesYesYes (via HF)
ReturnsStringStringPydantic object
Speed overhead~5–10%~10–15%~10–20%
Custom regex supportNoYesYes
Pydantic schemaNo (dict only)No (dict only)Yes

Recommendation

Use lm-format-enforcer for a quick drop-in with your existing transformers-based servers (Qwen2.5-Omni, Gemma 4). It requires the fewest changes — just wrap the generate() call.

Use outlines if you want Pydantic type safety and regex-constrained generation beyond JSON.

Avoid the manual LogitsProcessor approach for production — it only nudges the model, not truly constrains it.