SGS Model: LASDD-SM (Scalar)

Scalar (potential temperature) Locally-Averaged Scale-Dependent Dynamic SGS model using the Smagorinsky base formulation. Called for optSgs = 1 (LASDD-SM) and optSgs = 3 (LAD-SM). For LAD variants (optSgs = 3), the scalar scale-dependence parameter beta2 is set to 1.

Source Code: DynamicSGS_ScalarLASDD_SM

DynamicSGS_ScalarLASDD_SM.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: DynamicSGS_ScalarLASDD.py
 18===============================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: locally-averaged scale-dependent dynamic (LASDD) model
 25              for scalar transport
 26"""
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31
 32import jax
 33import jax.numpy as jnp
 34
 35# Import derived variables
 36from ..config.DerivedVars import *
 37
 38# Import FFT modules
 39from ..operations.FFT import FFT
 40
 41# Import filtering functions
 42from ..operations.Filtering import Filtering_Level1, Filtering_Level2
 43
 44# Import helper functions
 45from ..utilities.Utilities import PlanarMean, StagGridAvg
 46from ..utilities.Utilities import Roots, Imfilter
 47
 48
 49# ============================================================
 50# Find maximum real root 0 and 5, with focus on 0.5-1.5 range
 51# ============================================================
 52
 53@jax.jit
 54def ComputeBeta2(ff, ee, dd, cc, bb, aa):
 55    """
 56    This function solves the polynomial:
 57    ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
 58    for each vertical level to find the optimal parameter beta2
 59    used in the LASDD SGS model for scalar transport
 60
 61    Parameters:
 62    -----------
 63    ff, ee, dd, cc, bb, aa : ndarray
 64        1D arrays containing the polynomial coefficients at each vertical level
 65
 66    Returns:
 67    --------
 68    beta2 : ndarray
 69        1D array of the maximum valid real root for each vertical level
 70    """
 71
 72    def find_roots_for_level(k):
 73        # Construct polynomial coefficients for this level
 74        coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
 75
 76        # Use initial guesses concentrated in the expected range (0.5-1.5)
 77        # with a few wider points to catch outliers
 78        guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
 79        roots = jax.vmap(lambda guess:
 80                         Roots(coeffs, init_guess=guess))(guesses)
 81
 82        # Filter valid real roots
 83        valid_roots = jnp.where(
 84            (jnp.abs(jnp.imag(roots)) < 1e-6) &
 85            (jnp.real(roots) > 0) &
 86            (jnp.real(roots) < 5.0),
 87            jnp.real(roots),
 88            jnp.nan
 89        )
 90
 91        # Get maximum valid root or default to 1.0
 92        max_root = jnp.nanmax(valid_roots)
 93        return jnp.where(jnp.isnan(max_root), 1.0, max_root)
 94
 95    # Apply to all levels
 96    return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
 97
 98
 99# ===================================================
100# Compute Cs2PrRatio coefficient at vertical level k
101# ===================================================
102
103@jax.jit
104def Cs2PrRatio_at_level_k(T_up_k, T_dn_k):
105    """
106    Parameters:
107    -----------
108    T_up_k : ndarray
109        2D horizontal slice of T_up at level k
110    T_dn_k : ndarray
111        2D horizontal slice of T_dn at level k
112
113    Returns:
114    --------
115    Cs2PrRatio : ndarray
116        2D array of Cs2/Pr at level k
117    """
118
119    T_up_F = Imfilter(T_up_k)
120    T_dn_F = Imfilter(T_dn_k)
121
122    # Compute Cs2PrRatio
123    Cs2PrRatio = T_up_F / T_dn_F
124
125    # Find indices where T_dn_F is too small, Cs2PrRatio < 0, or Cs2PrRatio > 1
126    mask = (jnp.abs(T_dn_F) < 1e-10) | (Cs2PrRatio < 0) | (Cs2PrRatio > 1)
127
128    # Apply the mask to set invalid values to zero
129    Cs2PrRatio = jnp.where(mask, 0.0, Cs2PrRatio)
130
131    return Cs2PrRatio
132
133
134# ==============================================
135# Main LASDD code for SGS scalar transport model
136# ==============================================
137
138@jax.jit
139def ScalarLASDD(
140        u_, v_, w_,
141        u_hat, v_hat, w_hat,
142        u_hatd, v_hatd, w_hatd,
143        TH,
144        dTHdx, dTHdy, dTHdz,
145        S, S_hat, S_hatd,
146        ZeRo3D):
147    """
148
149    Parameters:
150    -----------
151    u_, v_, w_ : ndarray
152        Interpolated velocity fields
153    u_hat, v_hat, w_hat : ndarray
154        Level-1 filtered velocity components
155    u_hatd, v_hatd, w_hatd : ndarray
156        Level-2 filtered velocity components
157    TH : ndarray
158        Potential temperature field
159    dTHdx, dTHdy, dTHdz : ndarray
160        Potential temperature gradients
161    S, S_hat, S_hatd : ndarray
162        Strain rate magnitude and its filtered versions
163    ZeRo3D : ndarray
164        Pre-allocated zero arrays
165
166    Returns:
167    --------
168    Cs2PrRatio_3D : ndarray
169        Cs2PrRatio field
170    Cs2PrRatio_1D : ndarray
171        1D averaged profile of Cs2PrRatio
172    beta2_1D : ndarray
173        1D profile of beta2
174    """
175
176    # Subtract base state to avoid cancellation in Leonard flux products
177    # TH is stored as anomaly (TH - T_0), so Leonard flux uTH_hat - u_hat*TH_hat
178    # operates on small values (~0-5 K) with no catastrophic cancellation.
179    TH_ = TH
180
181    # Convert dTHdz to THz (at proper grid locations)
182    THz = ZeRo3D.copy()
183    THz = THz.at[:, :, 1:nz - 1].set(StagGridAvg(dTHdz[:, :, 1:nz]))
184    THz = THz.at[:, :, 0].set(dTHdz[:, :, 0])
185    THz = THz.at[:, :, nz - 1].set(dTHdz[:, :, nz - 1])
186
187    # Compute scalar products
188    uTH = u_ * TH_
189    vTH = v_ * TH_
190    wTH = w_ * TH_
191
192    # Apply filtering
193    TH_hat = Filtering_Level1(FFT(TH_))
194    uTH_hat = Filtering_Level1(FFT(uTH))
195    vTH_hat = Filtering_Level1(FFT(vTH))
196    wTH_hat = Filtering_Level1(FFT(wTH))
197
198    TH_hatd = Filtering_Level2(FFT(TH_))
199    uTH_hatd = Filtering_Level2(FFT(uTH))
200    vTH_hatd = Filtering_Level2(FFT(vTH))
201    wTH_hatd = Filtering_Level2(FFT(wTH))
202
203    dTHdx_hat = Filtering_Level1(FFT(dTHdx))
204    dTHdy_hat = Filtering_Level1(FFT(dTHdy))
205    dTHdz_hat = Filtering_Level1(FFT(THz))
206
207    dTHdx_hatd = Filtering_Level2(FFT(dTHdx))
208    dTHdy_hatd = Filtering_Level2(FFT(dTHdy))
209    dTHdz_hatd = Filtering_Level2(FFT(THz))
210
211    # Compute and filter strain-gradient products
212    SdTHdx = S * dTHdx
213    SdTHdy = S * dTHdy
214    SdTHdz = S * THz
215
216    SdTHdx_hat = Filtering_Level1(FFT(SdTHdx))
217    SdTHdy_hat = Filtering_Level1(FFT(SdTHdy))
218    SdTHdz_hat = Filtering_Level1(FFT(SdTHdz))
219
220    SdTHdx_hatd = Filtering_Level2(FFT(SdTHdx))
221    SdTHdy_hatd = Filtering_Level2(FFT(SdTHdy))
222    SdTHdz_hatd = Filtering_Level2(FFT(SdTHdz))
223
224    # Compute L and Q terms
225    LTH11 = uTH_hat - u_hat * TH_hat
226    LTH12 = vTH_hat - v_hat * TH_hat
227    LTH13 = wTH_hat - w_hat * TH_hat
228
229    QTH11 = uTH_hatd - u_hatd * TH_hatd
230    QTH12 = vTH_hatd - v_hatd * TH_hatd
231    QTH13 = wTH_hatd - w_hatd * TH_hatd
232
233    # Compute polynomial coefficients
234    a2_terms = (LTH11 * SdTHdx_hat +
235                LTH12 * SdTHdy_hat +
236                LTH13 * SdTHdz_hat)
237    a2 = PlanarMean((L ** 2) * a2_terms)
238
239    b2_terms = (LTH11 * S_hat * dTHdx_hat +
240                LTH12 * S_hat * dTHdy_hat +
241                LTH13 * S_hat * dTHdz_hat)
242    b2 = PlanarMean((-L ** 2 * TFR ** 2) * b2_terms)
243
244    c2_terms = (SdTHdx_hat ** 2 +
245                SdTHdy_hat ** 2 +
246                SdTHdz_hat ** 2)
247    c2 = PlanarMean((L ** 4) * c2_terms)
248
249    d2_terms = (SdTHdx_hat * S_hat * dTHdx_hat +
250                SdTHdy_hat * S_hat * dTHdy_hat +
251                SdTHdz_hat * S_hat * dTHdz_hat)
252    d2 = PlanarMean((-2 * L ** 4 * TFR ** 2) * d2_terms)
253
254    e2_terms = ((S_hat * dTHdx_hat) ** 2 +
255                (S_hat * dTHdy_hat) ** 2 +
256                (S_hat * dTHdz_hat) ** 2)
257    e2 = PlanarMean((L ** 4 * TFR ** 4) * e2_terms)
258
259    a4_terms = (QTH11 * SdTHdx_hatd +
260                QTH12 * SdTHdy_hatd +
261                QTH13 * SdTHdz_hatd)
262    a4 = PlanarMean((L ** 2) * a4_terms)
263
264    b4_terms = (QTH11 * S_hatd * dTHdx_hatd +
265                QTH12 * S_hatd * dTHdy_hatd +
266                QTH13 * S_hatd * dTHdz_hatd)
267    b4 = PlanarMean((-L ** 2 * TFR ** 4) * b4_terms)
268
269    c4_terms = (SdTHdx_hatd ** 2 +
270                SdTHdy_hatd ** 2 +
271                SdTHdz_hatd ** 2)
272    c4 = PlanarMean((L ** 4) * c4_terms)
273
274    d4_terms = (SdTHdx_hatd * S_hatd * dTHdx_hatd +
275                SdTHdy_hatd * S_hatd * dTHdy_hatd +
276                SdTHdz_hatd * S_hatd * dTHdz_hatd)
277    d4 = PlanarMean((-2 * L ** 4 * TFR ** 4) * d4_terms)
278
279    e4_terms = ((S_hatd * dTHdx_hatd) ** 2 +
280                (S_hatd * dTHdy_hatd) ** 2 +
281                (S_hatd * dTHdz_hatd) ** 2)
282    e4 = PlanarMean((L ** 4 * TFR ** 8) * e4_terms)
283
284    # Compute polynomial coefficients for beta2
285    aa = a2 * c4 - a4 * c2
286    bb = -a4 * d2 + b2 * c4
287    cc = -c2 * b4 + a2 * d4 - a4 * e2
288    dd = b2 * d4 - b4 * d2
289    ee = a2 * e4 - b4 * e2
290    ff = b2 * e4
291
292    # Compute beta2 for each vertical level
293    computeBeta = optSgs in [1, 2]
294    if computeBeta:
295        beta2_1D = ComputeBeta2(ff, ee, dd, cc, bb, aa)
296    else:
297        beta2_1D = jnp.ones(nz)
298
299    # Extend beta2 to 3D field
300    beta2_3D = jnp.broadcast_to(beta2_1D.reshape(1, 1, nz), (nx, ny, nz))
301
302    # Compute numerator and denominator for Cs2PrRatio
303    T_up = ((L ** 2) * a2_terms +
304            (-L ** 2 * TFR ** 2) * b2_terms * beta2_3D)
305    T_dn = ((L ** 4) * c2_terms +
306            (-2 * L ** 4 * TFR ** 2) * d2_terms * beta2_3D +
307            (L ** 4 * TFR ** 4) * e2_terms * beta2_3D ** 2)
308
309    # Compute Cs2PrRatio_3D field for all levels using vmap
310    Cs2PrRatio_3D = jax.vmap(Cs2PrRatio_at_level_k,
311                             in_axes=(2, 2), out_axes=2)(T_up, T_dn)
312
313    # Compute 1D average from the 3D field
314    Cs2PrRatio_1D = PlanarMean(Cs2PrRatio_3D)
315
316    return Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D