SGS Model: LASDD-SM (Scalar)
Scalar (potential temperature) Locally-Averaged Scale-Dependent Dynamic
SGS model using the Smagorinsky base formulation. Called for optSgs = 1
(LASDD-SM) and optSgs = 3 (LAD-SM). For LAD variants (optSgs = 3),
the scalar scale-dependence parameter beta2 is set to 1.
Source Code: DynamicSGS_ScalarLASDD_SM
DynamicSGS_ScalarLASDD_SM.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: DynamicSGS_ScalarLASDD.py
18===============================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-29
24:Description: locally-averaged scale-dependent dynamic (LASDD) model
25 for scalar transport
26"""
27
28# ============================================================
29# Imports
30# ============================================================
31
32import jax
33import jax.numpy as jnp
34
35# Import derived variables
36from ..config.DerivedVars import *
37
38# Import FFT modules
39from ..operations.FFT import FFT
40
41# Import filtering functions
42from ..operations.Filtering import Filtering_Level1, Filtering_Level2
43
44# Import helper functions
45from ..utilities.Utilities import PlanarMean, StagGridAvg
46from ..utilities.Utilities import Roots, Imfilter
47
48
49# ============================================================
50# Find maximum real root 0 and 5, with focus on 0.5-1.5 range
51# ============================================================
52
53@jax.jit
54def ComputeBeta2(ff, ee, dd, cc, bb, aa):
55 """
56 This function solves the polynomial:
57 ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
58 for each vertical level to find the optimal parameter beta2
59 used in the LASDD SGS model for scalar transport
60
61 Parameters:
62 -----------
63 ff, ee, dd, cc, bb, aa : ndarray
64 1D arrays containing the polynomial coefficients at each vertical level
65
66 Returns:
67 --------
68 beta2 : ndarray
69 1D array of the maximum valid real root for each vertical level
70 """
71
72 def find_roots_for_level(k):
73 # Construct polynomial coefficients for this level
74 coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
75
76 # Use initial guesses concentrated in the expected range (0.5-1.5)
77 # with a few wider points to catch outliers
78 guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
79 roots = jax.vmap(lambda guess:
80 Roots(coeffs, init_guess=guess))(guesses)
81
82 # Filter valid real roots
83 valid_roots = jnp.where(
84 (jnp.abs(jnp.imag(roots)) < 1e-6) &
85 (jnp.real(roots) > 0) &
86 (jnp.real(roots) < 5.0),
87 jnp.real(roots),
88 jnp.nan
89 )
90
91 # Get maximum valid root or default to 1.0
92 max_root = jnp.nanmax(valid_roots)
93 return jnp.where(jnp.isnan(max_root), 1.0, max_root)
94
95 # Apply to all levels
96 return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
97
98
99# ===================================================
100# Compute Cs2PrRatio coefficient at vertical level k
101# ===================================================
102
103@jax.jit
104def Cs2PrRatio_at_level_k(T_up_k, T_dn_k):
105 """
106 Parameters:
107 -----------
108 T_up_k : ndarray
109 2D horizontal slice of T_up at level k
110 T_dn_k : ndarray
111 2D horizontal slice of T_dn at level k
112
113 Returns:
114 --------
115 Cs2PrRatio : ndarray
116 2D array of Cs2/Pr at level k
117 """
118
119 T_up_F = Imfilter(T_up_k)
120 T_dn_F = Imfilter(T_dn_k)
121
122 # Compute Cs2PrRatio
123 Cs2PrRatio = T_up_F / T_dn_F
124
125 # Find indices where T_dn_F is too small, Cs2PrRatio < 0, or Cs2PrRatio > 1
126 mask = (jnp.abs(T_dn_F) < 1e-10) | (Cs2PrRatio < 0) | (Cs2PrRatio > 1)
127
128 # Apply the mask to set invalid values to zero
129 Cs2PrRatio = jnp.where(mask, 0.0, Cs2PrRatio)
130
131 return Cs2PrRatio
132
133
134# ==============================================
135# Main LASDD code for SGS scalar transport model
136# ==============================================
137
138@jax.jit
139def ScalarLASDD(
140 u_, v_, w_,
141 u_hat, v_hat, w_hat,
142 u_hatd, v_hatd, w_hatd,
143 TH,
144 dTHdx, dTHdy, dTHdz,
145 S, S_hat, S_hatd,
146 ZeRo3D):
147 """
148
149 Parameters:
150 -----------
151 u_, v_, w_ : ndarray
152 Interpolated velocity fields
153 u_hat, v_hat, w_hat : ndarray
154 Level-1 filtered velocity components
155 u_hatd, v_hatd, w_hatd : ndarray
156 Level-2 filtered velocity components
157 TH : ndarray
158 Potential temperature field
159 dTHdx, dTHdy, dTHdz : ndarray
160 Potential temperature gradients
161 S, S_hat, S_hatd : ndarray
162 Strain rate magnitude and its filtered versions
163 ZeRo3D : ndarray
164 Pre-allocated zero arrays
165
166 Returns:
167 --------
168 Cs2PrRatio_3D : ndarray
169 Cs2PrRatio field
170 Cs2PrRatio_1D : ndarray
171 1D averaged profile of Cs2PrRatio
172 beta2_1D : ndarray
173 1D profile of beta2
174 """
175
176 # Subtract base state to avoid cancellation in Leonard flux products
177 # TH is stored as anomaly (TH - T_0), so Leonard flux uTH_hat - u_hat*TH_hat
178 # operates on small values (~0-5 K) with no catastrophic cancellation.
179 TH_ = TH
180
181 # Convert dTHdz to THz (at proper grid locations)
182 THz = ZeRo3D.copy()
183 THz = THz.at[:, :, 1:nz - 1].set(StagGridAvg(dTHdz[:, :, 1:nz]))
184 THz = THz.at[:, :, 0].set(dTHdz[:, :, 0])
185 THz = THz.at[:, :, nz - 1].set(dTHdz[:, :, nz - 1])
186
187 # Compute scalar products
188 uTH = u_ * TH_
189 vTH = v_ * TH_
190 wTH = w_ * TH_
191
192 # Apply filtering
193 TH_hat = Filtering_Level1(FFT(TH_))
194 uTH_hat = Filtering_Level1(FFT(uTH))
195 vTH_hat = Filtering_Level1(FFT(vTH))
196 wTH_hat = Filtering_Level1(FFT(wTH))
197
198 TH_hatd = Filtering_Level2(FFT(TH_))
199 uTH_hatd = Filtering_Level2(FFT(uTH))
200 vTH_hatd = Filtering_Level2(FFT(vTH))
201 wTH_hatd = Filtering_Level2(FFT(wTH))
202
203 dTHdx_hat = Filtering_Level1(FFT(dTHdx))
204 dTHdy_hat = Filtering_Level1(FFT(dTHdy))
205 dTHdz_hat = Filtering_Level1(FFT(THz))
206
207 dTHdx_hatd = Filtering_Level2(FFT(dTHdx))
208 dTHdy_hatd = Filtering_Level2(FFT(dTHdy))
209 dTHdz_hatd = Filtering_Level2(FFT(THz))
210
211 # Compute and filter strain-gradient products
212 SdTHdx = S * dTHdx
213 SdTHdy = S * dTHdy
214 SdTHdz = S * THz
215
216 SdTHdx_hat = Filtering_Level1(FFT(SdTHdx))
217 SdTHdy_hat = Filtering_Level1(FFT(SdTHdy))
218 SdTHdz_hat = Filtering_Level1(FFT(SdTHdz))
219
220 SdTHdx_hatd = Filtering_Level2(FFT(SdTHdx))
221 SdTHdy_hatd = Filtering_Level2(FFT(SdTHdy))
222 SdTHdz_hatd = Filtering_Level2(FFT(SdTHdz))
223
224 # Compute L and Q terms
225 LTH11 = uTH_hat - u_hat * TH_hat
226 LTH12 = vTH_hat - v_hat * TH_hat
227 LTH13 = wTH_hat - w_hat * TH_hat
228
229 QTH11 = uTH_hatd - u_hatd * TH_hatd
230 QTH12 = vTH_hatd - v_hatd * TH_hatd
231 QTH13 = wTH_hatd - w_hatd * TH_hatd
232
233 # Compute polynomial coefficients
234 a2_terms = (LTH11 * SdTHdx_hat +
235 LTH12 * SdTHdy_hat +
236 LTH13 * SdTHdz_hat)
237 a2 = PlanarMean((L ** 2) * a2_terms)
238
239 b2_terms = (LTH11 * S_hat * dTHdx_hat +
240 LTH12 * S_hat * dTHdy_hat +
241 LTH13 * S_hat * dTHdz_hat)
242 b2 = PlanarMean((-L ** 2 * TFR ** 2) * b2_terms)
243
244 c2_terms = (SdTHdx_hat ** 2 +
245 SdTHdy_hat ** 2 +
246 SdTHdz_hat ** 2)
247 c2 = PlanarMean((L ** 4) * c2_terms)
248
249 d2_terms = (SdTHdx_hat * S_hat * dTHdx_hat +
250 SdTHdy_hat * S_hat * dTHdy_hat +
251 SdTHdz_hat * S_hat * dTHdz_hat)
252 d2 = PlanarMean((-2 * L ** 4 * TFR ** 2) * d2_terms)
253
254 e2_terms = ((S_hat * dTHdx_hat) ** 2 +
255 (S_hat * dTHdy_hat) ** 2 +
256 (S_hat * dTHdz_hat) ** 2)
257 e2 = PlanarMean((L ** 4 * TFR ** 4) * e2_terms)
258
259 a4_terms = (QTH11 * SdTHdx_hatd +
260 QTH12 * SdTHdy_hatd +
261 QTH13 * SdTHdz_hatd)
262 a4 = PlanarMean((L ** 2) * a4_terms)
263
264 b4_terms = (QTH11 * S_hatd * dTHdx_hatd +
265 QTH12 * S_hatd * dTHdy_hatd +
266 QTH13 * S_hatd * dTHdz_hatd)
267 b4 = PlanarMean((-L ** 2 * TFR ** 4) * b4_terms)
268
269 c4_terms = (SdTHdx_hatd ** 2 +
270 SdTHdy_hatd ** 2 +
271 SdTHdz_hatd ** 2)
272 c4 = PlanarMean((L ** 4) * c4_terms)
273
274 d4_terms = (SdTHdx_hatd * S_hatd * dTHdx_hatd +
275 SdTHdy_hatd * S_hatd * dTHdy_hatd +
276 SdTHdz_hatd * S_hatd * dTHdz_hatd)
277 d4 = PlanarMean((-2 * L ** 4 * TFR ** 4) * d4_terms)
278
279 e4_terms = ((S_hatd * dTHdx_hatd) ** 2 +
280 (S_hatd * dTHdy_hatd) ** 2 +
281 (S_hatd * dTHdz_hatd) ** 2)
282 e4 = PlanarMean((L ** 4 * TFR ** 8) * e4_terms)
283
284 # Compute polynomial coefficients for beta2
285 aa = a2 * c4 - a4 * c2
286 bb = -a4 * d2 + b2 * c4
287 cc = -c2 * b4 + a2 * d4 - a4 * e2
288 dd = b2 * d4 - b4 * d2
289 ee = a2 * e4 - b4 * e2
290 ff = b2 * e4
291
292 # Compute beta2 for each vertical level
293 computeBeta = optSgs in [1, 2]
294 if computeBeta:
295 beta2_1D = ComputeBeta2(ff, ee, dd, cc, bb, aa)
296 else:
297 beta2_1D = jnp.ones(nz)
298
299 # Extend beta2 to 3D field
300 beta2_3D = jnp.broadcast_to(beta2_1D.reshape(1, 1, nz), (nx, ny, nz))
301
302 # Compute numerator and denominator for Cs2PrRatio
303 T_up = ((L ** 2) * a2_terms +
304 (-L ** 2 * TFR ** 2) * b2_terms * beta2_3D)
305 T_dn = ((L ** 4) * c2_terms +
306 (-2 * L ** 4 * TFR ** 2) * d2_terms * beta2_3D +
307 (L ** 4 * TFR ** 4) * e2_terms * beta2_3D ** 2)
308
309 # Compute Cs2PrRatio_3D field for all levels using vmap
310 Cs2PrRatio_3D = jax.vmap(Cs2PrRatio_at_level_k,
311 in_axes=(2, 2), out_axes=2)(T_up, T_dn)
312
313 # Compute 1D average from the 3D field
314 Cs2PrRatio_1D = PlanarMean(Cs2PrRatio_3D)
315
316 return Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D