SGS Model: LASDD-WL (Scalar)
Scalar (potential temperature) Locally-Averaged Scale-Dependent Dynamic
SGS model using the Wong-Lilly base formulation. Called for optSgs = 2
(LASDD-WL) and optSgs = 4 (LAD-WL). For LAD variants (optSgs = 4),
the scalar scale-dependence parameter beta2 is set to 1.
Source Code: DynamicSGS_ScalarLASDD_WL
DynamicSGS_ScalarLASDD_WL.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: DynamicSGS_ScalarLASDD_WL.py
18===================================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2026-5-9
24:Description: locally-averaged scale-dependent dynamic (LASDD) model
25 for scalar transport using the Wong-Lilly (1994) SGS base
26 model (LASDD-WL).
27 Reference: Anderson, Basu, and Letchford (2007), EFM.
28"""
29
30# ============================================================
31# Imports
32# ============================================================
33
34import jax
35import jax.numpy as jnp
36
37# Import derived variables
38from ..config.DerivedVars import *
39
40# Import FFT modules
41from ..operations.FFT import FFT
42
43# Import filtering functions
44from ..operations.Filtering import Filtering_Level1, Filtering_Level2
45
46# Import helper functions
47from ..utilities.Utilities import PlanarMean, StagGridAvg
48from ..utilities.Utilities import Roots, Imfilter
49
50
51# ============================================================
52# Find maximum real root between 0 and 5
53# ============================================================
54
55@jax.jit
56def ComputeBeta2(ff, ee, dd, cc, bb, aa):
57 """
58 Solves the polynomial:
59 ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
60 for each vertical level to find the scalar scale-dependence
61 parameter beta2 used in the LASDD-WL scalar model.
62
63 Parameters:
64 -----------
65 ff, ee, dd, cc, bb, aa : ndarray
66 1D arrays of polynomial coefficients at each vertical level
67
68 Returns:
69 --------
70 beta2 : ndarray
71 1D array of the maximum valid real root for each vertical level
72 """
73
74 def find_roots_for_level(k):
75 coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
76 guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
77 roots = jax.vmap(lambda guess:
78 Roots(coeffs, init_guess=guess))(guesses)
79
80 valid_roots = jnp.where(
81 (jnp.abs(jnp.imag(roots)) < 1e-6) &
82 (jnp.real(roots) > 0) &
83 (jnp.real(roots) < 5.0),
84 jnp.real(roots),
85 jnp.nan
86 )
87
88 max_root = jnp.nanmax(valid_roots)
89 return jnp.where(jnp.isnan(max_root), 1.0, max_root)
90
91 return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
92
93
94# ============================================================
95# Compute CwlPrRatio coefficient at vertical level k
96# ============================================================
97
98@jax.jit
99def CwlPrRatio_at_level_k(T_up_k, T_dn_k):
100 """
101 Parameters:
102 -----------
103 T_up_k : ndarray
104 2D horizontal slice of T_up at level k
105 T_dn_k : ndarray
106 2D horizontal slice of T_dn at level k
107
108 Returns:
109 --------
110 CwlPrRatio : ndarray
111 2D array of C_WL/Pr_t at level k
112 """
113
114 T_up_F = Imfilter(T_up_k)
115 T_dn_F = Imfilter(T_dn_k)
116
117 CwlPrRatio = T_up_F / T_dn_F
118
119 mask = (jnp.abs(T_dn_F) < 1e-10) | (CwlPrRatio < 0) | (CwlPrRatio > 1)
120 CwlPrRatio = jnp.where(mask, 0.0, CwlPrRatio)
121
122 return CwlPrRatio
123
124
125# ============================================================
126# Main LASDD-WL scalar code
127# ============================================================
128
129@jax.jit
130def ScalarLASDD(
131 u_, v_, w_,
132 u_hat, v_hat, w_hat,
133 u_hatd, v_hatd, w_hatd,
134 TH,
135 dTHdx, dTHdy, dTHdz,
136 ZeRo3D):
137 """
138 Locally-averaged scale-dependent dynamic scalar model using the
139 Wong-Lilly SGS base model (LASDD-WL).
140
141 Parameters:
142 -----------
143 u_, v_, w_ : ndarray
144 Interpolated velocity fields (from LASDD momentum step)
145 u_hat, v_hat, w_hat : ndarray
146 Level-1 filtered velocity components
147 u_hatd, v_hatd, w_hatd : ndarray
148 Level-2 filtered velocity components
149 TH : ndarray
150 Potential temperature field
151 dTHdx, dTHdy, dTHdz : ndarray
152 Potential temperature gradients
153 ZeRo3D : ndarray
154 Pre-allocated zero array
155
156 Returns:
157 --------
158 CwlPrRatio_3D : ndarray
159 3D field of C_WL/Pr_t
160 CwlPrRatio_1D : ndarray
161 1D averaged profile of C_WL/Pr_t
162 beta2_1D : ndarray
163 1D profile of scalar scale-dependence parameter beta2
164 """
165
166 # TH is stored as anomaly (TH - T_0), so Leonard flux uTH_hat - u_hat*TH_hat
167 # operates on small values (~0-5 K) with no catastrophic cancellation.
168 TH_ = TH
169
170 # Interpolate dTHdz to UVP nodes
171 THz = ZeRo3D.copy()
172 THz = THz.at[:, :, 1:nz - 1].set(StagGridAvg(dTHdz[:, :, 1:nz]))
173 THz = THz.at[:, :, 0].set(dTHdz[:, :, 0])
174 THz = THz.at[:, :, nz - 1].set(dTHdz[:, :, nz - 1])
175
176 # Scalar flux products
177 uTH = u_ * TH_
178 vTH = v_ * TH_
179 wTH = w_ * TH_
180
181 # Level-1 filtered scalar and flux products
182 TH_hat = Filtering_Level1(FFT(TH_))
183 uTH_hat = Filtering_Level1(FFT(uTH))
184 vTH_hat = Filtering_Level1(FFT(vTH))
185 wTH_hat = Filtering_Level1(FFT(wTH))
186
187 # Level-2 filtered scalar and flux products
188 TH_hatd = Filtering_Level2(FFT(TH_))
189 uTH_hatd = Filtering_Level2(FFT(uTH))
190 vTH_hatd = Filtering_Level2(FFT(vTH))
191 wTH_hatd = Filtering_Level2(FFT(wTH))
192
193 # Filtered scalar gradients (Level 1 and Level 2)
194 dTHdx_hat = Filtering_Level1(FFT(dTHdx))
195 dTHdy_hat = Filtering_Level1(FFT(dTHdy))
196 dTHdz_hat = Filtering_Level1(FFT(THz))
197
198 dTHdx_hatd = Filtering_Level2(FFT(dTHdx))
199 dTHdy_hatd = Filtering_Level2(FFT(dTHdy))
200 dTHdz_hatd = Filtering_Level2(FFT(THz))
201
202 # Scalar Leonard fluxes:
203 # K'_i = LTH (Level 1), K_i = QTH (Level 2)
204 LTH11 = uTH_hat - u_hat * TH_hat
205 LTH12 = vTH_hat - v_hat * TH_hat
206 LTH13 = wTH_hat - w_hat * TH_hat
207
208 QTH11 = uTH_hatd - u_hatd * TH_hatd
209 QTH12 = vTH_hatd - v_hatd * TH_hatd
210 QTH13 = wTH_hatd - w_hatd * TH_hatd
211
212 # ----------------------------------------------------------
213 # WL scalar polynomial coefficients (ABL07 Appendix)
214 # Independent scalars: a1, a3, a6, a8
215 # ----------------------------------------------------------
216 a1_terms = (QTH11 * dTHdx_hat + QTH12 * dTHdy_hat + QTH13 * dTHdz_hat)
217 a1 = PlanarMean(a1_terms)
218
219 a3_terms = (dTHdx_hat ** 2 + dTHdy_hat ** 2 + dTHdz_hat ** 2)
220 a3 = PlanarMean(a3_terms)
221
222 a6_terms = (LTH11 * dTHdx_hat + LTH12 * dTHdy_hat + LTH13 * dTHdz_hat)
223 a6 = PlanarMean(a6_terms)
224
225 a8_terms = (dTHdx_hatd ** 2 + dTHdy_hatd ** 2 + dTHdz_hatd ** 2)
226 a8 = PlanarMean(a8_terms)
227
228 # Derived scalars
229 a2 = -(TFR ** (8 / 3)) * a1
230 a4 = -2 * TFR ** (4 / 3) * a3
231 a5 = TFR ** (8 / 3) * a3
232 a7 = -TFR ** (4 / 3) * a6
233 a9 = -2 * TFR ** (8 / 3) * a8
234 a10 = TFR ** (16 / 3) * a8
235
236 # Polynomial coefficients A0...A5 mapped to aa...ff
237 aa = a1 * a3 - a6 * a8 # A0
238 bb = a1 * a4 - a7 * a8 # A1
239 cc = a2 * a3 + a1 * a5 - a6 * a9 # A2
240 dd = a2 * a4 - a7 * a9 # A3
241 ee = a2 * a5 - a6 * a10 # A4
242 ff = -a7 * a10 # A5
243
244 computeBeta = optSgs in [1, 2]
245 if computeBeta:
246 beta2_1D = ComputeBeta2(ff, ee, dd, cc, bb, aa)
247 else:
248 beta2_1D = jnp.ones(nz)
249 beta2_3D = jnp.broadcast_to(beta2_1D.reshape(1, 1, nz), (nx, ny, nz))
250
251 # ----------------------------------------------------------
252 # WL scalar T_up and T_dn for CwlPrRatio
253 # T_up = L^(4/3) * (K'_i * ∂_i c̄ - α^(4/3)*β2 * K'_i * ∂_i ĉ)
254 # T_dn = L^(8/3) * (|∂_i c̄|² - 2*α^(4/3)*β2*(∂_i c̄·∂_i ĉ) + α^(8/3)*β2²*|∂_i ĉ|²)
255 # ----------------------------------------------------------
256 b6_terms = (LTH11 * dTHdx_hatd +
257 LTH12 * dTHdy_hatd +
258 LTH13 * dTHdz_hatd)
259
260 c36_terms = (dTHdx_hat * dTHdx_hatd +
261 dTHdy_hat * dTHdy_hatd +
262 dTHdz_hat * dTHdz_hatd)
263
264 T_up = L ** (4 / 3) * (a6_terms - TFR ** (4 / 3) * beta2_3D * b6_terms)
265
266 T_dn = L ** (8 / 3) * (a3_terms
267 - 2 * TFR ** (4 / 3) * beta2_3D * c36_terms
268 + TFR ** (8 / 3) * beta2_3D ** 2 * a8_terms)
269
270 CwlPrRatio_3D = jax.vmap(CwlPrRatio_at_level_k,
271 in_axes=(2, 2), out_axes=2)(T_up, T_dn)
272
273 CwlPrRatio_1D = PlanarMean(CwlPrRatio_3D)
274
275 return CwlPrRatio_3D, CwlPrRatio_1D, beta2_1D