Gradient pathologies and adaptive weights

Part 3 — Training pathologies and remedies

Learning objectives

  • Recognise per-point gradient pathology: hard collocation points get under-weighted by uniform averaging
  • Implement gradient-norm balancing (Wang–Teng–Perdikaris 2021) at the per-point level
  • Implement SA-PINN (McClenny–Braga-Neto 2023): per-point learnable masks γ_i ascended on the loss
  • See empirically that SA-PINN auto-discovers the hard regions of a problem

§3.3 fixed the loss-balance problem at the level of terms (PDE vs IC vs data). But within a single term, the NN collocation points are still averaged uniformly. If the residual is concentrated at a few hard points (a shock, a sharp transition, an initial condition), uniform averaging dilutes their gradient signal. This is the per-point version of the gradient pathology.

Two adaptive remedies

(1) Gradient-norm balancing (Wang, Teng, Perdikaris 2021). At each step, compute the per-point gradient norm gi=θr(xi;θ)g_i = | \nabla_\theta r(x_i; \theta) |. Weight each point inversely proportional to its gradient norm:

wi=maxjgjgi,w_i = \frac{\max_j g_j}{g_i} ,

so points with large per-point gradients get small weights (their direction is already well-represented), and points with small gradients get amplified. This is the gradient-norm version of GradNorm (Chen et al. 2018).

(2) Self-Adaptive PINNs (SA-PINN) (McClenny & Braga-Neto 2023). Attach a per-point trainable parameter γi\gamma_i and a mask function m(γi)=σ(γi)m(\gamma_i) = \sigma(\gamma_i) (sigmoid). The loss becomes

L(θ,γ)=1Nim(γi)r(xi;θ)2.\mathcal{L}(\theta, \gamma) = \frac{1}{N} \sum_i m(\gamma_i) \, r(x_i; \theta)^2 .

SA-PINN ascends in γ\gamma while descending in θ\theta:

θθηθθL,γγ+ηγγL.\theta \leftarrow \theta - \eta_\theta \nabla_\theta \mathcal{L} , \qquad \gamma \leftarrow \gamma + \eta_\gamma \nabla_\gamma \mathcal{L} .

The min-max formulation pushes hard points (large residual) to large γi\gamma_i (so m(γi)1m(\gamma_i) \to 1) and easy points to small γi\gamma_i (so m(γi)0m(\gamma_i) \to 0). The optimiser self-discovers where the hard regions are.

Try it: three-mode race

The widget races three weighting strategies on the §3.1 sampling-bias pathology — regress f(x)=tanh(20x)f(x) = \tanh(20 x) on x[1,1]x \in [-1, 1] with 30 uniform points:

Sa PinnInteractive figure — enable JavaScript to interact.

What you should observe

  • uniform: the central transition stays poorly resolved — the centre/outer residual ratio is the §3.1 pathology.
  • grad-norm: better than uniform, but the per-point scaling can over-correct on points with very small gradients (numerical noise).
  • SA-PINN: the m(γi)m(\gamma_i) mask concentrates near the centre transition (visible in the right panel as taller bars at x0x \approx 0). Centre residual edges uniform, and crucially the centre/outer ratio drops by ~2× — SA-PINN balances the residual across the domain, not just at the hardest point. On larger problems (Allen-Cahn, wave equation) the centre-residual win grows to 5–10×.

SA-PINN as automatic curriculum

The min-max formulation is mathematically equivalent to a primal-dual method on a constrained optimisation problem (the constraint being "the residual must be small everywhere", not just on average). McClenny & Braga-Neto (2023) prove convergence guarantees under standard assumptions; in practice SA-PINN works well on the wave equation, Burgers, and Allen-Cahn benchmarks. The cost is one extra parameter per collocation point — cheap.

SA-PINN doubles as automatic curriculum: at the start of training all γi=0\gamma_i = 0 (so m=0.5m = 0.5 uniformly), and points adapt as the network learns. The hard regions become visible in γ\gamma-space without any user input. In seismic PINNs this is particularly useful for FWI where the source-region residual dominates the wavefield-region residual.

References

  • Wang, S., Teng, Y., Perdikaris, P. (2021). Understanding and mitigating gradient flow pathologies in physics-informed neural networks. SISC 43(5).
  • McClenny, L., Braga-Neto, U. (2023). Self-adaptive physics-informed neural networks. J. Comput. Phys. 474, 111722.
  • Chen, Z., Badrinarayanan, V., Lee, C.-Y., Rabinovich, A. (2018). GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks. ICML.

This page is prerendered for SEO and accessibility. The interactive widgets above hydrate on JavaScript load.