Scalar SGS Flux Computations: Smagorinsky
Computes SGS scalar (potential temperature) fluxes using the Smagorinsky (SM) base formulation:
\[q_i = -2 \frac{C_s^2}{Pr_t} \Delta^2 |\bar{S}| \frac{\partial \bar{\theta}}{\partial x_i}\]
Used for optSgs = 1 (LASDD-SM) and optSgs = 3 (LAD-SM).
Source Code: ScalarSGSFluxes_SM
ScalarSGSFluxes_SM.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: ScalarSGSFluxes.py
18========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-29
24:Description: computes SGS scalar fluxes for the eddy-diffusivity model:
25 q_i = -2(L^2) * Cs^2/Pr_t * |S| * (∂TH/∂x_i)
26 where Cs^2/Pr_t is the model coefficient, |S| is the strain rate
27 magnitude, and ∂TH/∂x_i is the potential temperature gradient
28"""
29
30# ============================================================
31# Imports
32# ============================================================
33
34import jax
35import jax.numpy as jnp
36
37# Import derived variables
38from ..config.DerivedVars import *
39
40# Import FFT modules
41from ..operations.FFT import FFT, FFT_pad
42
43# Import helper functions
44from ..utilities.Utilities import StagGridAvg
45
46# Import dealiasing functions
47from ..operations.Dealiasing import Dealias1, Dealias2
48
49
50# ======================================================
51# Compute SGS scalar fluxes on UVP nodes with dealiasing
52# ======================================================
53
54@jax.jit
55def ScalarFluxesUVPnodes_Dealias(
56 dTHdx_pad, dTHdy_pad,
57 S_pad, Cs2PrRatio_3D_pad,
58 ZeRo3D_fft):
59 """
60 Parameters:
61 -----------
62 dTHdx_pad, dTHdy_pad : ndarray
63 Dealiased potential temperature gradients
64 S_pad : ndarray
65 Dealiased strain rate magnitude
66 Cs2PrRatio_3D_pad : ndarray
67 Dealiased Cs^2/Pr_t field
68 ZeRo3D_fft : ndarray
69 Pre-allocated zero array for dealiasing
70
71 Returns:
72 --------
73 qx, qy : ndarray
74 SGS scalar flux components in x and y directions
75 """
76
77 # Compute SGS scalar fluxes at UVP nodes
78 preCompute = -2 * (L ** 2) * Cs2PrRatio_3D_pad * S_pad
79 qx_pad = preCompute * dTHdx_pad
80 qy_pad = preCompute * dTHdy_pad
81
82 # Set top boundary conditions
83 qx_pad = qx_pad.at[:, :, nz - 1].set(0)
84 qy_pad = qy_pad.at[:, :, nz - 1].set(0)
85
86 # Apply dealiasing to horizontal fluxes
87 qx = Dealias2(FFT_pad(qx_pad), ZeRo3D_fft)
88 qy = Dealias2(FFT_pad(qy_pad), ZeRo3D_fft)
89
90 return qx, qy
91
92
93# =========================================================
94# Compute SGS scalar fluxes on UVP nodes without dealiasing
95# =========================================================
96
97@jax.jit
98def ScalarFluxesUVPnodes_NoDealias(
99 dTHdx, dTHdy,
100 S, Cs2PrRatio_3D):
101 """
102 Parameters:
103 -----------
104 dTHdx, dTHdy : ndarray
105 Potential temperature gradients at UVP nodes
106 S : ndarray
107 Strain rate magnitude at UVP nodes
108 Cs2PrRatio_3D : ndarray
109 Cs^2/Pr_t field
110
111 Returns:
112 --------
113 qx, qy : ndarray
114 SGS scalar flux components in x and y directions
115 """
116
117 # Compute SGS scalar fluxes at UVP nodes
118 preCompute = -2 * (L ** 2) * Cs2PrRatio_3D * S
119 qx = preCompute * dTHdx
120 qy = preCompute * dTHdy
121
122 # Set top boundary conditions
123 qx = qx.at[:, :, nz - 1].set(0)
124 qy = qy.at[:, :, nz - 1].set(0)
125
126 return qx, qy
127
128
129# ====================================================
130# Compute SGS scalar fluxes on W nodes with dealiasing
131# ====================================================
132
133@jax.jit
134def ScalarFluxesWnodes_Dealias(
135 dTHdz_pad,
136 S_pad, Cs2PrRatio_3D_pad,
137 qz_sfc,
138 ZeRo3D_fft):
139 """
140 Parameters:
141 -----------
142 dTHdz_pad : ndarray
143 Dealiased potential temperature gradient in z-direction
144 S_pad : ndarray
145 Dealiased strain rate magnitude
146 Cs2PrRatio_3D_pad : ndarray
147 Dealiased Cs^2/Pr_t field
148 qz_sfc : ndarray
149 Surface heat flux
150 ZeRo3D_fft : ndarray
151 Pre-allocated zero array for dealiasing
152
153 Returns:
154 --------
155 qz : ndarray
156 SGS scalar flux component in z-direction
157 """
158
159 # Initialize array for vertical flux
160 qz_pad = jnp.zeros_like(S_pad)
161
162 # Interior points for vertical flux (on w-nodes)
163 qz_pad = qz_pad.at[:, :, 1:nz - 1].set(
164 -2 * (L ** 2) * StagGridAvg(Cs2PrRatio_3D_pad[:, :, :nz - 1]) *
165 S_pad[:, :, 1:nz - 1] * dTHdz_pad[:, :, 1:nz - 1]
166 )
167
168 # Top boundary condition
169 qz_pad = qz_pad.at[:, :, nz - 1].set(0)
170
171 # Apply dealiasing to vertical flux
172 qz = Dealias2(FFT_pad(qz_pad), ZeRo3D_fft)
173
174 # Bottom boundary condition
175 qz = qz.at[:, :, 0].set(qz_sfc)
176
177 return qz
178
179
180# =======================================================
181# Compute SGS scalar fluxes on W nodes without dealiasing
182# =======================================================
183
184@jax.jit
185def ScalarFluxesWnodes_NoDealias(
186 dTHdz,
187 S, Cs2PrRatio_3D,
188 qz_sfc):
189 """
190 Parameters:
191 -----------
192 dTHdz : ndarray
193 Potential temperature gradient in z-direction
194 S : ndarray
195 Strain rate magnitude
196 Cs2PrRatio_3D : ndarray
197 Turbulent Cs^2/Pr_t field
198 qz_sfc : ndarray
199 Surface heat flux
200
201 Returns:
202 --------
203 qz : ndarray
204 SGS scalar flux component in z-direction
205 """
206
207 # Initialize array for vertical flux with correct dimensions
208 qz = jnp.zeros_like(S)
209
210 # Interior points for vertical flux (on w-nodes)
211 qz = qz.at[:, :, 1:nz - 1].set(
212 -2 * (L ** 2) * StagGridAvg(Cs2PrRatio_3D[:, :, :nz - 1]) *
213 S[:, :, 1:nz - 1] * dTHdz[:, :, 1:nz - 1]
214 )
215
216 # Top boundary condition
217 qz = qz.at[:, :, nz - 1].set(0)
218
219 # Bottom boundary condition
220 qz = qz.at[:, :, 0].set(qz_sfc)
221
222 return qz