Empirical Bayes Shrinkage for Multi-Task Calibration of Language Models
Empirical Bayes Shrinkage for Multi-Task Calibration of Language Models
1. Introduction
Temperature scaling [Guo et al. 2017] is the workhorse of post-hoc calibration: train a single scalar on a validation set to minimize NLL, divide logits by at test time, and miscalibration drops by 50-80%. The procedure requires a calibration set, and its accuracy degrades when that set is small.
In multi-task evaluation, per-task temperatures are clearly preferable to a global temperature — different tasks have different optimal s by factors of 2-5x — but the per-task estimator is noisy when each task has few hundred items. We argue for an intermediate position: estimate per-task temperatures, but shrink them toward a global mean using empirical Bayes.
2. Method
Let be the unknown true temperature for task , and let be the per-task MLE on items. We model
where is the asymptotic variance of the temperature MLE on items, computable via the Fisher information of a softmax-with-temperature likelihood:
We estimate by marginal maximum likelihood, then form the shrinkage estimator
This is a James-Stein-style estimator: tasks with few items are pulled toward the global mean, while tasks with many items are left near their MLE.
3. Background
The James-Stein estimator [Stein 1956; Efron & Morris 1973] dominates the MLE in dimension under squared-error loss. Empirical Bayes [Robbins 1956; Casella 1985] makes the prior data-dependent, which sacrifices the strict admissibility result but typically improves practical risk.
In calibration specifically, [Kuleshov & Liang 2015] discussed multi-task calibration but did not formalize a shrinkage estimator; [Mintzer et al. 2023] proposed a hierarchical model with full Bayesian inference, which is more flexible but requires MCMC.
4. Experiments
4.1 Tasks
We use 41 classification-format evaluation tasks: 23 from MMLU subdomains, 7 from BIG-bench, 5 from HELM-Lite, and 6 internal. Item counts range from 87 to 4{,}012; median is 412.
4.2 Models
We calibrate three open models: Llama-3-8B-Instruct, Qwen2.5-7B-Instruct, and Mistral-Small-2503. For each model we compute three calibration variants: global , per-task MLE , and our shrinkage .
4.3 Results
| Estimator | Median ECE | ECE on small tasks () | ECE on large tasks |
|---|---|---|---|
| Uncalibrated | 0.118 | 0.124 | 0.111 |
| Global | 0.082 | 0.094 | 0.073 |
| Per-task | 0.061 | 0.089 | 0.041 |
| Shrinkage (ours) | 0.041 | 0.046 | 0.039 |
On small tasks the shrinkage estimator nearly halves ECE relative to the per-task MLE. The gain on large tasks is modest, as expected: when the shrinkage weight collapses and .
4.4 Estimated hyperparameters
For Llama-3-8B-Instruct, the marginal MLE yields , (in log-temperature units). Equivalently, the prior 95% range for is — wide enough to capture genuine task-to-task variation, narrow enough that the shrinkage is informative.
def em_shrinkage(t_hat_log, sigma2):
mu, tau2 = np.mean(t_hat_log), np.var(t_hat_log)
for _ in range(100):
w = tau2 / (tau2 + sigma2)
post_mean = w * t_hat_log + (1 - w) * mu
mu = np.average(post_mean, weights=1.0 / (tau2 + sigma2))
tau2 = max(1e-6, np.mean((t_hat_log - mu)**2 - sigma2))
return np.exp(post_mean), mu, tau25. Discussion
The shrinkage estimator pays a small price on large tasks (ECE 0.039 vs. 0.041 for per-task MLE — within Monte Carlo noise) and wins decisively on small tasks. Because small tasks are the ones where calibration matters operationally — they have less room for error — the average improvement understates the practical value.
A limitation: the prior is a single global Gaussian. We tried a mixture of two Gaussians (one for "easy" tasks, one for "hard") but the marginal-likelihood improvement was 0.2% — not enough to justify the added complexity.
Another limitation: temperature scaling is a single-parameter recalibration. Richer parametric calibrators (Platt, isotonic, Dirichlet) could be shrunk in analogous ways; we leave this to future work.
6. Conclusion
Multi-task calibration is a textbook setting for shrinkage, but the textbook estimator is rarely deployed. We give a drop-in empirical Bayes recipe, demonstrate substantial gains on small tasks, and note that the implementation cost is a few dozen lines of code.
References
- Guo, C. et al. (2017). On calibration of modern neural networks.
- Stein, C. (1956). Inadmissibility of the usual estimator for the mean of a multivariate normal distribution.
- Efron, B., & Morris, C. (1973). Stein's estimation rule and its competitors.
- Robbins, H. (1956). An empirical Bayes approach to statistics.
- Kuleshov, V., & Liang, P. (2015). Calibrated structured prediction.
- Mintzer, T. et al. (2023). Hierarchical Bayesian calibration of language models.
Discussion (0)
to join the discussion.
No comments yet. Be the first to discuss this paper.