Grokking Phase Diagrams: Mapping Delayed Generalization in Modular Arithmetic
Introduction
[power2022grokking] discovered that small neural networks trained on modular arithmetic tasks can exhibit "grokking" — a phenomenon where test accuracy remains at chance level long after training accuracy reaches near perfection, then suddenly jumps to high accuracy after many additional epochs of training. This delayed generalization represents a phase transition in the learning dynamics.
Subsequent work has deepened our understanding: [nanda2023progress] reverse-engineered the learned algorithm as a discrete Fourier transform with trigonometric identities, identifying three continuous training phases (memorization, circuit formation, cleanup). [liu2022omnigrok] introduced the "LU mechanism" explaining grokking through the mismatch between L-shaped training loss and U-shaped test loss as functions of weight norm, and demonstrated that grokking can be induced or suppressed across diverse data domains.
Despite these advances, a systematic mapping of the grokking phase diagram across multiple hyperparameters simultaneously has received less attention. In this work, we sweep over three key hyperparameters — weight decay, dataset fraction, and model width — to construct a complete phase diagram that classifies each training run into one of four outcomes.
Methods
Task and Data
We study modular addition: given integers , predict where (a standard prime used in the grokking literature). The full dataset contains input-output pairs. Each pair is unique, and labels are uniformly distributed over .
Model Architecture
We use a one-hidden-layer MLP with learned embeddings:
- Each input and is mapped to a 16-dimensional learned embedding
- The two embeddings are concatenated to form a 32-dimensional vector
- A linear layer maps to hidden units with ReLU activation
- A final linear layer maps to output logits
Parameter counts range from () to (), all well under 100K.
Training
We use the AdamW optimizer [loshchilov2019adamw] with learning rate , , and cross-entropy loss. Training uses full-batch gradient descent (all training examples in each batch), following the standard grokking setup. Each run trains for up to 2,500 epochs with early stopping when both train and test accuracy exceed 99% for two consecutive evaluation points. Metrics are logged every 100 epochs.
Phase Classification
We classify each training run into one of four phases:
- Confusion: Final training accuracy . The model fails to memorize.
- Memorization: Training accuracy but test accuracy . Overfitting without generalization.
- Grokking: Both accuracies reach , but test accuracy lags train by epochs. Delayed generalization.
- Comprehension: Both accuracies reach with test lagging by epochs. Fast generalization.
Hyperparameter Sweep
We sweep over:
- Weight decay (5 values)
- Dataset fraction (4 values)
- Hidden dimension (3 values)
This yields training runs, all with seed fixed to 42 for full reproducibility.
The execution script also records reproducibility metadata (sweep grid, seed, package versions, and runtime) in results/metadata.json, and the validator checks full grid coverage plus phase/gap consistency for each run.
Results
Phase Diagram Structure
The phase diagram reveals clear boundaries between learning regimes. The results are presented as 2D heatmaps (weight decay dataset fraction) for each hidden dimension.
The phase diagram shows that generalization is absent at low dataset fractions ( or ) across the entire sweep, but becomes common once . Within that higher-data regime, the narrowest model () still struggles, while wider models display both delayed grokking and rapid comprehension. Weight decay affects which of those regimes is more likely, but no single value cleanly separates success from failure across all widths and dataset fractions.
Role of Weight Decay
Weight decay shapes the learning regime, but in this sweep it does not act as a single universal threshold:
- : Mostly confusion or memorization, but wider models with can still generalize, including two grokking runs and one rapid-comprehension run.
- to : The most reliable grokking region in our sweep. These settings account for 7 of the 9 grokking runs.
- : Strong regularization suppresses grokking, but does not eliminate generalization. At high dataset fraction and sufficient width, runs transition directly to comprehension instead.
Role of Dataset Fraction
Larger dataset fractions facilitate generalization. With (30% training data), models struggle to generalize even with appropriate weight decay. At --, grokking and comprehension are more common. This aligns with the Omnigrok finding [liu2022omnigrok] of a critical training set size below which generalization is impossible.
Role of Model Width
Wider models () tend to grok or comprehend more readily than narrow models () at the same weight decay and dataset fraction. This suggests that overparameterization, when combined with appropriate regularization, facilitates the transition to generalizing solutions.
Discussion
Our phase diagram confirms and refines several findings from the grokking literature:
- Dataset fraction is the clearest gate on generalization in this setup: none of the runs with generalize, while most runs with do.
- Weight decay modulates how generalization appears once enough data and width are available: intermediate values favor delayed grokking, while extreme values more often yield either confusion or fast comprehension.
- Model width modulates the phase boundaries, with wider models having broader grokking regions.
- The four-phase structure (confusion, memorization, grokking, comprehension) is robust across model widths.
Limitations. Our study uses a single arithmetic operation (addition mod 97), a single optimizer (AdamW), and a single seed. The grokking gap threshold (200 epochs) is somewhat arbitrary. Extending to multiplication, varying seeds for confidence intervals, and exploring additional optimizers would strengthen these findings.
Reproducibility. All code is deterministic on CPU with pinned seeds and dependency versions. Each run emits machine-checkable artifacts (sweep\_results.json, phase\_diagram.json, metadata.json) and passes a validator that enforces artifact presence, Cartesian grid completeness, and phase-label consistency. In our verification pass, the complete 60-run analysis finished on CPU in 383 seconds (about 6.4 minutes). No GPU, internet access, or authentication is required.
\bibliographystyle{plainnat}
References
[power2022grokking] Power, A., Burda, Y., Edwards, H., Babuschkin, I., and Misra, V. Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv preprint arXiv:2201.02177, 2022.
[nanda2023progress] Nanda, N., Chan, L., Lieberum, T., Smith, J., and Steinhardt, J. Progress measures for grokking via mechanistic interpretability. In ICLR, 2023.
[liu2022omnigrok] Liu, Z., Michaud, E. J., and Tegmark, M. Omnigrok: Grokking beyond algorithmic data. arXiv preprint arXiv:2210.01117, 2022.
[loshchilov2019adamw] Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. In ICLR, 2019.
Reproducibility: Skill File
Use this skill file to reproduce the research with an AI agent.
---
name: grokking-phase-diagrams
description: Train tiny MLPs on modular arithmetic (addition mod 97) and map the grokking phase diagram as a function of weight decay, dataset fraction, and model width. Classifies each training run into four phases (confusion, memorization, grokking, comprehension) and generates heatmap visualizations.
allowed-tools: Bash(git *), Bash(python *), Bash(python3 *), Bash(pip *), Bash(.venv/*), Bash(cat *), Read, Write
---
# Grokking Phase Diagrams
This skill trains tiny neural networks on modular arithmetic and studies the "grokking" phenomenon — the delayed phase transition from memorization to generalization. It sweeps over weight decay, dataset fraction, and model width to map the full phase diagram.
## Prerequisites
- Requires **Python 3.10+**.
- **No internet access needed** — all data is generated locally (modular arithmetic).
- **No GPU needed** — models are tiny (<20K parameters), trained on CPU.
- Expected runtime: **5-7 minutes** (60 training runs, up to 2500 epochs each, on CPU).
- All commands must be run from the **submission directory** (`submissions/grokking/`).
## Step 0: Get the Code
Clone the repository and navigate to the submission directory:
```bash
git clone https://github.com/davidydu/Claw4S.git
cd Claw4S/submissions/grokking/
```
All subsequent commands assume you are in this directory.
## Step 1: Environment Setup
Create a virtual environment and install dependencies:
```bash
python3 -m venv .venv
.venv/bin/pip install --upgrade pip
.venv/bin/pip install -r requirements.txt
```
Verify all packages are installed:
```bash
.venv/bin/python -c "import torch, numpy, scipy, matplotlib; print(f'PyTorch {torch.__version__}, NumPy {numpy.__version__} — All imports OK')"
```
Expected output: `PyTorch 2.6.0, NumPy 2.2.4 — All imports OK`
## Step 2: Run Unit Tests
Verify the analysis modules work correctly:
```bash
.venv/bin/python -m pytest tests/ -v
```
Expected: Pytest exits with all tests passed and exit code 0. Tests cover data generation, model architecture, training loop, phase classification, and sweep logic.
## Step 3: Run the Analysis
Execute the full phase diagram sweep:
```bash
.venv/bin/python run.py
```
Expected: Script runs 60 training experiments (5 weight decays x 4 dataset fractions x 3 hidden dims [16, 32, 64]), prints progress for each run, and exits with code 0. Output files are created in `results/`.
This will:
1. Generate modular addition dataset (all (a,b) pairs for a,b in 0..96, computing (a+b) mod 97)
2. For each hyperparameter combination: train a tiny MLP, log accuracy curves, classify the outcome
3. Generate phase diagram heatmaps showing grokking/memorization/confusion/comprehension regions
4. Generate example training curves illustrating the grokking phenomenon
5. Save results to `results/sweep_results.json`, `results/phase_diagram.json`, `results/metadata.json`, and `results/report.md`
Optional: run custom sweeps without editing source code:
```bash
.venv/bin/python run.py --weight-decays 0,0.001,0.01 --dataset-fractions 0.5,0.7,0.9 --hidden-dims 32,64 --p 97 --max-epochs 2500 --seed 42
```
## Step 4: Validate Results
Check that results were produced correctly:
```bash
.venv/bin/python validate.py
```
Expected: Prints validation checks (artifacts present, full Cartesian grid coverage, no duplicate/missing hyperparameter points, phase/gap consistency, metadata consistency) and `Validation passed.`
## Step 5: Review the Report
Read the generated report:
```bash
cat results/report.md
```
Review the phase diagram to understand where grokking occurs vs memorization vs comprehension.
The report contains:
- Phase distribution across all 60 runs
- Effect of weight decay on grokking
- Effect of dataset fraction on generalization
- Detailed per-run results table
- Phase diagram heatmaps (one per hidden dimension)
- Example training curves showing the grokking phenomenon
- Run metadata (`results/metadata.json`) including sweep config, seed, package versions, and runtime
## How to Extend
- **Change the arithmetic operation:** Modify `generate_modular_addition_data()` in `src/data.py` to compute `(a * b) % p` instead of `(a + b) % p`.
- **Change the prime modulus:** Use `run.py --p <prime>`. Smaller p (e.g., 23) runs faster; larger p may require more epochs.
- **Change sweep grids:** Use `run.py --weight-decays ... --dataset-fractions ... --hidden-dims ...` to run alternative grids without code edits.
- **Add new sweep dimensions in code:** Extend `run_single()` / `run_sweep()` in `src/sweep.py` (e.g., learning rate, embedding dimension).
- **Change grokking threshold:** Modify `ACC_THRESHOLD` (default 0.95) and `GROKKING_GAP_THRESHOLD` (default 200 epochs) in `src/analysis.py`.
- **Increase training budget:** Use `run.py --max-epochs <N>` (default 2500; larger values increase runtime).
Discussion (0)
to join the discussion.
No comments yet. Be the first to discuss this paper.