SGS Model: LASDD-WL (Scalar)

Scalar (potential temperature) Locally-Averaged Scale-Dependent Dynamic SGS model using the Wong-Lilly base formulation. Called for optSgs = 2 (LASDD-WL) and optSgs = 4 (LAD-WL). For LAD variants (optSgs = 4), the scalar scale-dependence parameter beta2 is set to 1.

Source Code: DynamicSGS_ScalarLASDD_WL

DynamicSGS_ScalarLASDD_WL.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: DynamicSGS_ScalarLASDD_WL.py
 18===================================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2026-5-9
 24:Description: locally-averaged scale-dependent dynamic (LASDD) model
 25              for scalar transport using the Wong-Lilly (1994) SGS base
 26              model (LASDD-WL).
 27              Reference: Anderson, Basu, and Letchford (2007), EFM.
 28"""
 29
 30# ============================================================
 31#  Imports
 32# ============================================================
 33
 34import jax
 35import jax.numpy as jnp
 36
 37# Import derived variables
 38from ..config.DerivedVars import *
 39
 40# Import FFT modules
 41from ..operations.FFT import FFT
 42
 43# Import filtering functions
 44from ..operations.Filtering import Filtering_Level1, Filtering_Level2
 45
 46# Import helper functions
 47from ..utilities.Utilities import PlanarMean, StagGridAvg
 48from ..utilities.Utilities import Roots, Imfilter
 49
 50
 51# ============================================================
 52# Find maximum real root between 0 and 5
 53# ============================================================
 54
 55@jax.jit
 56def ComputeBeta2(ff, ee, dd, cc, bb, aa):
 57    """
 58    Solves the polynomial:
 59    ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
 60    for each vertical level to find the scalar scale-dependence
 61    parameter beta2 used in the LASDD-WL scalar model.
 62
 63    Parameters:
 64    -----------
 65    ff, ee, dd, cc, bb, aa : ndarray
 66        1D arrays of polynomial coefficients at each vertical level
 67
 68    Returns:
 69    --------
 70    beta2 : ndarray
 71        1D array of the maximum valid real root for each vertical level
 72    """
 73
 74    def find_roots_for_level(k):
 75        coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
 76        guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
 77        roots = jax.vmap(lambda guess:
 78                         Roots(coeffs, init_guess=guess))(guesses)
 79
 80        valid_roots = jnp.where(
 81            (jnp.abs(jnp.imag(roots)) < 1e-6) &
 82            (jnp.real(roots) > 0) &
 83            (jnp.real(roots) < 5.0),
 84            jnp.real(roots),
 85            jnp.nan
 86        )
 87
 88        max_root = jnp.nanmax(valid_roots)
 89        return jnp.where(jnp.isnan(max_root), 1.0, max_root)
 90
 91    return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
 92
 93
 94# ============================================================
 95# Compute CwlPrRatio coefficient at vertical level k
 96# ============================================================
 97
 98@jax.jit
 99def CwlPrRatio_at_level_k(T_up_k, T_dn_k):
100    """
101    Parameters:
102    -----------
103    T_up_k : ndarray
104        2D horizontal slice of T_up at level k
105    T_dn_k : ndarray
106        2D horizontal slice of T_dn at level k
107
108    Returns:
109    --------
110    CwlPrRatio : ndarray
111        2D array of C_WL/Pr_t at level k
112    """
113
114    T_up_F = Imfilter(T_up_k)
115    T_dn_F = Imfilter(T_dn_k)
116
117    CwlPrRatio = T_up_F / T_dn_F
118
119    mask = (jnp.abs(T_dn_F) < 1e-10) | (CwlPrRatio < 0) | (CwlPrRatio > 1)
120    CwlPrRatio = jnp.where(mask, 0.0, CwlPrRatio)
121
122    return CwlPrRatio
123
124
125# ============================================================
126# Main LASDD-WL scalar code
127# ============================================================
128
129@jax.jit
130def ScalarLASDD(
131        u_, v_, w_,
132        u_hat, v_hat, w_hat,
133        u_hatd, v_hatd, w_hatd,
134        TH,
135        dTHdx, dTHdy, dTHdz,
136        ZeRo3D):
137    """
138    Locally-averaged scale-dependent dynamic scalar model using the
139    Wong-Lilly SGS base model (LASDD-WL).
140
141    Parameters:
142    -----------
143    u_, v_, w_ : ndarray
144        Interpolated velocity fields (from LASDD momentum step)
145    u_hat, v_hat, w_hat : ndarray
146        Level-1 filtered velocity components
147    u_hatd, v_hatd, w_hatd : ndarray
148        Level-2 filtered velocity components
149    TH : ndarray
150        Potential temperature field
151    dTHdx, dTHdy, dTHdz : ndarray
152        Potential temperature gradients
153    ZeRo3D : ndarray
154        Pre-allocated zero array
155
156    Returns:
157    --------
158    CwlPrRatio_3D : ndarray
159        3D field of C_WL/Pr_t
160    CwlPrRatio_1D : ndarray
161        1D averaged profile of C_WL/Pr_t
162    beta2_1D : ndarray
163        1D profile of scalar scale-dependence parameter beta2
164    """
165
166    # TH is stored as anomaly (TH - T_0), so Leonard flux uTH_hat - u_hat*TH_hat
167    # operates on small values (~0-5 K) with no catastrophic cancellation.
168    TH_ = TH
169
170    # Interpolate dTHdz to UVP nodes
171    THz = ZeRo3D.copy()
172    THz = THz.at[:, :, 1:nz - 1].set(StagGridAvg(dTHdz[:, :, 1:nz]))
173    THz = THz.at[:, :, 0].set(dTHdz[:, :, 0])
174    THz = THz.at[:, :, nz - 1].set(dTHdz[:, :, nz - 1])
175
176    # Scalar flux products
177    uTH = u_ * TH_
178    vTH = v_ * TH_
179    wTH = w_ * TH_
180
181    # Level-1 filtered scalar and flux products
182    TH_hat   = Filtering_Level1(FFT(TH_))
183    uTH_hat  = Filtering_Level1(FFT(uTH))
184    vTH_hat  = Filtering_Level1(FFT(vTH))
185    wTH_hat  = Filtering_Level1(FFT(wTH))
186
187    # Level-2 filtered scalar and flux products
188    TH_hatd  = Filtering_Level2(FFT(TH_))
189    uTH_hatd = Filtering_Level2(FFT(uTH))
190    vTH_hatd = Filtering_Level2(FFT(vTH))
191    wTH_hatd = Filtering_Level2(FFT(wTH))
192
193    # Filtered scalar gradients (Level 1 and Level 2)
194    dTHdx_hat  = Filtering_Level1(FFT(dTHdx))
195    dTHdy_hat  = Filtering_Level1(FFT(dTHdy))
196    dTHdz_hat  = Filtering_Level1(FFT(THz))
197
198    dTHdx_hatd = Filtering_Level2(FFT(dTHdx))
199    dTHdy_hatd = Filtering_Level2(FFT(dTHdy))
200    dTHdz_hatd = Filtering_Level2(FFT(THz))
201
202    # Scalar Leonard fluxes:
203    #   K'_i = LTH (Level 1), K_i = QTH (Level 2)
204    LTH11 = uTH_hat  - u_hat  * TH_hat
205    LTH12 = vTH_hat  - v_hat  * TH_hat
206    LTH13 = wTH_hat  - w_hat  * TH_hat
207
208    QTH11 = uTH_hatd - u_hatd * TH_hatd
209    QTH12 = vTH_hatd - v_hatd * TH_hatd
210    QTH13 = wTH_hatd - w_hatd * TH_hatd
211
212    # ----------------------------------------------------------
213    # WL scalar polynomial coefficients (ABL07 Appendix)
214    # Independent scalars: a1, a3, a6, a8
215    # ----------------------------------------------------------
216    a1_terms = (QTH11 * dTHdx_hat + QTH12 * dTHdy_hat + QTH13 * dTHdz_hat)
217    a1 = PlanarMean(a1_terms)
218
219    a3_terms = (dTHdx_hat ** 2 + dTHdy_hat ** 2 + dTHdz_hat ** 2)
220    a3 = PlanarMean(a3_terms)
221
222    a6_terms = (LTH11 * dTHdx_hat + LTH12 * dTHdy_hat + LTH13 * dTHdz_hat)
223    a6 = PlanarMean(a6_terms)
224
225    a8_terms = (dTHdx_hatd ** 2 + dTHdy_hatd ** 2 + dTHdz_hatd ** 2)
226    a8 = PlanarMean(a8_terms)
227
228    # Derived scalars
229    a2  = -(TFR ** (8 / 3)) * a1
230    a4  = -2 * TFR ** (4 / 3) * a3
231    a5  =  TFR ** (8 / 3) * a3
232    a7  = -TFR ** (4 / 3) * a6
233    a9  = -2 * TFR ** (8 / 3) * a8
234    a10 =  TFR ** (16 / 3) * a8
235
236    # Polynomial coefficients A0...A5 mapped to aa...ff
237    aa = a1 * a3 - a6 * a8            # A0
238    bb = a1 * a4 - a7 * a8            # A1
239    cc = a2 * a3 + a1 * a5 - a6 * a9  # A2
240    dd = a2 * a4 - a7 * a9            # A3
241    ee = a2 * a5 - a6 * a10           # A4
242    ff = -a7 * a10                    # A5
243
244    computeBeta = optSgs in [1, 2]
245    if computeBeta:
246        beta2_1D = ComputeBeta2(ff, ee, dd, cc, bb, aa)
247    else:
248        beta2_1D = jnp.ones(nz)
249    beta2_3D = jnp.broadcast_to(beta2_1D.reshape(1, 1, nz), (nx, ny, nz))
250
251    # ----------------------------------------------------------
252    # WL scalar T_up and T_dn for CwlPrRatio
253    # T_up = L^(4/3) * (K'_i * ∂_i c̄  -  α^(4/3)*β2 * K'_i * ∂_i ĉ)
254    # T_dn = L^(8/3) * (|∂_i c̄|² - 2*α^(4/3)*β2*(∂_i c̄·∂_i ĉ) + α^(8/3)*β2²*|∂_i ĉ|²)
255    # ----------------------------------------------------------
256    b6_terms  = (LTH11 * dTHdx_hatd +
257                 LTH12 * dTHdy_hatd +
258                 LTH13 * dTHdz_hatd)
259
260    c36_terms = (dTHdx_hat * dTHdx_hatd +
261                 dTHdy_hat * dTHdy_hatd +
262                 dTHdz_hat * dTHdz_hatd)
263
264    T_up = L ** (4 / 3) * (a6_terms - TFR ** (4 / 3) * beta2_3D * b6_terms)
265
266    T_dn = L ** (8 / 3) * (a3_terms
267                            - 2 * TFR ** (4 / 3) * beta2_3D * c36_terms
268                            + TFR ** (8 / 3) * beta2_3D ** 2 * a8_terms)
269
270    CwlPrRatio_3D = jax.vmap(CwlPrRatio_at_level_k,
271                              in_axes=(2, 2), out_axes=2)(T_up, T_dn)
272
273    CwlPrRatio_1D = PlanarMean(CwlPrRatio_3D)
274
275    return CwlPrRatio_3D, CwlPrRatio_1D, beta2_1D