SGS Model: LASDD-SM (Momentum)

Locally-Averaged Scale-Dependent Dynamic SGS model using the Smagorinsky base formulation. Called for optSgs = 1 (LASDD-SM) and optSgs = 3 (LAD-SM). For LAD variants (optSgs = 3), the scale-dependence parameter beta is set to 1 rather than computed.

Source Code: DynamicSGS_LASDD_SM

DynamicSGS_LASDD_SM.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: DynamicSGS_LASDD.py
 18=========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: locally-averaged scale-dependent dynamic (LASDD) model
 25"""
 26
 27# ============================================================
 28#  Imports
 29# ============================================================
 30
 31import jax
 32import jax.numpy as jnp
 33
 34# Import derived variables
 35from ..config.DerivedVars import *
 36
 37# Import FFT modules
 38from ..operations.FFT import FFT
 39
 40# Import filtering functions
 41from ..operations.Filtering import Filtering_Level1, Filtering_Level2
 42
 43# Import helper functions
 44from ..utilities.Utilities import PlanarMean, StagGridAvg
 45from ..utilities.Utilities import Roots, Imfilter
 46
 47
 48# ============================================================
 49# Find maximum real root 0 and 5, with focus on 0.5-1.5 range
 50# ============================================================
 51
 52@jax.jit
 53def ComputeBeta1(ff, ee, dd, cc, bb, aa):
 54    """
 55    This function solves the polynomial:
 56    ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
 57    for each vertical level to find the optimal parameter beta1
 58    used in the LASDD SGS model
 59
 60    Parameters:
 61    -----------
 62    ff, ee, dd, cc, bb, aa : ndarray
 63        1D arrays containing the polynomial coefficients at each vertical level
 64
 65    Returns:
 66    --------
 67    beta1 : ndarray
 68        1D array of the maximum valid real root for each vertical level
 69    """
 70
 71    def find_roots_for_level(k):
 72        # Construct polynomial coefficients for this level
 73        coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
 74
 75        # Use initial guesses concentrated in the expected range (0.5-1.5)
 76        # with a few wider points to catch outliers
 77        guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
 78        roots = jax.vmap(lambda guess:
 79                         Roots(coeffs, init_guess=guess))(guesses)
 80
 81        # Filter valid real roots
 82        valid_roots = jnp.where(
 83            (jnp.abs(jnp.imag(roots)) < 1e-6) &
 84            (jnp.real(roots) > 0) &
 85            (jnp.real(roots) < 5.0),
 86            jnp.real(roots),
 87            jnp.nan
 88        )
 89
 90        # Get maximum valid root or default to 1.0
 91        max_root = jnp.nanmax(valid_roots)
 92        return jnp.where(jnp.isnan(max_root), 1.0, max_root)
 93
 94    # Apply to all levels
 95    return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
 96
 97
 98# ===================================================
 99# Compute Smagorinsky coefficient at vertical level k
100# ===================================================
101
102@jax.jit
103def Cs2_at_level_k(LM_k, MM_k):
104    """
105    Parameters:
106    -----------
107    LM_k : jax.numpy.ndarray
108        2D horizontal slice of LM at level k
109    MM_k : jax.numpy.ndarray
110        2D horizontal slice of MM at level k
111
112    Returns:
113    --------
114    Cs2: ndarray
115        2D array of the squared Smagorinsky coefficient at level k
116    """
117
118    LMx = Imfilter(LM_k)
119    MMx = Imfilter(MM_k)
120
121    # Compute Temp with division
122    Cs2 = LMx / MMx
123
124    # Find indices where MMx is too small, Cs2 < 0, or Cs2 > 1
125    mask = (jnp.abs(MMx) < 1e-10) | (Cs2 < 0) | (Cs2 > 1)
126
127    # Apply the mask to set invalid values to zero
128    Cs2 = jnp.where(mask, 0.0, Cs2)
129
130    return Cs2
131
132
133# ============================================================
134# Main LASDD code
135# ============================================================
136
137@jax.jit
138def LASDD(
139        u, v, w,
140        S11, S22, S33,
141        S12, S13, S23,
142        S,
143        ZeRo3D):
144    """
145    Parameters:
146    -----------
147    u, v, w : ndarray
148        Velocity components
149    S11, S22, S33 : ndarray
150        Normal strain rate components
151    S12, S13, S23 : ndarray
152        Shear strain rate components
153    S : ndarray
154        Strain rate magnitude
155    ZeRo3D : ndarray
156        Pre-allocated zero arrays
157
158    Returns:
159    --------
160    u_, v_, w_ : ndarray
161        Interpolated velocity components
162    u_hat, v_hat, w_hat : ndarray
163        Level-1 filtered velocity components
164    u_hatd, v_hatd, w_hatd : ndarray
165        Level-2 filtered velocity components
166    S_hat, S_hatd : ndarray
167       Filtered strain rate magnitudes
168    Cs2_3D : ndarray
169        Cs2 field
170    Cs2_1D_avg1 : ndarray
171        1D profile of Cs2 (method 1: square of mean of sqrt)
172    Cs2_1D_avg2 : ndarray
173        1D profile of Cs2 (method 2: direct mean)
174    beta1_1D : ndarray
175        1D profile of beta1
176    """
177
178    u_ = u.copy()
179    v_ = v.copy()
180    w_ = ZeRo3D.copy()
181    w_ = w_.at[:, :, 1:nz - 1].set(StagGridAvg(w[:, :, 1:nz]))
182    w_ = w_.at[:, :, 0].set(0.5 * w[:, :, 1])
183    w_ = w_.at[:, :, nz - 1].set(w[:, :, nz - 2])
184
185    # Compute squared terms
186    uu, vv, ww = u_ ** 2, v_ ** 2, w_ ** 2
187    uv, uw, vw = u_ * v_, u_ * w_, v_ * w_
188
189    # Apply filtering
190    u_hat = Filtering_Level1(FFT(u_))
191    v_hat = Filtering_Level1(FFT(v_))
192    w_hat = Filtering_Level1(FFT(w_))
193    uu_hat = Filtering_Level1(FFT(uu))
194    vv_hat = Filtering_Level1(FFT(vv))
195    ww_hat = Filtering_Level1(FFT(ww))
196    uv_hat = Filtering_Level1(FFT(uv))
197    uw_hat = Filtering_Level1(FFT(uw))
198    vw_hat = Filtering_Level1(FFT(vw))
199
200    u_hatd = Filtering_Level2(FFT(u_))
201    v_hatd = Filtering_Level2(FFT(v_))
202    w_hatd = Filtering_Level2(FFT(w_))
203    uu_hatd = Filtering_Level2(FFT(uu))
204    vv_hatd = Filtering_Level2(FFT(vv))
205    ww_hatd = Filtering_Level2(FFT(ww))
206    uv_hatd = Filtering_Level2(FFT(uv))
207    uw_hatd = Filtering_Level2(FFT(uw))
208    vw_hatd = Filtering_Level2(FFT(vw))
209
210    # Filter strain rate components
211    S11_hat = Filtering_Level1(FFT(S11))
212    S22_hat = Filtering_Level1(FFT(S22))
213    S33_hat = Filtering_Level1(FFT(S33))
214    S12_hat = Filtering_Level1(FFT(S12))
215    S13_hat = Filtering_Level1(FFT(S13))
216    S23_hat = Filtering_Level1(FFT(S23))
217
218    S11_hatd = Filtering_Level2(FFT(S11))
219    S22_hatd = Filtering_Level2(FFT(S22))
220    S33_hatd = Filtering_Level2(FFT(S33))
221    S12_hatd = Filtering_Level2(FFT(S12))
222    S13_hatd = Filtering_Level2(FFT(S13))
223    S23_hatd = Filtering_Level2(FFT(S23))
224
225    # Compute filtered strain rate magnitudes
226    S_hat = jnp.sqrt(2 * (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
227                          2 * S12_hat ** 2 +
228                          2 * S13_hat ** 2 +
229                          2 * S23_hat ** 2))
230    S_hatd = jnp.sqrt(2 * (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
231                           2 * S12_hatd ** 2 +
232                           2 * S13_hatd ** 2 +
233                           2 * S23_hatd ** 2))
234
235    # Compute and filter strain rate products
236    SS11_hat = Filtering_Level1(FFT(S * S11))
237    SS22_hat = Filtering_Level1(FFT(S * S22))
238    SS33_hat = Filtering_Level1(FFT(S * S33))
239    SS12_hat = Filtering_Level1(FFT(S * S12))
240    SS13_hat = Filtering_Level1(FFT(S * S13))
241    SS23_hat = Filtering_Level1(FFT(S * S23))
242
243    SS11_hatd = Filtering_Level2(FFT(S * S11))
244    SS22_hatd = Filtering_Level2(FFT(S * S22))
245    SS33_hatd = Filtering_Level2(FFT(S * S33))
246    SS12_hatd = Filtering_Level2(FFT(S * S12))
247    SS13_hatd = Filtering_Level2(FFT(S * S13))
248    SS23_hatd = Filtering_Level2(FFT(S * S23))
249
250    # Compute L and Q tensors
251    L11, L22, L33 = (uu_hat - u_hat ** 2,
252                     vv_hat - v_hat ** 2,
253                     ww_hat - w_hat ** 2)
254    L12, L13, L23 = (uv_hat - u_hat * v_hat,
255                     uw_hat - u_hat * w_hat,
256                     vw_hat - v_hat * w_hat)
257
258    Q11, Q22, Q33 = (uu_hatd - u_hatd ** 2,
259                     vv_hatd - v_hatd ** 2,
260                     ww_hatd - w_hatd ** 2)
261    Q12, Q13, Q23 = (uv_hatd - u_hatd * v_hatd,
262                     uw_hatd - u_hatd * w_hatd,
263                     vw_hatd - v_hatd * w_hatd)
264
265    a1_terms = (L11 * SS11_hat + L22 * SS22_hat + L33 * SS33_hat +
266                2 * (L12 * SS12_hat + L13 * SS13_hat + L23 * SS23_hat))
267    a2_terms = (Q11 * SS11_hatd + Q22 * SS22_hatd + Q33 * SS33_hatd +
268                2 * (Q12 * SS12_hatd + Q13 * SS13_hatd + Q23 * SS23_hatd))
269
270    a1 = PlanarMean(2 * (L ** 2) * a1_terms)
271    a2 = PlanarMean(2 * (L ** 2) * a2_terms)
272
273    b1_terms = (L11 * S11_hat + L22 * S22_hat + L33 * S33_hat +
274                2 * (L12 * S12_hat + L13 * S13_hat + L23 * S23_hat))
275    b2_terms = (Q11 * S11_hatd + Q22 * S22_hatd + Q33 * S33_hatd +
276                2 * (Q12 * S12_hatd + Q13 * S13_hatd + Q23 * S23_hatd))
277
278    b1 = PlanarMean(2 * (L ** 2) * (TFR ** 2) * S_hat * b1_terms)
279    b2 = PlanarMean(2 * (L ** 2) * (TFR ** 4) * S_hatd * b2_terms)
280
281    c1_terms = (SS11_hat ** 2 + SS22_hat ** 2 + SS33_hat ** 2 +
282                2 * (SS12_hat ** 2 + SS13_hat ** 2 + SS23_hat ** 2))
283    c2_terms = (SS11_hatd ** 2 + SS22_hatd ** 2 + SS33_hatd ** 2 +
284                2 * (SS12_hatd ** 2 + SS13_hatd ** 2 + SS23_hatd ** 2))
285
286    c1 = PlanarMean((2 * L ** 2) ** 2 * c1_terms)
287    c2 = PlanarMean((2 * L ** 2) ** 2 * c2_terms)
288
289    d1_terms = (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
290                2 * (S12_hat ** 2 + S13_hat ** 2 + S23_hat ** 2))
291    d2_terms = (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
292                2 * (S12_hatd ** 2 + S13_hatd ** 2 + S23_hatd ** 2))
293
294    d1 = PlanarMean((4 * L ** 4) * (TFR ** 4) * (S_hat ** 2) * d1_terms)
295    d2 = PlanarMean((4 * L ** 4) * (TFR ** 8) * (S_hatd ** 2) * d2_terms)
296
297    e1_terms = (S11_hat * SS11_hat +
298                S22_hat * SS22_hat +
299                S33_hat * SS33_hat +
300                2 * (S12_hat * SS12_hat +
301                     S13_hat * SS13_hat +
302                     S23_hat * SS23_hat))
303    e2_terms = (S11_hatd * SS11_hatd +
304                S22_hatd * SS22_hatd +
305                S33_hatd * SS33_hatd +
306                2 * (S12_hatd * SS12_hatd +
307                     S13_hatd * SS13_hatd +
308                     S23_hatd * SS23_hatd))
309
310    e1 = PlanarMean((8 * L ** 4) * (TFR ** 2) * S_hat * e1_terms)
311    e2 = PlanarMean((8 * L ** 4) * (TFR ** 4) * S_hatd * e2_terms)
312
313    # Compute polynomial coefficients
314    aa = a1 * c2 - a2 * c1
315    bb = a2 * e1 - b1 * c2
316    cc = b2 * c1 - a1 * e2 - a2 * d1
317    dd = b1 * e2 - b2 * e1
318    ee = a1 * d2 + b2 * d1
319    ff = -b1 * d2
320
321    computeBeta = optSgs in [1, 2]
322    if computeBeta:
323        beta1_1D = ComputeBeta1(ff, ee, dd, cc, bb, aa)
324    else:
325        beta1_1D = jnp.ones(nz)
326    # Extend beta1 to 3D field
327    beta1_3D = jnp.broadcast_to(beta1_1D.reshape(1, 1, nz), (nx, ny, nz))
328
329    # Compute M terms
330    T1 = 2 * L ** 2
331    T2 = 2 * (TFR * L) ** 2
332    M11 = T1 * SS11_hat - T2 * beta1_3D * S_hat * S11_hat
333    M22 = T1 * SS22_hat - T2 * beta1_3D * S_hat * S22_hat
334    M33 = T1 * SS33_hat - T2 * beta1_3D * S_hat * S33_hat
335    M12 = T1 * SS12_hat - T2 * beta1_3D * S_hat * S12_hat
336    M13 = T1 * SS13_hat - T2 * beta1_3D * S_hat * S13_hat
337    M23 = T1 * SS23_hat - T2 * beta1_3D * S_hat * S23_hat
338
339    # Compute LM and MM terms
340    LM = ((L11 * M11 +
341          L22 * M22 +
342          L33 * M33) +
343          2 * (L12 * M12 +
344               L13 * M13 +
345               L23 * M23))
346
347    MM = (M11 ** 2 +
348          M22 ** 2 +
349          M33 ** 2 +
350          2 * (M12 ** 2 +
351               M13 ** 2 +
352               M23 ** 2))
353
354    # Compute Cs2_3D field for all levels using vmap
355    Cs2_3D = jax.vmap(Cs2_at_level_k, in_axes=(2, 2), out_axes=2)(LM, MM)
356
357    # Compute 1D averages from the 3D field
358    # First compute sqrt(Cs2_3D) for each level and then square the mean
359    Cs2_1D_avg1 = PlanarMean(jnp.sqrt(Cs2_3D)) ** 2
360    # Compute simple mean for each level
361    Cs2_1D_avg2 = PlanarMean(Cs2_3D)
362
363    return (u_, v_, w_,
364            u_hat, v_hat, w_hat,
365            u_hatd, v_hatd, w_hatd,
366            S_hat, S_hatd,
367            Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D)