SGS Model: LASDD-WL (Momentum)
Locally-Averaged Scale-Dependent Dynamic SGS model using the Wong-Lilly
base formulation. Called for optSgs = 2 (LASDD-WL) and optSgs = 4
(LAD-WL). For LAD variants (optSgs = 4), the scale-dependence parameter
beta is set to 1 rather than computed.
Source Code: DynamicSGS_LASDD_WL
DynamicSGS_LASDD_WL.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: DynamicSGS_LASDD_WL.py
18=============================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2026-5-9
24:Description: locally-averaged scale-dependent dynamic (LASDD) model
25 using the Wong-Lilly (1994) SGS base model (LASDD-WL).
26 Reference: Anderson, Basu, and Letchford (2007), EFM.
27"""
28
29# ============================================================
30# Imports
31# ============================================================
32
33import jax
34import jax.numpy as jnp
35
36# Import derived variables
37from ..config.DerivedVars import *
38
39# Import FFT modules
40from ..operations.FFT import FFT
41
42# Import filtering functions
43from ..operations.Filtering import Filtering_Level1, Filtering_Level2
44
45# Import helper functions
46from ..utilities.Utilities import PlanarMean, StagGridAvg
47from ..utilities.Utilities import Roots, Imfilter
48
49
50# ============================================================
51# Find maximum real root between 0 and 5
52# ============================================================
53
54@jax.jit
55def ComputeBeta1(ff, ee, dd, cc, bb, aa):
56 """
57 Solves the polynomial:
58 ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
59 for each vertical level to find the scale-dependent parameter beta1.
60
61 Parameters:
62 -----------
63 ff, ee, dd, cc, bb, aa : ndarray
64 1D arrays of polynomial coefficients at each vertical level
65
66 Returns:
67 --------
68 beta1 : ndarray
69 1D array of the maximum valid real root for each vertical level
70 """
71
72 def find_roots_for_level(k):
73 coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
74 guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
75 roots = jax.vmap(lambda guess:
76 Roots(coeffs, init_guess=guess))(guesses)
77
78 valid_roots = jnp.where(
79 (jnp.abs(jnp.imag(roots)) < 1e-6) &
80 (jnp.real(roots) > 0) &
81 (jnp.real(roots) < 5.0),
82 jnp.real(roots),
83 jnp.nan
84 )
85
86 max_root = jnp.nanmax(valid_roots)
87 return jnp.where(jnp.isnan(max_root), 1.0, max_root)
88
89 return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
90
91
92# ============================================================
93# Compute Wong-Lilly coefficient at vertical level k
94# ============================================================
95
96@jax.jit
97def Cwl_at_level_k(LM_k, MM_k):
98 """
99 Parameters:
100 -----------
101 LM_k : jax.numpy.ndarray
102 2D horizontal slice of LM at level k
103 MM_k : jax.numpy.ndarray
104 2D horizontal slice of MM at level k
105
106 Returns:
107 --------
108 Cwl : ndarray
109 2D array of the Wong-Lilly coefficient C_WL at level k
110 """
111
112 LMx = Imfilter(LM_k)
113 MMx = Imfilter(MM_k)
114
115 Cwl = LMx / MMx
116
117 mask = (jnp.abs(MMx) < 1e-10) | (Cwl < 0) | (Cwl > 1)
118 Cwl = jnp.where(mask, 0.0, Cwl)
119
120 return Cwl
121
122
123# ============================================================
124# Main LASDD-WL code
125# ============================================================
126
127@jax.jit
128def LASDD(
129 u, v, w,
130 S11, S22, S33,
131 S12, S13, S23,
132 ZeRo3D):
133 """
134 Locally-averaged scale-dependent dynamic model using the
135 Wong-Lilly SGS base model (LASDD-WL).
136
137 Parameters:
138 -----------
139 u, v, w : ndarray
140 Velocity components
141 S11, S22, S33 : ndarray
142 Normal strain rate components
143 S12, S13, S23 : ndarray
144 Shear strain rate components
145 ZeRo3D : ndarray
146 Pre-allocated zero array
147
148 Returns:
149 --------
150 u_, v_, w_ : ndarray
151 Interpolated velocity components
152 u_hat, v_hat, w_hat : ndarray
153 Level-1 filtered velocity components
154 u_hatd, v_hatd, w_hatd : ndarray
155 Level-2 filtered velocity components
156 S_hat, S_hatd : ndarray
157 Filtered strain rate magnitudes (diagnostic)
158 Cwl_3D : ndarray
159 3D field of C_WL coefficient
160 Cwl_1D_avg1 : ndarray
161 1D profile of C_WL (square of mean of sqrt)
162 Cwl_1D_avg2 : ndarray
163 1D profile of C_WL (direct mean)
164 beta1_1D : ndarray
165 1D profile of scale-dependence parameter beta1
166 """
167
168 u_ = u.copy()
169 v_ = v.copy()
170 w_ = ZeRo3D.copy()
171 w_ = w_.at[:, :, 1:nz - 1].set(StagGridAvg(w[:, :, 1:nz]))
172 w_ = w_.at[:, :, 0].set(0.5 * w[:, :, 1])
173 w_ = w_.at[:, :, nz - 1].set(w[:, :, nz - 2])
174
175 # Velocity products
176 uu, vv, ww = u_ ** 2, v_ ** 2, w_ ** 2
177 uv, uw, vw = u_ * v_, u_ * w_, v_ * w_
178
179 # Level-1 filtered velocities and products
180 u_hat = Filtering_Level1(FFT(u_))
181 v_hat = Filtering_Level1(FFT(v_))
182 w_hat = Filtering_Level1(FFT(w_))
183 uu_hat = Filtering_Level1(FFT(uu))
184 vv_hat = Filtering_Level1(FFT(vv))
185 ww_hat = Filtering_Level1(FFT(ww))
186 uv_hat = Filtering_Level1(FFT(uv))
187 uw_hat = Filtering_Level1(FFT(uw))
188 vw_hat = Filtering_Level1(FFT(vw))
189
190 # Level-2 filtered velocities and products
191 u_hatd = Filtering_Level2(FFT(u_))
192 v_hatd = Filtering_Level2(FFT(v_))
193 w_hatd = Filtering_Level2(FFT(w_))
194 uu_hatd = Filtering_Level2(FFT(uu))
195 vv_hatd = Filtering_Level2(FFT(vv))
196 ww_hatd = Filtering_Level2(FFT(ww))
197 uv_hatd = Filtering_Level2(FFT(uv))
198 uw_hatd = Filtering_Level2(FFT(uw))
199 vw_hatd = Filtering_Level2(FFT(vw))
200
201 # Filtered strain rate components
202 S11_hat = Filtering_Level1(FFT(S11))
203 S22_hat = Filtering_Level1(FFT(S22))
204 S33_hat = Filtering_Level1(FFT(S33))
205 S12_hat = Filtering_Level1(FFT(S12))
206 S13_hat = Filtering_Level1(FFT(S13))
207 S23_hat = Filtering_Level1(FFT(S23))
208
209 S11_hatd = Filtering_Level2(FFT(S11))
210 S22_hatd = Filtering_Level2(FFT(S22))
211 S33_hatd = Filtering_Level2(FFT(S33))
212 S12_hatd = Filtering_Level2(FFT(S12))
213 S13_hatd = Filtering_Level2(FFT(S13))
214 S23_hatd = Filtering_Level2(FFT(S23))
215
216 # Filtered strain rate magnitudes (diagnostic outputs)
217 S_hat = jnp.sqrt(2 * (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
218 2 * S12_hat ** 2 +
219 2 * S13_hat ** 2 +
220 2 * S23_hat ** 2))
221 S_hatd = jnp.sqrt(2 * (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
222 2 * S12_hatd ** 2 +
223 2 * S13_hatd ** 2 +
224 2 * S23_hatd ** 2))
225
226 # Leonard stress tensors L_ij (Level 1) and Q_ij (Level 2)
227 L11, L22, L33 = (uu_hat - u_hat ** 2,
228 vv_hat - v_hat ** 2,
229 ww_hat - w_hat ** 2)
230 L12, L13, L23 = (uv_hat - u_hat * v_hat,
231 uw_hat - u_hat * w_hat,
232 vw_hat - v_hat * w_hat)
233
234 Q11, Q22, Q33 = (uu_hatd - u_hatd ** 2,
235 vv_hatd - v_hatd ** 2,
236 ww_hatd - w_hatd ** 2)
237 Q12, Q13, Q23 = (uv_hatd - u_hatd * v_hatd,
238 uw_hatd - u_hatd * w_hatd,
239 vw_hatd - v_hatd * w_hatd)
240
241 # ----------------------------------------------------------
242 # WL polynomial coefficients (ABL07 Appendix, Eqs. A1-A10)
243 # Independent scalars: a1, a3, a6, a8
244 # ----------------------------------------------------------
245 a1_terms = (Q11 * S11_hat + Q22 * S22_hat + Q33 * S33_hat +
246 2 * (Q12 * S12_hat + Q13 * S13_hat + Q23 * S23_hat))
247 a1 = PlanarMean(a1_terms)
248
249 a3_terms = (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
250 2 * (S12_hat ** 2 + S13_hat ** 2 + S23_hat ** 2))
251 a3 = PlanarMean(a3_terms)
252
253 a6_terms = (L11 * S11_hat + L22 * S22_hat + L33 * S33_hat +
254 2 * (L12 * S12_hat + L13 * S13_hat + L23 * S23_hat))
255 a6 = PlanarMean(a6_terms)
256
257 a8_terms = (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
258 2 * (S12_hatd ** 2 + S13_hatd ** 2 + S23_hatd ** 2))
259 a8 = PlanarMean(a8_terms)
260
261 # Derived scalars (all expressible in terms of a1, a3, a6, a8)
262 a2 = -(TFR ** (8 / 3)) * a1
263 a4 = -2 * TFR ** (4 / 3) * a3
264 a5 = TFR ** (8 / 3) * a3
265 a7 = -TFR ** (4 / 3) * a6
266 a9 = -2 * TFR ** (8 / 3) * a8
267 a10 = TFR ** (16 / 3) * a8
268
269 # Polynomial coefficients A0...A5 mapped to aa...ff
270 aa = a1 * a3 - a6 * a8 # A0 (constant term)
271 bb = a1 * a4 - a7 * a8 # A1 (beta^1)
272 cc = a2 * a3 + a1 * a5 - a6 * a9 # A2 (beta^2)
273 dd = a2 * a4 - a7 * a9 # A3 (beta^3)
274 ee = a2 * a5 - a6 * a10 # A4 (beta^4)
275 ff = -a7 * a10 # A5 (beta^5)
276
277 computeBeta = optSgs in [1, 2]
278 if computeBeta:
279 beta1_1D = ComputeBeta1(ff, ee, dd, cc, bb, aa)
280 else:
281 beta1_1D = jnp.ones(nz)
282 beta1_3D = jnp.broadcast_to(beta1_1D.reshape(1, 1, nz), (nx, ny, nz))
283
284 # ----------------------------------------------------------
285 # WL M tensor: M_ij = 2*Δf^(4/3)*(S̄_ij - α^(4/3)*β*Ŝ_ij)
286 # ----------------------------------------------------------
287 T1 = 2 * L ** (4 / 3)
288 T2 = 2 * (TFR * L) ** (4 / 3)
289 M11 = T1 * S11_hat - T2 * beta1_3D * S11_hatd
290 M22 = T1 * S22_hat - T2 * beta1_3D * S22_hatd
291 M33 = T1 * S33_hat - T2 * beta1_3D * S33_hatd
292 M12 = T1 * S12_hat - T2 * beta1_3D * S12_hatd
293 M13 = T1 * S13_hat - T2 * beta1_3D * S13_hatd
294 M23 = T1 * S23_hat - T2 * beta1_3D * S23_hatd
295
296 # LM = L_ij * M_ij, MM = M_ij * M_ij
297 LM = ((L11 * M11 + L22 * M22 + L33 * M33) +
298 2 * (L12 * M12 + L13 * M13 + L23 * M23))
299
300 MM = (M11 ** 2 + M22 ** 2 + M33 ** 2 +
301 2 * (M12 ** 2 + M13 ** 2 + M23 ** 2))
302
303 # C_WL field: local 3x3 averaging via Imfilter
304 Cwl_3D = jax.vmap(Cwl_at_level_k, in_axes=(2, 2), out_axes=2)(LM, MM)
305
306 Cwl_1D_avg1 = PlanarMean(jnp.sqrt(Cwl_3D)) ** 2
307 Cwl_1D_avg2 = PlanarMean(Cwl_3D)
308
309 return (u_, v_, w_,
310 u_hat, v_hat, w_hat,
311 u_hatd, v_hatd, w_hatd,
312 S_hat, S_hatd,
313 Cwl_3D, Cwl_1D_avg1, Cwl_1D_avg2, beta1_1D)