Dynamic SGS Modeling: Main Code

Source Code: DynamicSGS_Main

DynamicSGS_Main.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: DynamicSGS_Main.py
 18========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: dynamic SGS modeling - main code.
 25              Dispatches between SM (optSgs=1,3) and WL (optSgs=2,4) based on Config.
 26              optSgs=1: LASDD-SM, optSgs=2: LASDD-WL,
 27              optSgs=3: LAD-SM, optSgs=4: LAD-WL
 28"""
 29
 30# ============================================================
 31#  Imports
 32# ============================================================
 33
 34import jax
 35
 36# Import derived variables
 37from ..config.DerivedVars import *
 38
 39# Import FFT modules
 40from ..operations.FFT import FFT
 41
 42# Import dealiasing functions
 43from ..operations.Dealiasing import Dealias1
 44
 45# Import strain rates functions
 46from .StrainRates import StrainsUVPnodes_Dealias, StrainsWnodes_Dealias
 47from .StrainRates import StrainsUVPnodes_NoDealias, StrainsWnodes_NoDealias
 48
 49# Import LASDD models (SM and WL)
 50from .DynamicSGS_LASDD_SM import LASDD as LASDD_SM
 51from .DynamicSGS_LASDD_WL import LASDD as LASDD_WL
 52from .DynamicSGS_ScalarLASDD_SM import ScalarLASDD as ScalarLASDD_SM
 53from .DynamicSGS_ScalarLASDD_WL import ScalarLASDD as ScalarLASDD_WL
 54
 55# Import stress functions (SM and WL)
 56from .SGSStresses_SM import (
 57    StressesUVPnodes_Dealias   as StressesUVPnodes_Dealias_SM,
 58    StressesUVPnodes_NoDealias as StressesUVPnodes_NoDealias_SM,
 59    StressesWnodes_Dealias     as StressesWnodes_Dealias_SM,
 60    StressesWnodes_NoDealias   as StressesWnodes_NoDealias_SM)
 61from .SGSStresses_WL import (
 62    StressesUVPnodes_Dealias   as StressesUVPnodes_Dealias_WL,
 63    StressesUVPnodes_NoDealias as StressesUVPnodes_NoDealias_WL,
 64    StressesWnodes_Dealias     as StressesWnodes_Dealias_WL,
 65    StressesWnodes_NoDealias   as StressesWnodes_NoDealias_WL)
 66
 67# Import scalar flux functions (SM and WL)
 68from .ScalarSGSFluxes_SM import (
 69    ScalarFluxesUVPnodes_Dealias   as ScalarFluxesUVPnodes_Dealias_SM,
 70    ScalarFluxesUVPnodes_NoDealias as ScalarFluxesUVPnodes_NoDealias_SM,
 71    ScalarFluxesWnodes_Dealias     as ScalarFluxesWnodes_Dealias_SM,
 72    ScalarFluxesWnodes_NoDealias   as ScalarFluxesWnodes_NoDealias_SM)
 73from .ScalarSGSFluxes_WL import (
 74    ScalarFluxesUVPnodes_Dealias   as ScalarFluxesUVPnodes_Dealias_WL,
 75    ScalarFluxesUVPnodes_NoDealias as ScalarFluxesUVPnodes_NoDealias_WL,
 76    ScalarFluxesWnodes_Dealias     as ScalarFluxesWnodes_Dealias_WL,
 77    ScalarFluxesWnodes_NoDealias   as ScalarFluxesWnodes_NoDealias_WL)
 78
 79
 80# ============================================================
 81# Dynamic SGS: compute all the SGS stresses on proper nodes
 82# ============================================================
 83
 84@jax.jit
 85def DynamicSGS(
 86        dudx, dvdx, dwdx,
 87        dudy, dvdy, dwdy,
 88        dudz, dvdz, dwdz,
 89        u, v, w, M_sfc_loc, psi2D_m, psi2D_m0,
 90        ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft):
 91    """
 92    Computes all SGS stresses on proper grid nodes using the dynamic model.
 93    Dispatches to SM variants (optSgs=1,3) or WL variants (optSgs=2,4).
 94
 95    Parameters:
 96    -----------
 97    dudx, dvdx, dwdx : ndarray of shape (nx, ny, nz)
 98        Derivatives of velocity components in x-direction
 99    dudy, dvdy, dwdy : ndarray of shape (nx, ny, nz)
100        Derivatives of velocity components in y-direction
101    dudz, dvdz, dwdz : ndarray of shape (nx, ny, nz)
102        Derivatives of velocity components in z-direction
103    u, v, w : ndarray of shape (nx, ny, nz)
104        Velocity components
105    M_sfc_loc : ndarray of shape (nx, ny)
106        Near-surface wind speed
107    psi2D_m, psi2D_m0 : ndarray of shape (nx, ny)
108        Stability correction functions
109    ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft : ndarray
110        Pre-allocated zero arrays
111
112    Returns:
113    --------
114    txx, tyy, tzz, txy, txz, tyz : ndarray of shape (nx, ny, nz)
115        SGS stress components
116    Cs2_1D_avg1, Cs2_1D_avg2 : ndarray of shape (nz)
117        1D profiles of SGS model coefficient (two averaging methods)
118    beta1_1D : ndarray of shape (nz)
119        1D profile of scale-dependence parameter beta1
120    u_, v_, w_ : ndarray of shape (nx, ny, nz)
121        Interpolated velocity components
122    u_hat, v_hat, w_hat : ndarray of shape (nx, ny, nz)
123        Level-1 filtered velocity components
124    u_hatd, v_hatd, w_hatd : ndarray of shape (nx, ny, nz)
125        Level-2 filtered velocity components
126    S_uvp, S_uvp_pad : ndarray of shape (nx, ny, nz)
127        Strain rate magnitude at UVP nodes and its padded version
128    S_w, S_w_pad : ndarray of shape (nx, ny, nz)
129        Strain rate magnitude at W nodes and its padded version
130    S_uvp_hat, S_uvp_hatd : ndarray of shape (nx, ny, nz)
131        Filtered strain rate magnitudes
132    """
133
134    # ----------------------------------------
135    # Compute txx, tyy, tzz and txy components
136    # ----------------------------------------
137    if optDealias == 1:
138
139        # --------------------------------------
140        # Compute strain rates
141        # --------------------------------------
142        (S11, S22, S33,
143         S12, S13, S23,
144         S_uvp,
145         S11_pad, S22_pad, S33_pad,
146         S12_pad, S13_pad, S23_pad,
147         S_uvp_pad) = (
148            StrainsUVPnodes_Dealias(
149                dudx, dvdx, dwdx,
150                dudy, dvdy, dwdy,
151                dudz, dvdz, dwdz,
152                ZeRo3D, ZeRo3D_pad_fft))
153
154        # --------------------------------------
155        # Call LASDD and compute UVP stresses
156        # --------------------------------------
157        if optSgs in [1, 3]:  # SM variants
158
159            (u_, v_, w_,
160             u_hat, v_hat, w_hat,
161             u_hatd, v_hatd, w_hatd,
162             S_uvp_hat, S_uvp_hatd,
163             Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D) = (
164                LASDD_SM(
165                    u, v, w,
166                    S11, S22, S33,
167                    S12, S13, S23,
168                    S_uvp,
169                    ZeRo3D))
170
171            Cs2_3D_pad = Dealias1(FFT(Cs2_3D), ZeRo3D_pad_fft)
172
173            (txx, tyy, tzz, txy) = (
174                StressesUVPnodes_Dealias_SM(
175                    S11_pad, S22_pad, S33_pad, S12_pad,
176                    S_uvp_pad,
177                    Cs2_3D_pad,
178                    ZeRo3D_fft))
179
180        elif optSgs in [2, 4]:  # WL variants
181
182            (u_, v_, w_,
183             u_hat, v_hat, w_hat,
184             u_hatd, v_hatd, w_hatd,
185             S_uvp_hat, S_uvp_hatd,
186             Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D) = (
187                LASDD_WL(
188                    u, v, w,
189                    S11, S22, S33,
190                    S12, S13, S23,
191                    ZeRo3D))
192
193            Cs2_3D_pad = Dealias1(FFT(Cs2_3D), ZeRo3D_pad_fft)
194
195            (txx, tyy, tzz, txy) = (
196                StressesUVPnodes_Dealias_WL(
197                    S11_pad, S22_pad, S33_pad, S12_pad,
198                    Cs2_3D_pad,
199                    ZeRo3D_fft))
200
201        else:
202            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS")
203
204    else:
205
206        # --------------------------------------
207        # Compute strain rates
208        # --------------------------------------
209        (S11, S22, S33,
210         S12, S13, S23,
211         S_uvp) = (
212            StrainsUVPnodes_NoDealias(
213                dudx, dvdx, dwdx,
214                dudy, dvdy, dwdy,
215                dudz, dvdz, dwdz,
216                ZeRo3D))
217
218        S_uvp_pad = S_uvp
219
220        # --------------------------------------
221        # Call LASDD and compute UVP stresses
222        # --------------------------------------
223        if optSgs in [1, 3]:  # SM variants
224
225            (u_, v_, w_,
226             u_hat, v_hat, w_hat,
227             u_hatd, v_hatd, w_hatd,
228             S_uvp_hat, S_uvp_hatd,
229             Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D) = (
230                LASDD_SM(
231                    u, v, w,
232                    S11, S22, S33,
233                    S12, S13, S23,
234                    S_uvp,
235                    ZeRo3D))
236
237            (txx, tyy, tzz, txy) = (
238                StressesUVPnodes_NoDealias_SM(
239                    S11, S22, S33, S12,
240                    S_uvp,
241                    Cs2_3D))
242
243        elif optSgs in [2, 4]:  # WL variants
244
245            (u_, v_, w_,
246             u_hat, v_hat, w_hat,
247             u_hatd, v_hatd, w_hatd,
248             S_uvp_hat, S_uvp_hatd,
249             Cs2_3D, Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D) = (
250                LASDD_WL(
251                    u, v, w,
252                    S11, S22, S33,
253                    S12, S13, S23,
254                    ZeRo3D))
255
256            (txx, tyy, tzz, txy) = (
257                StressesUVPnodes_NoDealias_WL(
258                    S11, S22, S33, S12,
259                    Cs2_3D))
260
261        else:
262            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS")
263
264    # ------------------------------------------------------------
265    # Compute txz and tyz components
266    # ------------------------------------------------------------
267    if optDealias == 1:
268
269        (S13_pad, S23_pad,
270         S_w_pad) = (
271            StrainsWnodes_Dealias(
272                dudx, dvdx, dwdx,
273                dudy, dvdy, dwdy,
274                dudz, dvdz, dwdz,
275                ZeRo3D, ZeRo3D_pad_fft))
276
277        S_w = S_w_pad
278
279        if optSgs in [1, 3]:  # SM variants
280            (txz, tyz) = (
281                StressesWnodes_Dealias_SM(
282                    S13_pad, S23_pad,
283                    S_w_pad,
284                    Cs2_3D_pad,
285                    u, v, M_sfc_loc, psi2D_m, psi2D_m0,
286                    ZeRo3D_fft))
287        elif optSgs in [2, 4]:  # WL variants
288            (txz, tyz) = (
289                StressesWnodes_Dealias_WL(
290                    S13_pad, S23_pad,
291                    Cs2_3D_pad,
292                    u, v, M_sfc_loc, psi2D_m, psi2D_m0,
293                    ZeRo3D_fft))
294        else:
295            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS")
296
297    else:
298
299        (S13, S23,
300         S_w) = (
301            StrainsWnodes_NoDealias(
302                dudx, dvdx, dwdx,
303                dudy, dvdy, dwdy,
304                dudz, dvdz, dwdz,
305                ZeRo3D))
306
307        S_w_pad = S_w
308
309        if optSgs in [1, 3]:  # SM variants
310            (txz, tyz) = (
311                StressesWnodes_NoDealias_SM(
312                    S13, S23,
313                    S_w,
314                    Cs2_3D,
315                    u, v, M_sfc_loc, psi2D_m, psi2D_m0))
316        elif optSgs in [2, 4]:  # WL variants
317            (txz, tyz) = (
318                StressesWnodes_NoDealias_WL(
319                    S13, S23,
320                    Cs2_3D,
321                    u, v, M_sfc_loc, psi2D_m, psi2D_m0))
322        else:
323            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS")
324
325    return (txx, tyy, tzz, txy, txz, tyz,
326            Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D,
327            Cs2_3D,
328            u_, v_, w_,
329            u_hat, v_hat, w_hat,
330            u_hatd, v_hatd, w_hatd,
331            S_uvp, S_uvp_pad,
332            S_w, S_w_pad,
333            S_uvp_hat, S_uvp_hatd)
334
335
336# ============================================================
337# Dynamic SGS: compute scalar SGS fluxes on proper nodes
338# ============================================================
339
340@jax.jit
341def DynamicSGSscalar(
342        u_, v_, w_,
343        u_hat, v_hat, w_hat,
344        u_hatd, v_hatd, w_hatd,
345        S_uvp, S_uvp_pad,
346        S_w, S_w_pad,
347        S_uvp_hat, S_uvp_hatd,
348        TH,
349        dTHdx, dTHdy, dTHdz,
350        qz_sfc,
351        ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft):
352    """
353    Parameters:
354    -----------
355    u_, v_, w_ : ndarray of shape (nx, ny, nz)
356        Interpolated velocity components
357    u_hat, v_hat, w_hat : ndarray of shape (nx, ny, nz)
358        Level-1 filtered velocity components
359    u_hatd, v_hatd, w_hatd : ndarray of shape (nx, ny, nz)
360        Level-2 filtered velocity components
361    S_uvp, S_uvp_pad : ndarray of shape (nx, ny, nz)
362        Strain rate magnitude at UVP nodes and its padded version
363    S_w, S_w_pad : ndarray of shape (nx, ny, nz)
364        Strain rate magnitude at W nodes and its padded version
365    S_uvp_hat, S_uvp_hatd : ndarray of shape (nx, ny, nz)
366        Filtered strain rate magnitudes
367    TH : ndarray of shape (nx, ny, nz)
368        Potential temperature
369    dTHdx, dTHdy, dTHdz : ndarray of shape (nx, ny, nz)
370        Derivatives of potential temperature
371    qz_sfc : ndarray of shape (nx, ny)
372        Surface sensible heat flux
373    ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft : ndarray
374        Pre-allocated arrays for calculations
375
376    Returns:
377    --------
378    qx, qy, qz : ndarray of shape (nx, ny, nz)
379        SGS scalar flux components
380    Cs2PrRatio_1D : ndarray of shape (nz)
381        1D profile of SGS coefficient / Pr_t
382    beta2_1D : ndarray of shape (nz)
383        1D profile of scalar scale-dependence parameter beta2
384    """
385
386    # ------------------------------------------------------------
387    # Compute scalar SGS model coefficient
388    # ------------------------------------------------------------
389    if optSgs in [1, 3]:  # SM variants
390        (Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D) = (
391            ScalarLASDD_SM(
392                u_, v_, w_,
393                u_hat, v_hat, w_hat,
394                u_hatd, v_hatd, w_hatd,
395                TH,
396                dTHdx, dTHdy, dTHdz,
397                S_uvp, S_uvp_hat, S_uvp_hatd,
398                ZeRo3D))
399    elif optSgs in [2, 4]:  # WL variants
400        (Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D) = (
401            ScalarLASDD_WL(
402                u_, v_, w_,
403                u_hat, v_hat, w_hat,
404                u_hatd, v_hatd, w_hatd,
405                TH,
406                dTHdx, dTHdy, dTHdz,
407                ZeRo3D))
408    else:
409        raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS scalar")
410
411    # ------------------------------------------------------------
412    # Compute qx, qy and qz components
413    # ------------------------------------------------------------
414    if optDealias == 1:
415
416        dTHdx_pad = Dealias1(FFT(dTHdx), ZeRo3D_pad_fft)
417        dTHdy_pad = Dealias1(FFT(dTHdy), ZeRo3D_pad_fft)
418        dTHdz_pad = Dealias1(FFT(dTHdz), ZeRo3D_pad_fft)
419
420        Cs2PrRatio_3D_pad = Dealias1(FFT(Cs2PrRatio_3D), ZeRo3D_pad_fft)
421
422        if optSgs in [1, 3]:  # SM variants
423            (qx, qy) = (
424                ScalarFluxesUVPnodes_Dealias_SM(
425                    dTHdx_pad, dTHdy_pad,
426                    S_uvp_pad,
427                    Cs2PrRatio_3D_pad,
428                    ZeRo3D_fft))
429            qz = (
430                ScalarFluxesWnodes_Dealias_SM(
431                    dTHdz_pad,
432                    S_w_pad,
433                    Cs2PrRatio_3D_pad,
434                    qz_sfc,
435                    ZeRo3D_fft))
436        elif optSgs in [2, 4]:  # WL variants
437            (qx, qy) = (
438                ScalarFluxesUVPnodes_Dealias_WL(
439                    dTHdx_pad, dTHdy_pad,
440                    Cs2PrRatio_3D_pad,
441                    ZeRo3D_fft))
442            qz = (
443                ScalarFluxesWnodes_Dealias_WL(
444                    dTHdz_pad,
445                    Cs2PrRatio_3D_pad,
446                    qz_sfc,
447                    ZeRo3D_fft))
448        else:
449            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS scalar")
450
451    else:
452
453        if optSgs in [1, 3]:  # SM variants
454            (qx, qy) = (
455                ScalarFluxesUVPnodes_NoDealias_SM(
456                    dTHdx, dTHdy,
457                    S_uvp,
458                    Cs2PrRatio_3D))
459            qz = (
460                ScalarFluxesWnodes_NoDealias_SM(
461                    dTHdz,
462                    S_w,
463                    Cs2PrRatio_3D,
464                    qz_sfc))
465        elif optSgs in [2, 4]:  # WL variants
466            (qx, qy) = (
467                ScalarFluxesUVPnodes_NoDealias_WL(
468                    dTHdx, dTHdy,
469                    Cs2PrRatio_3D))
470            qz = (
471                ScalarFluxesWnodes_NoDealias_WL(
472                    dTHdz,
473                    Cs2PrRatio_3D,
474                    qz_sfc))
475        else:
476            raise ValueError(f"Unsupported optSgs={optSgs} for dynamic SGS scalar")
477
478    return (qx, qy, qz,
479            Cs2PrRatio_3D,
480            Cs2PrRatio_1D, beta2_1D)