SGS Model: LASDD-SM (Momentum)
Locally-Averaged Scale-Dependent Dynamic SGS model using the Smagorinsky
base formulation. Called for optSgs = 1 (LASDD-SM) and optSgs = 3
(LAD-SM). For LAD variants (optSgs = 3), the scale-dependence parameter
beta is set to 1 rather than computed.
Source Code: DynamicSGS_LASDD_SM
DynamicSGS_LASDD_SM.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: DynamicSGS_LASDD.py
18=========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-29
24:Description: locally-averaged scale-dependent dynamic (LASDD) model
25"""
26
27# ============================================================
28# Imports
29# ============================================================
30
31import jax
32import jax.numpy as jnp
33
34# Import derived variables
35from ..config.DerivedVars import *
36
37# Import FFT modules
38from ..operations.FFT import FFT
39
40# Import filtering functions
41from ..operations.Filtering import Filtering_Level1, Filtering_Level2
42
43# Import helper functions
44from ..utilities.Utilities import PlanarMean, StagGridAvg
45from ..utilities.Utilities import Roots, Imfilter
46
47
48# ============================================================
49# Find maximum real root 0 and 5, with focus on 0.5-1.5 range
50# ============================================================
51
52@jax.jit
53def ComputeBeta1(ff, ee, dd, cc, bb, aa):
54 """
55 This function solves the polynomial:
56 ff*x^5 + ee*x^4 + dd*x^3 + cc*x^2 + bb*x + aa = 0
57 for each vertical level to find the optimal parameter beta1
58 used in the LASDD SGS model
59
60 Parameters:
61 -----------
62 ff, ee, dd, cc, bb, aa : ndarray
63 1D arrays containing the polynomial coefficients at each vertical level
64
65 Returns:
66 --------
67 beta1 : ndarray
68 1D array of the maximum valid real root for each vertical level
69 """
70
71 def find_roots_for_level(k):
72 # Construct polynomial coefficients for this level
73 coeffs = jnp.array([ff[k], ee[k], dd[k], cc[k], bb[k], aa[k]])
74
75 # Use initial guesses concentrated in the expected range (0.5-1.5)
76 # with a few wider points to catch outliers
77 guesses = jnp.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.5])
78 roots = jax.vmap(lambda guess:
79 Roots(coeffs, init_guess=guess))(guesses)
80
81 # Filter valid real roots
82 valid_roots = jnp.where(
83 (jnp.abs(jnp.imag(roots)) < 1e-6) &
84 (jnp.real(roots) > 0) &
85 (jnp.real(roots) < 5.0),
86 jnp.real(roots),
87 jnp.nan
88 )
89
90 # Get maximum valid root or default to 1.0
91 max_root = jnp.nanmax(valid_roots)
92 return jnp.where(jnp.isnan(max_root), 1.0, max_root)
93
94 # Apply to all levels
95 return jax.vmap(find_roots_for_level)(jnp.arange(ff.shape[0]))
96
97
98# ===================================================
99# Compute Smagorinsky coefficient at vertical level k
100# ===================================================
101
102@jax.jit
103def Cs2_at_level_k(LM_k, MM_k):
104 """
105 Parameters:
106 -----------
107 LM_k : jax.numpy.ndarray
108 2D horizontal slice of LM at level k
109 MM_k : jax.numpy.ndarray
110 2D horizontal slice of MM at level k
111
112 Returns:
113 --------
114 Cs2: ndarray
115 2D array of the squared Smagorinsky coefficient at level k
116 """
117
118 LMx = Imfilter(LM_k)
119 MMx = Imfilter(MM_k)
120
121 # Compute Temp with division
122 Cs2 = LMx / MMx
123
124 # Find indices where MMx is too small, Cs2 < 0, or Cs2 > 1
125 mask = (jnp.abs(MMx) < 1e-10) | (Cs2 < 0) | (Cs2 > 1)
126
127 # Apply the mask to set invalid values to zero
128 Cs2 = jnp.where(mask, 0.0, Cs2)
129
130 return Cs2
131
132
133# ============================================================
134# Main LASDD code
135# ============================================================
136
137@jax.jit
138def LASDD(
139 u, v, w,
140 S11, S22, S33,
141 S12, S13, S23,
142 S,
143 ZeRo3D):
144 """
145 Parameters:
146 -----------
147 u, v, w : ndarray
148 Velocity components
149 S11, S22, S33 : ndarray
150 Normal strain rate components
151 S12, S13, S23 : ndarray
152 Shear strain rate components
153 S : ndarray
154 Strain rate magnitude
155 ZeRo3D : ndarray
156 Pre-allocated zero arrays
157
158 Returns:
159 --------
160 u_, v_, w_ : ndarray
161 Interpolated velocity components
162 u_hat, v_hat, w_hat : ndarray
163 Level-1 filtered velocity components
164 u_hatd, v_hatd, w_hatd : ndarray
165 Level-2 filtered velocity components
166 S_hat, S_hatd : ndarray
167 Filtered strain rate magnitudes
168 Cs2_3D : ndarray
169 Cs2 field
170 Cs2_1D_avg1 : ndarray
171 1D profile of Cs2 (method 1: square of mean of sqrt)
172 Cs2_1D_avg2 : ndarray
173 1D profile of Cs2 (method 2: direct mean)
174 beta1_1D : ndarray
175 1D profile of beta1
176 """
177
178 u_ = u.copy()
179 v_ = v.copy()
180 w_ = ZeRo3D.copy()
181 w_ = w_.at[:, :, 1:nz - 1].set(StagGridAvg(w[:, :, 1:nz]))
182 w_ = w_.at[:, :, 0].set(0.5 * w[:, :, 1])
183 w_ = w_.at[:, :, nz - 1].set(w[:, :, nz - 2])
184
185 # Compute squared terms
186 uu, vv, ww = u_ ** 2, v_ ** 2, w_ ** 2
187 uv, uw, vw = u_ * v_, u_ * w_, v_ * w_
188
189 # Apply filtering
190 u_hat = Filtering_Level1(FFT(u_))
191 v_hat = Filtering_Level1(FFT(v_))
192 w_hat = Filtering_Level1(FFT(w_))
193 uu_hat = Filtering_Level1(FFT(uu))
194 vv_hat = Filtering_Level1(FFT(vv))
195 ww_hat = Filtering_Level1(FFT(ww))
196 uv_hat = Filtering_Level1(FFT(uv))
197 uw_hat = Filtering_Level1(FFT(uw))
198 vw_hat = Filtering_Level1(FFT(vw))
199
200 u_hatd = Filtering_Level2(FFT(u_))
201 v_hatd = Filtering_Level2(FFT(v_))
202 w_hatd = Filtering_Level2(FFT(w_))
203 uu_hatd = Filtering_Level2(FFT(uu))
204 vv_hatd = Filtering_Level2(FFT(vv))
205 ww_hatd = Filtering_Level2(FFT(ww))
206 uv_hatd = Filtering_Level2(FFT(uv))
207 uw_hatd = Filtering_Level2(FFT(uw))
208 vw_hatd = Filtering_Level2(FFT(vw))
209
210 # Filter strain rate components
211 S11_hat = Filtering_Level1(FFT(S11))
212 S22_hat = Filtering_Level1(FFT(S22))
213 S33_hat = Filtering_Level1(FFT(S33))
214 S12_hat = Filtering_Level1(FFT(S12))
215 S13_hat = Filtering_Level1(FFT(S13))
216 S23_hat = Filtering_Level1(FFT(S23))
217
218 S11_hatd = Filtering_Level2(FFT(S11))
219 S22_hatd = Filtering_Level2(FFT(S22))
220 S33_hatd = Filtering_Level2(FFT(S33))
221 S12_hatd = Filtering_Level2(FFT(S12))
222 S13_hatd = Filtering_Level2(FFT(S13))
223 S23_hatd = Filtering_Level2(FFT(S23))
224
225 # Compute filtered strain rate magnitudes
226 S_hat = jnp.sqrt(2 * (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
227 2 * S12_hat ** 2 +
228 2 * S13_hat ** 2 +
229 2 * S23_hat ** 2))
230 S_hatd = jnp.sqrt(2 * (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
231 2 * S12_hatd ** 2 +
232 2 * S13_hatd ** 2 +
233 2 * S23_hatd ** 2))
234
235 # Compute and filter strain rate products
236 SS11_hat = Filtering_Level1(FFT(S * S11))
237 SS22_hat = Filtering_Level1(FFT(S * S22))
238 SS33_hat = Filtering_Level1(FFT(S * S33))
239 SS12_hat = Filtering_Level1(FFT(S * S12))
240 SS13_hat = Filtering_Level1(FFT(S * S13))
241 SS23_hat = Filtering_Level1(FFT(S * S23))
242
243 SS11_hatd = Filtering_Level2(FFT(S * S11))
244 SS22_hatd = Filtering_Level2(FFT(S * S22))
245 SS33_hatd = Filtering_Level2(FFT(S * S33))
246 SS12_hatd = Filtering_Level2(FFT(S * S12))
247 SS13_hatd = Filtering_Level2(FFT(S * S13))
248 SS23_hatd = Filtering_Level2(FFT(S * S23))
249
250 # Compute L and Q tensors
251 L11, L22, L33 = (uu_hat - u_hat ** 2,
252 vv_hat - v_hat ** 2,
253 ww_hat - w_hat ** 2)
254 L12, L13, L23 = (uv_hat - u_hat * v_hat,
255 uw_hat - u_hat * w_hat,
256 vw_hat - v_hat * w_hat)
257
258 Q11, Q22, Q33 = (uu_hatd - u_hatd ** 2,
259 vv_hatd - v_hatd ** 2,
260 ww_hatd - w_hatd ** 2)
261 Q12, Q13, Q23 = (uv_hatd - u_hatd * v_hatd,
262 uw_hatd - u_hatd * w_hatd,
263 vw_hatd - v_hatd * w_hatd)
264
265 a1_terms = (L11 * SS11_hat + L22 * SS22_hat + L33 * SS33_hat +
266 2 * (L12 * SS12_hat + L13 * SS13_hat + L23 * SS23_hat))
267 a2_terms = (Q11 * SS11_hatd + Q22 * SS22_hatd + Q33 * SS33_hatd +
268 2 * (Q12 * SS12_hatd + Q13 * SS13_hatd + Q23 * SS23_hatd))
269
270 a1 = PlanarMean(2 * (L ** 2) * a1_terms)
271 a2 = PlanarMean(2 * (L ** 2) * a2_terms)
272
273 b1_terms = (L11 * S11_hat + L22 * S22_hat + L33 * S33_hat +
274 2 * (L12 * S12_hat + L13 * S13_hat + L23 * S23_hat))
275 b2_terms = (Q11 * S11_hatd + Q22 * S22_hatd + Q33 * S33_hatd +
276 2 * (Q12 * S12_hatd + Q13 * S13_hatd + Q23 * S23_hatd))
277
278 b1 = PlanarMean(2 * (L ** 2) * (TFR ** 2) * S_hat * b1_terms)
279 b2 = PlanarMean(2 * (L ** 2) * (TFR ** 4) * S_hatd * b2_terms)
280
281 c1_terms = (SS11_hat ** 2 + SS22_hat ** 2 + SS33_hat ** 2 +
282 2 * (SS12_hat ** 2 + SS13_hat ** 2 + SS23_hat ** 2))
283 c2_terms = (SS11_hatd ** 2 + SS22_hatd ** 2 + SS33_hatd ** 2 +
284 2 * (SS12_hatd ** 2 + SS13_hatd ** 2 + SS23_hatd ** 2))
285
286 c1 = PlanarMean((2 * L ** 2) ** 2 * c1_terms)
287 c2 = PlanarMean((2 * L ** 2) ** 2 * c2_terms)
288
289 d1_terms = (S11_hat ** 2 + S22_hat ** 2 + S33_hat ** 2 +
290 2 * (S12_hat ** 2 + S13_hat ** 2 + S23_hat ** 2))
291 d2_terms = (S11_hatd ** 2 + S22_hatd ** 2 + S33_hatd ** 2 +
292 2 * (S12_hatd ** 2 + S13_hatd ** 2 + S23_hatd ** 2))
293
294 d1 = PlanarMean((4 * L ** 4) * (TFR ** 4) * (S_hat ** 2) * d1_terms)
295 d2 = PlanarMean((4 * L ** 4) * (TFR ** 8) * (S_hatd ** 2) * d2_terms)
296
297 e1_terms = (S11_hat * SS11_hat +
298 S22_hat * SS22_hat +
299 S33_hat * SS33_hat +
300 2 * (S12_hat * SS12_hat +
301 S13_hat * SS13_hat +
302 S23_hat * SS23_hat))
303 e2_terms = (S11_hatd * SS11_hatd +
304 S22_hatd * SS22_hatd +
305 S33_hatd * SS33_hatd +
306 2 * (S12_hatd * SS12_hatd +
307 S13_hatd * SS13_hatd +
308 S23_hatd * SS23_hatd))
309
310 e1 = PlanarMean((8 * L ** 4) * (TFR ** 2) * S_hat * e1_terms)
311 e2 = PlanarMean((8 * L ** 4) * (TFR ** 4) * S_hatd * e2_terms)
312
313 # Compute polynomial coefficients
314 aa = a1 * c2 - a2 * c1
315 bb = a2 * e1 - b1 * c2
316 cc = b2 * c1 - a1 * e2 - a2 * d1
317 dd = b1 * e2 - b2 * e1
318 ee = a1 * d2 + b2 * d1
319 ff = -b1 * d2
320
321 computeBeta = optSgs in [1, 2]
322 if computeBeta:
323 beta1_1D = ComputeBeta1(ff, ee, dd, cc, bb, aa)
324 else:
325 beta1_1D = jnp.ones(nz)
326 # Extend beta1 to 3D field
327 beta1_3D = jnp.broadcast_to(beta1_1D.reshape(1, 1, nz), (nx, ny, nz))
328
329 # Compute M terms
330 T1 = 2 * L ** 2
331 T2 = 2 * (TFR * L) ** 2
332 M11 = T1 * SS11_hat - T2 * beta1_3D * S_hat * S11_hat
333 M22 = T1 * SS22_hat - T2 * beta1_3D * S_hat * S22_hat
334 M33 = T1 * SS33_hat - T2 * beta1_3D * S_hat * S33_hat
335 M12 = T1 * SS12_hat - T2 * beta1_3D * S_hat * S12_hat
336 M13 = T1 * SS13_hat - T2 * beta1_3D * S_hat * S13_hat
337 M23 = T1 * SS23_hat - T2 * beta1_3D * S_hat * S23_hat
338
339 # Compute LM and MM terms
340 LM = ((L11 * M11 +
341 L22 * M22 +
342 L33 * M33) +
343 2 * (L12 * M12 +
344 L13 * M13 +
345 L23 * M23))
346
347 MM = (M11 ** 2 +
348 M22 ** 2 +
349 M33 ** 2 +
350 2 * (M12 ** 2 +
351 M13 ** 2 +
352 M23 ** 2))
353
354 # Compute Cs2_3D field for all levels using vmap
355 Cs2_3D = jax.vmap(Cs2_at_level_k, in_axes=(2, 2), out_axes=2)(LM, MM)
356
357 # Compute 1D averages from the 3D field
358 # First compute sqrt(Cs2_3D) for each level and then square the mean
359 Cs2_1D_avg1 = PlanarMean(jnp.sqrt(Cs2_3D)) ** 2
360 # Compute simple mean for each level
361 Cs2_1D_avg2 = PlanarMean(Cs2_3D)
362
363 return (u_, v_, w_,
364 u_hat, v_hat, w_hat,
365 u_hatd, v_hatd, w_hatd,
366 S_hat, S_hatd,
367 Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D)