NTK-balanced weighting
Learning objectives
- State the Neural Tangent Kernel (NTK) for a multi-term PINN loss
- Derive Wang–Yu–Perdikaris (2022) NTK-trace weighting: λ_a = (Σ_b tr K_b) / tr K_a
- Confirm empirically that NTK weighting eliminates the loss-balance crisis on the harmonic IVP
- Recognise the cost: full NTK trace computation every K epochs is expensive on large nets
§3.2 left us with a question: the optimal loss weight depends on per-term gradient magnitudes, but those magnitudes change across training. Can the optimiser detect the imbalance and re-weight automatically? Wang, Yu & Perdikaris (2022) gave the answer in the language of the Neural Tangent Kernel.
The NTK and its connection to loss balance
For a network with parameters , the Neural Tangent Kernel is
Jacot, Gabriel & Hongler (2018) showed that in the infinite-width limit, gradient flow on an MSE loss is equivalent to kernel regression with kernel . For a finite network the equivalence is approximate, but the NTK still governs the early-training dynamics.
For a multi-term PINN loss with , define the per-term NTK
The trace of this matrix is
which measures the total gradient mass each loss term contributes. Wang, Yu & Perdikaris (2022) prove that setting
makes every loss term contribute equally to the gradient flow. The loss-balance crisis is resolved automatically — no hand-tuning, no slider.
Try it: NTK race
The widget races two trainings on the same harmonic IVP from §3.1/§3.2: vanilla versus NTK-balanced (recomputed every 100 epochs). Watch the per-term loss curves and the auto-adjusted values.
What you should observe
- Vanilla λ=1: the §3.1 pathology repeats. PDE residual collapses, IC stays at , relative-L² is ~100%.
- NTK-weighted: rises sharply (clamped to ~100, in the §3.2 sweet-spot range), pushing the optimiser to honour the IC from the start. Both terms now decrease together. Relative-L² typically drops by 30× or more compared to vanilla — the network recovers a clean cosine.
The remarkable fact is that the NTK-derived converges to roughly the same order of magnitude that you found by hand-scanning in §3.2 — but the algorithm finds it automatically, in a single training run. The implementation uses two stabilisations: EMA smoothing () so successive updates do not whip-saw, and clamping the implied weight to to prevent over-correction on small toy problems where individual NTK trace ratios can spike. Wang, Yu & Perdikaris (2022) recommend both; they appear in every production-quality NTK-weighted PINN implementation.
The cost of NTK weighting
Computing requires a backward pass per collocation point per loss term: work where is the parameter count. On a small toy problem (50 collocation points, 1k parameters) this is cheap. On a 2D wave PINN with 10⁴ collocation points and 10⁵ parameters, computing the trace every step costs more than the actual training step. The standard practical fix:
- Recompute every – steps, not every step.
- Subsample collocation points (a few hundred) for the trace computation.
- Smooth with EMA (exponential moving average) so successive updates do not whip-saw.
NTK weighting in seismic PINNs
For 2D acoustic-wave forward problems (Part 4), NTK weighting is now the default in most published seismic PINNs (Rasht-Behesht et al. 2022; Song et al. 2023). For inverse problems (Part 6) where the loss has data + PDE + regularisation terms, the same machinery applies but you weight each term separately. The §3.4 alternatives (gradient-norm balancing, SA-PINN) are cheaper and work for the more common case where you just need to keep the PDE residual and the data-fit term in lockstep.
References
- Jacot, A., Gabriel, F., Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. NeurIPS.
- Wang, S., Yu, X., Perdikaris, P. (2022). When and why PINNs fail to train: A neural tangent kernel perspective. J. Comput. Phys. 449, 110768.
- Rasht-Behesht, M., et al. (2022). Physics-informed neural networks (PINNs) for wave propagation and full-waveform inversions. JGR Solid Earth.