NTK-balanced weighting

Part 3 — Training pathologies and remedies

Learning objectives

  • State the Neural Tangent Kernel (NTK) for a multi-term PINN loss
  • Derive Wang–Yu–Perdikaris (2022) NTK-trace weighting: λ_a = (Σ_b tr K_b) / tr K_a
  • Confirm empirically that NTK weighting eliminates the loss-balance crisis on the harmonic IVP
  • Recognise the cost: full NTK trace computation every K epochs is expensive on large nets

§3.2 left us with a question: the optimal loss weight depends on per-term gradient magnitudes, but those magnitudes change across training. Can the optimiser detect the imbalance and re-weight automatically? Wang, Yu & Perdikaris (2022) gave the answer in the language of the Neural Tangent Kernel.

The NTK and its connection to loss balance

For a network uθ(x)u_\theta(x) with parameters θRP\theta \in \mathbb{R}^P, the Neural Tangent Kernel is

K(x,x)=θuθ(x)θuθ(x)R.K(x, x') = \nabla_\theta u_\theta(x) \cdot \nabla_\theta u_\theta(x') \in \mathbb{R} .

Jacot, Gabriel & Hongler (2018) showed that in the infinite-width limit, gradient flow on an MSE loss is equivalent to kernel regression with kernel KK. For a finite network the equivalence is approximate, but the NTK still governs the early-training dynamics.

For a multi-term PINN loss L=aλaLa\mathcal{L} = \sum_a \lambda_a \mathcal{L}_a with La=1Naira(xi;θ)2\mathcal{L}_a = \frac{1}{N_a} \sum_i r_a(x_i; \theta)^2, define the per-term NTK

Ka(xi,xj)=θra(xi;θ)θra(xj;θ).K_a(x_i, x_j) = \nabla_\theta r_a(x_i; \theta) \cdot \nabla_\theta r_a(x_j; \theta) .

The trace of this matrix is

tr(Ka)=i=1Naθra(xi;θ)2,\textrm{tr}(K_a) = \sum_{i=1}^{N_a} \| \nabla_\theta r_a(x_i; \theta) \|^2 ,

which measures the total gradient mass each loss term contributes. Wang, Yu & Perdikaris (2022) prove that setting

λa=btr(Kb)tr(Ka)\lambda_a = \frac{\sum_b \textrm{tr}(K_b)}{\textrm{tr}(K_a)}

makes every loss term contribute equally to the gradient flow. The loss-balance crisis is resolved automatically — no hand-tuning, no slider.

Try it: NTK race

The widget races two trainings on the same harmonic IVP from §3.1/§3.2: vanilla λPDE=λIC=1\lambda_{\textrm{PDE}} = \lambda_{\textrm{IC}} = 1 versus NTK-balanced (recomputed every 100 epochs). Watch the per-term loss curves and the auto-adjusted λ\lambda values.

Ntk BalanceInteractive figure — enable JavaScript to interact.

What you should observe

  • Vanilla λ=1: the §3.1 pathology repeats. PDE residual collapses, IC stays at 1\sim 1, relative-L² is ~100%.
  • NTK-weighted: λIC\lambda_{\textrm{IC}} rises sharply (clamped to ~100, in the §3.2 sweet-spot range), pushing the optimiser to honour the IC from the start. Both terms now decrease together. Relative-L² typically drops by 30× or more compared to vanilla — the network recovers a clean cosine.

The remarkable fact is that the NTK-derived λIC\lambda_{\textrm{IC}} converges to roughly the same order of magnitude that you found by hand-scanning in §3.2 — but the algorithm finds it automatically, in a single training run. The implementation uses two stabilisations: EMA smoothing (α=0.9\alpha = 0.9) so successive updates do not whip-saw, and clamping the implied weight to [102,102][10^{-2}, 10^2] to prevent over-correction on small toy problems where individual NTK trace ratios can spike. Wang, Yu & Perdikaris (2022) recommend both; they appear in every production-quality NTK-weighted PINN implementation.

The cost of NTK weighting

Computing tr(Ka)\textrm{tr}(K_a) requires a backward pass per collocation point per loss term: O(NcP)O(N_c \cdot P) work where PP is the parameter count. On a small toy problem (50 collocation points, 1k parameters) this is cheap. On a 2D wave PINN with 10⁴ collocation points and 10⁵ parameters, computing the trace every step costs more than the actual training step. The standard practical fix:

  • Recompute λ\lambda every K100K \sim 10010001000 steps, not every step.
  • Subsample collocation points (a few hundred) for the trace computation.
  • Smooth with EMA (exponential moving average) so successive λ\lambda updates do not whip-saw.

NTK weighting in seismic PINNs

For 2D acoustic-wave forward problems (Part 4), NTK weighting is now the default in most published seismic PINNs (Rasht-Behesht et al. 2022; Song et al. 2023). For inverse problems (Part 6) where the loss has data + PDE + regularisation terms, the same machinery applies but you weight each term separately. The §3.4 alternatives (gradient-norm balancing, SA-PINN) are cheaper and work for the more common case where you just need to keep the PDE residual and the data-fit term in lockstep.

References

  • Jacot, A., Gabriel, F., Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. NeurIPS.
  • Wang, S., Yu, X., Perdikaris, P. (2022). When and why PINNs fail to train: A neural tangent kernel perspective. J. Comput. Phys. 449, 110768.
  • Rasht-Behesht, M., et al. (2022). Physics-informed neural networks (PINNs) for wave propagation and full-waveform inversions. JGR Solid Earth.

This page is prerendered for SEO and accessibility. The interactive widgets above hydrate on JavaScript load.