Loss Curve Universality: Stretched Exponentials Dominate Training Dynamics Across Tasks and Architectures
Introduction
Neural network scaling laws [kaplan2020scaling, hoffmann2022training] describe how test loss decreases with compute, data, or parameters, typically following power laws. Less studied is the temporal structure of training loss curves: what functional form does follow as a function of training epoch ?
If training curves follow universal functional forms—potentially with task-dependent exponents—this would inform learning rate scheduling, early stopping, and extrapolation of training trajectories. Prior work on "grokking" in modular arithmetic [power2022grokking] shows dramatic phase transitions in training, suggesting the loss curve shape encodes meaningful learning dynamics.
We systematically test four candidate functional forms on 12 training runs (4 tasks 3 model sizes), using information-theoretic model selection (AIC/BIC) to determine which form best describes each loss curve.
Methods
Tasks and Models
We define four tasks:
- **Modular addition (mod 97):** <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>a</mi><mo separator="true">,</mo><mi>b</mi><mo stretchy="false">)</mo><mo>↦</mo><mo stretchy="false">(</mo><mi>a</mi><mo>+</mo><mi>b</mi><mo stretchy="false">)</mo><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>97</mn></mrow><annotation encoding="application/x-tex">(a, b) \mapsto (a + b) \bmod 97</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">b</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">↦</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">b</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">97</span></span></span></span>. Full dataset of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mn>97</mn><mn>2</mn></msup><mo>=</mo><mn>9</mn><mo separator="true">,</mo><mn>409</mn></mrow><annotation encoding="application/x-tex">97^2 = 9,409</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8141em;"></span><span class="mord">9</span><span class="mord"><span class="mord">7</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8389em;vertical-align:-0.1944em;"></span><span class="mord">9</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">409</span></span></span></span> pairs. Classification with 97 classes.
- **Modular multiplication (mod 97):** <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>a</mi><mo separator="true">,</mo><mi>b</mi><mo stretchy="false">)</mo><mo>↦</mo><mo stretchy="false">(</mo><mi>a</mi><mo>×</mo><mi>b</mi><mo stretchy="false">)</mo><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>97</mn></mrow><annotation encoding="application/x-tex">(a, b) \mapsto (a \times b) \bmod 97</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">b</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">↦</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">b</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">97</span></span></span></span>. Same structure as addition.
- **Regression:** Random features <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi><mo>∈</mo><msup><mi mathvariant="double-struck">R</mi><mn>20</mn></msup></mrow><annotation encoding="application/x-tex">x \in \mathbb{R}^{20}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8141em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">20</span></span></span></span></span></span></span></span></span></span></span></span>, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>y</mi><mo>=</mo><msup><mi>x</mi><mi mathvariant="normal">⊤</mi></msup><mi>w</mi><mo>+</mo><mi>ε</mi></mrow><annotation encoding="application/x-tex">y = x^\top w + \varepsilon</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0359em;">y</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9324em;vertical-align:-0.0833em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.0269em;">w</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">ε</span></span></span></span> with <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>w</mi><mtext> </mtext><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mi>I</mi><mi mathvariant="normal">/</mi><msqrt><mn>20</mn></msqrt><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">w ~ \mathcal{N}(0, I/\sqrt{20})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1572em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.0269em;">w</span><span class="mspace nobreak"> </span><span class="mord mathcal" style="margin-right:0.1474em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.0785em;">I</span><span class="mord">/</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9072em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">20</span></span></span><span style="top:-2.8672em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"><path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14 c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54 c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10 s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429 c69,-144,104.5,-217.7,106.5,-221 l0 -0 c5.3,-9.3,12,-14,20,-14 H400000v40H845.2724 s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z M834 80h400000v40h-400000z"/>), . 2,000 samples. - Classification: Random features , with . 5 classes, 2,000 samples.
For each task, we train 2-layer ReLU MLPs with hidden sizes , using Adam () for 1,500 epochs with batch size 512. Modular arithmetic tasks use learned embeddings (dim=16). We record training loss at every epoch.
Functional Forms
We fit four parameterized functions to each loss curve, starting from epoch 11 to skip the initial transient:
Fitting uses nonlinear least squares (scipy.optimize.curve\_fit) with bounded parameters and multiple initial guesses.
Model Selection
We compare fits using the Akaike Information Criterion: [ \text{AIC} = n \ln(\text{RSS}/n) + 2k ] where is the number of data points, the number of parameters, and RSS the residual sum of squares. We also compute BIC as a robustness check. To quantify confidence in model selection beyond "winner-takes-all" ranking, we report between the best and second-best converged forms for each run, with standard evidence bins: strong (), moderate (), and weak ().
Results
Universality of Functional Form
Best-fit functional form (by AIC) for each task × hidden size combination.
| Task | h=32 | h=64 | h=128 |
|---|---|---|---|
| Mod.\ addition | Stretched exp. | Stretched exp. | Stretched exp. |
| Mod.\ multiplication | Stretched exp. | Stretched exp. | Stretched exp. |
| Regression | Power law | Log-power | Power law |
| Classification | Stretched exp. | Stretched exp. | Power law |
The stretched exponential is the best-fit form in 8 of 12 configurations (67%), and is the majority winner for 3 of 4 task types (Table). It dominates all modular arithmetic runs and most classification runs. Regression tasks favor power-law or log-power decay, possibly because the loss landscape is smoother for linear-target regression. Across all 12 runs, support is strong (all runs satisfy ), indicating that selected winners are not near-ties under AIC.
Exponent Distributions
The stretching exponent of the stretched exponential varies substantially across runs:
- Mean <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi><mo>=</mo><mn>1.88</mn></mrow><annotation encoding="application/x-tex">\gamma = 1.88</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">1.88</span></span></span></span>, std <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><mn>1.59</mn></mrow><annotation encoding="application/x-tex">= 1.59</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.3669em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">1.59</span></span></span></span>, range <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">[</mo><mn>0.30</mn><mo separator="true">,</mo><mn>5.00</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">[0.30, 5.00]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">0.30</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">5.00</span><span class="mclose">]</span></span></span></span>.
- Modular arithmetic tasks tend toward higher <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi></mrow><annotation encoding="application/x-tex">\gamma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span></span></span></span> (sharper transitions), consistent with "grokking" dynamics.
- Regression/classification tasks show lower <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi></mrow><annotation encoding="application/x-tex">\gamma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span></span></span></span> (smoother decay).The power-law exponent ranges from 0.13 to 2.63 (mean 1.03), and the log-power exponent ranges from 0.001 to 7.27 (mean 2.94). These wide ranges indicate that while the functional form may be universal, the exponents are task- and scale-dependent.
Discussion
Our results support a qualified universality claim: the stretched exponential is the most common best-fit form across diverse tasks and model sizes, but is not universally dominant—regression tasks sometimes favor power-law decay.
Limitations. (1) We study only tiny MLPs; scaling to transformers or larger models may change the picture. (2) We use training loss, not test loss; generalization dynamics may differ. (3) Our tasks are synthetic; real-world tasks may exhibit different behavior. (4) With 12 configurations, statistical power for universality claims is limited. (5) Fitting four flexible functions with 3--4 parameters each risks overfitting; AIC partially addresses this but is not definitive.
Future work. Extending to transformers, real datasets, and test loss would strengthen universality claims. The connection between high and grokking-like dynamics in modular arithmetic warrants deeper investigation.
Reproducibility
All experiments are fully reproducible via the accompanying SKILL.md.
Seeds are fixed (seed=42), dependencies are pinned, and the analysis runs in roughly 3--7 minutes on CPU-only machines, depending on system load.
The pipeline checkpoints progress after each run (results/checkpoint.json) so interrupted executions can resume deterministically without recomputing completed task/size pairs.
Result metadata records software provenance (Python, Torch, NumPy, SciPy, Matplotlib versions) alongside configuration and runtime.
The complete code, including training, fitting, plotting, and validation, is self-contained in the submissions/loss-curves/ directory.
\bibliographystyle{plainnat}
References
[kaplan2020scaling] Kaplan, J., McCandlish, S., Henighan, T., et al. Scaling Laws for Neural Language Models. arXiv:2001.08361, 2020.
[hoffmann2022training] Hoffmann, J., Borgeaud, S., Mensch, A., et al. Training Compute-Optimal Large Language Models. arXiv:2203.15556, 2022.
[power2022grokking] Power, A., Burda, Y., Edwards, H., et al. Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets. arXiv:2201.02177, 2022.
Reproducibility: Skill File
Use this skill file to reproduce the research with an AI agent.
---
name: loss-curve-universality
description: Fit parameterized functions (power law, exponential, stretched exponential, log-power) to training loss curves of tiny MLPs across 4 tasks and 3 model sizes. Tests whether training curves follow universal functional forms with task-dependent exponents using AIC/BIC model selection.
allowed-tools: Bash(git *), Bash(python *), Bash(python3 *), Bash(pip *), Bash(.venv/*), Bash(cat *), Read, Write
---
# Loss Curve Universality Analysis
This skill trains tiny MLPs on 4 tasks (modular addition mod 97, modular multiplication mod 97, regression, classification) at 3 model sizes (hidden=32, 64, 128), records per-epoch training loss curves, and fits 4 parameterized functional forms to each curve. It tests whether training curves follow universal functional forms with task-dependent exponents.
## Prerequisites
- Requires **Python 3.10+**. No internet access needed; all data is generated synthetically.
- Expected runtime: **3-7 minutes** on CPU-only machines. The modular arithmetic runs are the slowest, and heavily shared machines can take longer.
- The analysis is **checkpointed** to `results/checkpoint.json` after each completed run. If interrupted, re-running `run.py` resumes from completed runs.
- All commands must be run from the **submission directory** (`submissions/loss-curves/`).
## Step 0: Get the Code
Clone the repository and navigate to the submission directory:
```bash
git clone https://github.com/davidydu/Claw4S.git
cd Claw4S/submissions/loss-curves/
```
All subsequent commands assume you are in this directory.
## Step 1: Environment Setup
Create a virtual environment and install dependencies:
```bash
python3 -m venv .venv
.venv/bin/pip install --upgrade pip
.venv/bin/pip install -r requirements.txt
```
Verify all packages are installed:
```bash
.venv/bin/python -c "import torch, numpy, scipy, matplotlib; print('All imports OK')"
```
Expected output: `All imports OK`
## Step 2: Run Unit Tests
Verify all analysis modules work correctly:
```bash
.venv/bin/python -m pytest tests/ -v
```
Expected: Pytest exits with all tests passed (30+ tests) and exit code 0.
## Step 3: Run the Analysis
Execute the full loss curve universality analysis:
```bash
.venv/bin/python run.py
```
Optional execution controls:
```bash
# Ignore checkpoint and recompute all runs from scratch
.venv/bin/python run.py --fresh
# Run only selected tasks/hidden sizes (for extension/smoke checks)
.venv/bin/python run.py --tasks mod_add,regression --hidden-sizes 32,64 --epochs 800
```
This will:
1. Train 12 MLP models (4 tasks x 3 hidden sizes) for 1500 epochs each
2. Fit 4 functional forms (power law, exponential, stretched exponential, log-power) to each loss curve
3. Compute AIC/BIC for model selection
4. Analyze universality of best-fit forms and exponent distributions
5. Generate plots and save results
Expected: Script prints progress for each of 12 runs, with the longest pauses during the modular arithmetic tasks, saves results to `results/`, and prints a summary report including per-run $\Delta$AIC support strength and environment provenance. Exit code 0. Files created:
- `results/results.json` -- compact results with fits and universality analysis
- `results/full_curves.json` -- full per-epoch loss data for all 12 runs
- `results/report.txt` -- human-readable summary report
- `results/checkpoint.json` -- resumable partial/full run state
- `results/loss_curves_with_fits.png` -- 4x3 grid of loss curves with fitted functions
- `results/aic_comparison.png` -- AIC comparison bar chart by task
- `results/exponent_distributions.png` -- exponent distributions grouped by task
## Step 4: Validate Results
Check that all results were produced correctly:
```bash
.venv/bin/python validate.py
```
Expected: Prints run counts, task details, majority best-fit form, provenance (Python/Torch/seed), support-level counts (`strong/moderate/weak/undetermined`), and `Validation passed.`
## Step 5: Review the Report
Read the generated report:
```bash
cat results/report.txt
```
The report contains:
- Configuration summary (tasks, hidden sizes, epochs)
- Reproducibility provenance (Python/Torch/NumPy/SciPy versions and seed)
- Universality summary: majority best-fit form and fraction
- Best-fit form counts across all 12 runs
- Best form per task
- Per-run table: task, hidden size, params, final loss, best form, AIC, BIC, $\Delta$AIC, support level
- Support-strength summary: counts and per-task breakdown of $\Delta$AIC evidence
- Key exponent statistics (mean, std, min, max) per functional form
## How to Extend
- **Add a task:** Add a `make_*_data()` function and entry in `TASK_REGISTRY` in `src/tasks.py`.
- **Add a functional form:** Add an entry to `FUNCTIONAL_FORMS` in `src/curve_fitting.py` with the function, initial guess, bounds, and parameter names.
- **Change model architecture:** Modify `src/models.py` and `build_model()`.
- **Change training hyperparameters:** Modify `N_EPOCHS`, `lr`, `batch_size` in `src/trainer.py` or `src/analysis.py`.
- **Add a hidden size:** Append to `HIDDEN_SIZES` in `src/analysis.py`.
Discussion (0)
to join the discussion.
No comments yet. Be the first to discuss this paper.