SGS Model: LASDD-WL (Momentum)

Locally-Averaged Scale-Dependent Dynamic SGS model using the Wong-Lilly base formulation. Called for optSgs = 2 (LASDD-WL) and optSgs = 4 (LAD-WL). For LAD variants (optSgs = 4), the scale-dependence parameter beta is set to 1 rather than computed.

Source Code: DynamicSGS_LASDD_WL

DynamicSGS_LASDD_WL.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: DynamicSGS_LASDD_WL.py
 18=============================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2026-5-9
 24:Description: locally-averaged scale-dependent dynamic (LASDD) model
 25              using the Wong-Lilly (1994) SGS base model (LASDD-WL).
 26              Reference: Anderson, Basu, and Letchford (2007), EFM.
 27"""
 28
 29# ============================================================
 30#  Imports
 31# ============================================================
 32
 33import jax
 34import jax.numpy as jnp
 35
 36# Import derived variables
 37from ..config.DerivedVars import *
 38
 39# Import FFT modules
 40from ..operations.FFT import FFT
 41
 42# Import filtering functions
 43from ..operations.Filtering import Filtering_Level1, Filtering_Level2
 44
 45# Import helper functions
 46from ..utilities.Utilities import PlanarMean, StagGridAvg
 47from ..utilities.Utilities import Roots, Imfilter
 48
 49
 50# ============================================================
 51# Find maximum real root between 0 and 5
 52# ============================================================
 53
 54@jax.jit
 55def ComputeBeta1(ff, ee, dd, cc, bb, aa):
 56    """
 57    Solves the polynomial:
 58    ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
 59    for each vertical level to find the scale-dependent parameter beta1.
 60
 61    Parameters:
 62    -----------
 63    ff, ee, dd, cc, bb, aa : ndarray
 64        1D arrays of polynomial coefficients at each vertical level
 65
 66    Returns:
 67    --------
 68    beta1 : ndarray
 69        1D array of the maximum valid real root for each vertical level
 70    """
 71
 72    def find_roots_for_level(k):
 73        coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
 74        guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
 75        roots = jax.vmap(lambda guess:
 76                         Roots(coeffs, init_guess=guess))(guesses)
 77
 78        valid_roots = jnp.where(
 79            (jnp.abs(jnp.imag(roots)) < 1e-6) &
 80            (jnp.real(roots) > 0) &
 81            (jnp.real(roots) < 5.0),
 82            jnp.real(roots),
 83            jnp.nan
 84        )
 85
 86        max_root = jnp.nanmax(valid_roots)
 87        return jnp.where(jnp.isnan(max_root), 1.0, max_root)
 88
 89    return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
 90
 91
 92# ============================================================
 93# Compute Wong-Lilly coefficient at vertical level k
 94# ============================================================
 95
 96@jax.jit
 97def Cwl_at_level_k(LM_k, MM_k):
 98    """
 99    Parameters:
100    -----------
101    LM_k : jax.numpy.ndarray
102        2D horizontal slice of LM at level k
103    MM_k : jax.numpy.ndarray
104        2D horizontal slice of MM at level k
105
106    Returns:
107    --------
108    Cwl : ndarray
109        2D array of the Wong-Lilly coefficient C_WL at level k
110    """
111
112    LMx = Imfilter(LM_k)
113    MMx = Imfilter(MM_k)
114
115    Cwl = LMx / MMx
116
117    mask = (jnp.abs(MMx) < 1e-10) | (Cwl < 0) | (Cwl > 1)
118    Cwl = jnp.where(mask, 0.0, Cwl)
119
120    return Cwl
121
122
123# ============================================================
124# Main LASDD-WL code
125# ============================================================
126
127@jax.jit
128def LASDD(
129        u, v, w,
130        S11, S22, S33,
131        S12, S13, S23,
132        ZeRo3D):
133    """
134    Locally-averaged scale-dependent dynamic model using the
135    Wong-Lilly SGS base model (LASDD-WL).
136
137    Parameters:
138    -----------
139    u, v, w : ndarray
140        Velocity components
141    S11, S22, S33 : ndarray
142        Normal strain rate components
143    S12, S13, S23 : ndarray
144        Shear strain rate components
145    ZeRo3D : ndarray
146        Pre-allocated zero array
147
148    Returns:
149    --------
150    u_, v_, w_ : ndarray
151        Interpolated velocity components
152    u_hat, v_hat, w_hat : ndarray
153        Level-1 filtered velocity components
154    u_hatd, v_hatd, w_hatd : ndarray
155        Level-2 filtered velocity components
156    S_hat, S_hatd : ndarray
157        Filtered strain rate magnitudes (diagnostic)
158    Cwl_3D : ndarray
159        3D field of C_WL coefficient
160    Cwl_1D_avg1 : ndarray
161        1D profile of C_WL (square of mean of sqrt)
162    Cwl_1D_avg2 : ndarray
163        1D profile of C_WL (direct mean)
164    beta1_1D : ndarray
165        1D profile of scale-dependence parameter beta1
166    """
167
168    u_ = u.copy()
169    v_ = v.copy()
170    w_ = ZeRo3D.copy()
171    w_ = w_.at[:, :, 1:nz - 1].set(StagGridAvg(w[:, :, 1:nz]))
172    w_ = w_.at[:, :, 0].set(0.5 * w[:, :, 1])
173    w_ = w_.at[:, :, nz - 1].set(w[:, :, nz - 2])
174
175    # Velocity products
176    uu, vv, ww = u_ ** 2, v_ ** 2, w_ ** 2
177    uv, uw, vw = u_ * v_, u_ * w_, v_ * w_
178
179    # Level-1 filtered velocities and products
180    u_hat  = Filtering_Level1(FFT(u_))
181    v_hat  = Filtering_Level1(FFT(v_))
182    w_hat  = Filtering_Level1(FFT(w_))
183    uu_hat = Filtering_Level1(FFT(uu))
184    vv_hat = Filtering_Level1(FFT(vv))
185    ww_hat = Filtering_Level1(FFT(ww))
186    uv_hat = Filtering_Level1(FFT(uv))
187    uw_hat = Filtering_Level1(FFT(uw))
188    vw_hat = Filtering_Level1(FFT(vw))
189
190    # Level-2 filtered velocities and products
191    u_hatd  = Filtering_Level2(FFT(u_))
192    v_hatd  = Filtering_Level2(FFT(v_))
193    w_hatd  = Filtering_Level2(FFT(w_))
194    uu_hatd = Filtering_Level2(FFT(uu))
195    vv_hatd = Filtering_Level2(FFT(vv))
196    ww_hatd = Filtering_Level2(FFT(ww))
197    uv_hatd = Filtering_Level2(FFT(uv))
198    uw_hatd = Filtering_Level2(FFT(uw))
199    vw_hatd = Filtering_Level2(FFT(vw))
200
201    # Filtered strain rate components
202    S11_hat  = Filtering_Level1(FFT(S11))
203    S22_hat  = Filtering_Level1(FFT(S22))
204    S33_hat  = Filtering_Level1(FFT(S33))
205    S12_hat  = Filtering_Level1(FFT(S12))
206    S13_hat  = Filtering_Level1(FFT(S13))
207    S23_hat  = Filtering_Level1(FFT(S23))
208
209    S11_hatd = Filtering_Level2(FFT(S11))
210    S22_hatd = Filtering_Level2(FFT(S22))
211    S33_hatd = Filtering_Level2(FFT(S33))
212    S12_hatd = Filtering_Level2(FFT(S12))
213    S13_hatd = Filtering_Level2(FFT(S13))
214    S23_hatd = Filtering_Level2(FFT(S23))
215
216    # Filtered strain rate magnitudes (diagnostic outputs)
217    S_hat  = jnp.sqrt(2 * (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
218                           2 * S12_hat ** 2 +
219                           2 * S13_hat ** 2 +
220                           2 * S23_hat ** 2))
221    S_hatd = jnp.sqrt(2 * (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
222                            2 * S12_hatd ** 2 +
223                            2 * S13_hatd ** 2 +
224                            2 * S23_hatd ** 2))
225
226    # Leonard stress tensors L_ij (Level 1) and Q_ij (Level 2)
227    L11, L22, L33 = (uu_hat - u_hat ** 2,
228                     vv_hat - v_hat ** 2,
229                     ww_hat - w_hat ** 2)
230    L12, L13, L23 = (uv_hat - u_hat * v_hat,
231                     uw_hat - u_hat * w_hat,
232                     vw_hat - v_hat * w_hat)
233
234    Q11, Q22, Q33 = (uu_hatd - u_hatd ** 2,
235                     vv_hatd - v_hatd ** 2,
236                     ww_hatd - w_hatd ** 2)
237    Q12, Q13, Q23 = (uv_hatd - u_hatd * v_hatd,
238                     uw_hatd - u_hatd * w_hatd,
239                     vw_hatd - v_hatd * w_hatd)
240
241    # ----------------------------------------------------------
242    # WL polynomial coefficients (ABL07 Appendix, Eqs. A1-A10)
243    # Independent scalars: a1, a3, a6, a8
244    # ----------------------------------------------------------
245    a1_terms = (Q11 * S11_hat + Q22 * S22_hat + Q33 * S33_hat +
246                2 * (Q12 * S12_hat + Q13 * S13_hat + Q23 * S23_hat))
247    a1 = PlanarMean(a1_terms)
248
249    a3_terms = (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
250                2 * (S12_hat ** 2 + S13_hat ** 2 + S23_hat ** 2))
251    a3 = PlanarMean(a3_terms)
252
253    a6_terms = (L11 * S11_hat + L22 * S22_hat + L33 * S33_hat +
254                2 * (L12 * S12_hat + L13 * S13_hat + L23 * S23_hat))
255    a6 = PlanarMean(a6_terms)
256
257    a8_terms = (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
258                2 * (S12_hatd ** 2 + S13_hatd ** 2 + S23_hatd ** 2))
259    a8 = PlanarMean(a8_terms)
260
261    # Derived scalars (all expressible in terms of a1, a3, a6, a8)
262    a2  = -(TFR ** (8 / 3)) * a1
263    a4  = -2 * TFR ** (4 / 3) * a3
264    a5  =  TFR ** (8 / 3) * a3
265    a7  = -TFR ** (4 / 3) * a6
266    a9  = -2 * TFR ** (8 / 3) * a8
267    a10 =  TFR ** (16 / 3) * a8
268
269    # Polynomial coefficients A0...A5 mapped to aa...ff
270    aa = a1 * a3 - a6 * a8            # A0 (constant term)
271    bb = a1 * a4 - a7 * a8            # A1 (beta^1)
272    cc = a2 * a3 + a1 * a5 - a6 * a9  # A2 (beta^2)
273    dd = a2 * a4 - a7 * a9            # A3 (beta^3)
274    ee = a2 * a5 - a6 * a10           # A4 (beta^4)
275    ff = -a7 * a10                    # A5 (beta^5)
276
277    computeBeta = optSgs in [1, 2]
278    if computeBeta:
279        beta1_1D = ComputeBeta1(ff, ee, dd, cc, bb, aa)
280    else:
281        beta1_1D = jnp.ones(nz)
282    beta1_3D = jnp.broadcast_to(beta1_1D.reshape(1, 1, nz), (nx, ny, nz))
283
284    # ----------------------------------------------------------
285    # WL M tensor: M_ij = 2*Δf^(4/3)*(S̄_ij - α^(4/3)*β*Ŝ_ij)
286    # ----------------------------------------------------------
287    T1 = 2 * L ** (4 / 3)
288    T2 = 2 * (TFR * L) ** (4 / 3)
289    M11 = T1 * S11_hat - T2 * beta1_3D * S11_hatd
290    M22 = T1 * S22_hat - T2 * beta1_3D * S22_hatd
291    M33 = T1 * S33_hat - T2 * beta1_3D * S33_hatd
292    M12 = T1 * S12_hat - T2 * beta1_3D * S12_hatd
293    M13 = T1 * S13_hat - T2 * beta1_3D * S13_hatd
294    M23 = T1 * S23_hat - T2 * beta1_3D * S23_hatd
295
296    # LM = L_ij * M_ij,  MM = M_ij * M_ij
297    LM = ((L11 * M11 + L22 * M22 + L33 * M33) +
298          2 * (L12 * M12 + L13 * M13 + L23 * M23))
299
300    MM = (M11 ** 2 + M22 ** 2 + M33 ** 2 +
301          2 * (M12 ** 2 + M13 ** 2 + M23 ** 2))
302
303    # C_WL field: local 3x3 averaging via Imfilter
304    Cwl_3D = jax.vmap(Cwl_at_level_k, in_axes=(2, 2), out_axes=2)(LM, MM)
305
306    Cwl_1D_avg1 = PlanarMean(jnp.sqrt(Cwl_3D)) ** 2
307    Cwl_1D_avg2 = PlanarMean(Cwl_3D)
308
309    return (u_, v_, w_,
310            u_hat, v_hat, w_hat,
311            u_hatd, v_hatd, w_hatd,
312            S_hat, S_hatd,
313            Cwl_3D, Cwl_1D_avg1, Cwl_1D_avg2, beta1_1D)