Scalar SGS Flux Computations: Smagorinsky

Computes SGS scalar (potential temperature) fluxes using the Smagorinsky (SM) base formulation:

\[q_i = -2 \frac{C_s^2}{Pr_t} \Delta^2 |\bar{S}| \frac{\partial \bar{\theta}}{\partial x_i}\]

Used for optSgs = 1 (LASDD-SM) and optSgs = 3 (LAD-SM).

Source Code: ScalarSGSFluxes_SM

ScalarSGSFluxes_SM.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: ScalarSGSFluxes.py
 18========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: computes SGS scalar fluxes for the eddy-diffusivity model:
 25              q_i = -2(L^2) * Cs^2/Pr_t * |S| * (∂TH/∂x_i)
 26              where Cs^2/Pr_t is the model coefficient, |S| is the strain rate
 27              magnitude, and ∂TH/∂x_i is the potential temperature gradient
 28"""
 29
 30# ============================================================
 31#  Imports
 32# ============================================================
 33
 34import jax
 35import jax.numpy as jnp
 36
 37# Import derived variables
 38from ..config.DerivedVars import *
 39
 40# Import FFT modules
 41from ..operations.FFT import FFT, FFT_pad
 42
 43# Import helper functions
 44from ..utilities.Utilities import StagGridAvg
 45
 46# Import dealiasing functions
 47from ..operations.Dealiasing import Dealias1, Dealias2
 48
 49
 50# ======================================================
 51# Compute SGS scalar fluxes on UVP nodes with dealiasing
 52# ======================================================
 53
 54@jax.jit
 55def ScalarFluxesUVPnodes_Dealias(
 56        dTHdx_pad, dTHdy_pad,
 57        S_pad, Cs2PrRatio_3D_pad,
 58        ZeRo3D_fft):
 59    """
 60    Parameters:
 61    -----------
 62    dTHdx_pad, dTHdy_pad : ndarray
 63        Dealiased potential temperature gradients
 64    S_pad : ndarray
 65        Dealiased strain rate magnitude
 66    Cs2PrRatio_3D_pad : ndarray
 67        Dealiased Cs^2/Pr_t field
 68    ZeRo3D_fft : ndarray
 69        Pre-allocated zero array for dealiasing
 70
 71    Returns:
 72    --------
 73    qx, qy : ndarray
 74        SGS scalar flux components in x and y directions
 75    """
 76
 77    # Compute SGS scalar fluxes at UVP nodes
 78    preCompute = -2 * (L ** 2) * Cs2PrRatio_3D_pad * S_pad
 79    qx_pad = preCompute * dTHdx_pad
 80    qy_pad = preCompute * dTHdy_pad
 81
 82    # Set top boundary conditions
 83    qx_pad = qx_pad.at[:, :, nz - 1].set(0)
 84    qy_pad = qy_pad.at[:, :, nz - 1].set(0)
 85
 86    # Apply dealiasing to horizontal fluxes
 87    qx = Dealias2(FFT_pad(qx_pad), ZeRo3D_fft)
 88    qy = Dealias2(FFT_pad(qy_pad), ZeRo3D_fft)
 89
 90    return qx, qy
 91
 92
 93# =========================================================
 94# Compute SGS scalar fluxes on UVP nodes without dealiasing
 95# =========================================================
 96
 97@jax.jit
 98def ScalarFluxesUVPnodes_NoDealias(
 99        dTHdx, dTHdy,
100        S, Cs2PrRatio_3D):
101    """
102    Parameters:
103    -----------
104    dTHdx, dTHdy : ndarray
105        Potential temperature gradients at UVP nodes
106    S : ndarray
107        Strain rate magnitude at UVP nodes
108    Cs2PrRatio_3D : ndarray
109        Cs^2/Pr_t field
110
111    Returns:
112    --------
113    qx, qy : ndarray
114        SGS scalar flux components in x and y directions
115    """
116
117    # Compute SGS scalar fluxes at UVP nodes
118    preCompute = -2 * (L ** 2) * Cs2PrRatio_3D * S
119    qx = preCompute * dTHdx
120    qy = preCompute * dTHdy
121
122    # Set top boundary conditions
123    qx = qx.at[:, :, nz - 1].set(0)
124    qy = qy.at[:, :, nz - 1].set(0)
125
126    return qx, qy
127
128
129# ====================================================
130# Compute SGS scalar fluxes on W nodes with dealiasing
131# ====================================================
132
133@jax.jit
134def ScalarFluxesWnodes_Dealias(
135        dTHdz_pad,
136        S_pad, Cs2PrRatio_3D_pad,
137        qz_sfc,
138        ZeRo3D_fft):
139    """
140    Parameters:
141    -----------
142    dTHdz_pad : ndarray
143        Dealiased potential temperature gradient in z-direction
144    S_pad : ndarray
145        Dealiased strain rate magnitude
146    Cs2PrRatio_3D_pad : ndarray
147        Dealiased Cs^2/Pr_t field
148    qz_sfc : ndarray
149        Surface heat flux
150    ZeRo3D_fft : ndarray
151        Pre-allocated zero array for dealiasing
152
153    Returns:
154    --------
155    qz : ndarray
156        SGS scalar flux component in z-direction
157    """
158
159    # Initialize array for vertical flux
160    qz_pad = jnp.zeros_like(S_pad)
161
162    # Interior points for vertical flux (on w-nodes)
163    qz_pad = qz_pad.at[:, :, 1:nz - 1].set(
164        -2 * (L ** 2) * StagGridAvg(Cs2PrRatio_3D_pad[:, :, :nz - 1]) *
165        S_pad[:, :, 1:nz - 1] * dTHdz_pad[:, :, 1:nz - 1]
166    )
167
168    # Top boundary condition
169    qz_pad = qz_pad.at[:, :, nz - 1].set(0)
170
171    # Apply dealiasing to vertical flux
172    qz = Dealias2(FFT_pad(qz_pad), ZeRo3D_fft)
173
174    # Bottom boundary condition
175    qz = qz.at[:, :, 0].set(qz_sfc)
176
177    return qz
178
179
180# =======================================================
181# Compute SGS scalar fluxes on W nodes without dealiasing
182# =======================================================
183
184@jax.jit
185def ScalarFluxesWnodes_NoDealias(
186        dTHdz,
187        S, Cs2PrRatio_3D,
188        qz_sfc):
189    """
190    Parameters:
191    -----------
192    dTHdz : ndarray
193        Potential temperature gradient in z-direction
194    S : ndarray
195        Strain rate magnitude
196    Cs2PrRatio_3D : ndarray
197        Turbulent Cs^2/Pr_t field
198    qz_sfc : ndarray
199        Surface heat flux
200
201    Returns:
202    --------
203    qz : ndarray
204        SGS scalar flux component in z-direction
205    """
206
207    # Initialize array for vertical flux with correct dimensions
208    qz = jnp.zeros_like(S)
209
210    # Interior points for vertical flux (on w-nodes)
211    qz = qz.at[:, :, 1:nz - 1].set(
212        -2 * (L ** 2) * StagGridAvg(Cs2PrRatio_3D[:, :, :nz - 1]) *
213        S[:, :, 1:nz - 1] * dTHdz[:, :, 1:nz - 1]
214    )
215
216    # Top boundary condition
217    qz = qz.at[:, :, nz - 1].set(0)
218
219    # Bottom boundary condition
220    qz = qz.at[:, :, 0].set(qz_sfc)
221
222    return qz