Compute Statistics
Source Code: Statistics
Statistics.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Statistics.py
18===================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-5-3
24:Description: this file is used to compute various statistics.
25"""
26
27# ============================================================
28# Imports
29# ============================================================
30
31import jax
32import jax.numpy as jnp
33from ..utilities.Utilities import StagGridAvg
34
35# Import configuration from namelist
36from ..config.ConfigLoader import *
37
38# Import derived variables
39from ..config.DerivedVars import *
40
41
42# ============================================================
43# Compute planar-averaged statistics
44# ============================================================
45
46@jax.jit
47def ComputeStats(
48 u, v, w, TH, Q,
49 dudz, dvdz, dTHdz, dQdz,
50 M_sfc_loc, ustar, qz_sfc_avg_nd, qm_sfc_avg_nd,
51 txy, txz, tyz, qz, qHz_q,
52 Cs2_1D_avg1, Cs2_1D_avg2,
53 Cs2PrRatio_1D,
54 beta1_1D, beta2_1D,
55 StatsDict, ResetFlag,
56 ZeRo3D):
57 """
58 Computes spatial averaged-statistics for LES flow variables.
59
60 Parameters:
61 -----------
62 u, v, w : ndarray — velocity components
63 TH : ndarray — potential temperature (stored as anomaly TH - T_0)
64 Q : ndarray — specific humidity (kg/kg); ZeRo3D when optMoisture=0
65 dudz, dvdz : ndarray — velocity vertical gradients
66 dTHdz : ndarray — potential temperature vertical gradient
67 dQdz : ndarray — specific humidity vertical gradient; ZeRo3D when optMoisture=0
68 M_sfc_loc : ndarray (nx, ny) — surface wind speed
69 ustar : ndarray (nx, ny) — friction velocity
70 qz_sfc_avg_nd : scalar — non-dim surface heat flux
71 qm_sfc_avg_nd : scalar — non-dim surface moisture flux; 0.0 when optMoisture=0
72 txy, txz, tyz : ndarray — SGS stress components
73 qz : ndarray — SGS heat flux in z
74 qHz_q : ndarray — SGS moisture flux in z; ZeRo3D when optMoisture=0
75 Cs2_1D_avg1, Cs2_1D_avg2 : ndarray (nz) — Smagorinsky coefficients
76 Cs2PrRatio_1D : ndarray (nz) — Cs2/Pr_t profile
77 beta1_1D, beta2_1D : ndarray (nz) — beta coefficients
78 StatsDict : dict — accumulated statistics
79 ResetFlag : int — 1 to reset, 0 to accumulate
80 ZeRo3D : ndarray (nx, ny, nz) — pre-allocated zero array
81
82 Returns:
83 --------
84 UpdatedStats : dict
85 """
86
87 # Extract existing statistics
88
89 # Mean profiles
90 U_avg = StatsDict["U"]; V_avg = StatsDict["V"]
91 W_avg = StatsDict["W"]; TH_avg = StatsDict["TH"]
92
93 # Mean gradients
94 dUdz_avg = StatsDict["dUdz"]; dVdz_avg = StatsDict["dVdz"]
95 dTHdz_avg = StatsDict["dTHdz"]
96
97 # Resolved variances
98 u2_avg = StatsDict["u2"]; v2_avg = StatsDict["v2"]
99 w2_avg = StatsDict["w2"]; TH2_avg = StatsDict["TH2"]
100
101 # Resolved fluxes
102 uv_avg = StatsDict["uv"]; uw_avg = StatsDict["uw"]
103 vw_avg = StatsDict["vw"]; uTH_avg = StatsDict["uTH"]
104 vTH_avg = StatsDict["vTH"]; wTH_avg = StatsDict["wTH"]
105
106 # Surface terms
107 M_sfc_avg = StatsDict["M_sfc"]; ustar_avg = StatsDict["ustar"]
108 qz_sfc_sum = StatsDict["qz_sfc"]
109
110 # SGS terms
111 txy_avg = StatsDict["txy"]; txz_avg = StatsDict["txz"]
112 tyz_avg = StatsDict["tyz"]; qz_avg = StatsDict["qz"]
113
114 # SGS coefficients
115 Cs2_1_avg = StatsDict["Cs2_1"]; Cs2_2_avg = StatsDict["Cs2_2"]
116 Cs2PrRatio_avg = StatsDict["Cs2PrRatio"]
117 Beta1_avg = StatsDict["Beta1"]; Beta2_avg = StatsDict["Beta2"]
118
119 # Moisture statistics
120 Q_avg = StatsDict["Q"]
121 dQdz_avg = StatsDict["dQdz"]
122 Q2_avg = StatsDict["Q2"]
123 wQ_avg = StatsDict["wQ"]
124 qHz_avg = StatsDict["qHz"]
125 qm_sfc_sum = StatsDict["qm_sfc"]
126
127 # Constants
128 Ugal = StatsDict["Ugal"]; ZeRo1D = StatsDict["ZeRo1D"]
129
130 # Reset statistics function
131 def ResetStats(_):
132 return {
133 "U": ZeRo1D, "V": ZeRo1D, "W": ZeRo1D, "TH": ZeRo1D,
134 "dUdz": ZeRo1D, "dVdz": ZeRo1D, "dTHdz": ZeRo1D,
135 "u2": ZeRo1D, "v2": ZeRo1D, "w2": ZeRo1D, "TH2": ZeRo1D,
136 "uv": ZeRo1D, "uw": ZeRo1D, "vw": ZeRo1D,
137 "uTH": ZeRo1D, "vTH": ZeRo1D, "wTH": ZeRo1D,
138 "txy": ZeRo1D, "txz": ZeRo1D, "tyz": ZeRo1D,
139 "qz": ZeRo1D,
140 "M_sfc": 0.0, "ustar": 0.0, "qz_sfc": 0.0,
141 "Cs2_1": ZeRo1D, "Cs2_2": ZeRo1D,
142 "Cs2PrRatio": ZeRo1D,
143 "Beta1": ZeRo1D, "Beta2": ZeRo1D,
144 "Q": ZeRo1D, "dQdz": ZeRo1D, "Q2": ZeRo1D,
145 "wQ": ZeRo1D, "qHz": ZeRo1D, "qm_sfc": 0.0,
146 "Ugal": Ugal, "ZeRo1D": ZeRo1D
147 }
148
149 # Update statistics function
150 def UpdateStats(_):
151 # ------------------------------------------------------------
152 # Profiles of mean variables
153 # ------------------------------------------------------------
154 mU = (jnp.mean(u, axis=(0, 1)) + Ugal) * u_scale
155 mV = jnp.mean(v, axis=(0, 1)) * u_scale
156 mW = jnp.mean(w, axis=(0, 1)) * u_scale
157 # TH is stored as anomaly TH' = TH - T_0; add T_0 back for output.
158 # mTH_anom is kept separately as the non-dim fluctuation reference.
159 mTH_anom = jnp.mean(TH, axis=(0, 1))
160 mTH = (mTH_anom + T_0_nondim) * TH_scale
161
162 mdUdz = jnp.mean(dudz, axis=(0, 1)) * (u_scale / z_scale)
163 mdVdz = jnp.mean(dvdz, axis=(0, 1)) * (u_scale / z_scale)
164 mdTHdz = jnp.mean(dTHdz, axis=(0, 1)) * (TH_scale / z_scale)
165
166 # Q mean profile (dimensional kg/kg, stored as-is)
167 mQ = jnp.mean(Q, axis=(0, 1)) * Q_scale
168 mdQdz = jnp.mean(dQdz, axis=(0, 1)) * (Q_scale / z_scale)
169
170 # Updated mean profiles
171 new_U_avg = U_avg + mU
172 new_V_avg = V_avg + mV
173 new_W_avg = W_avg + mW
174 new_TH_avg = TH_avg + mTH
175 new_dUdz_avg = dUdz_avg + mdUdz
176 new_dVdz_avg = dVdz_avg + mdVdz
177 new_dTHdz_avg = dTHdz_avg + mdTHdz
178 new_Q_avg = Q_avg + mQ
179 new_dQdz_avg = dQdz_avg + mdQdz
180
181 # ------------------------------------------------------------
182 # Profiles of resolved variances and horizontal fluxes
183 # ------------------------------------------------------------
184 def ComputeLevel1(k):
185 """Compute variances and horizontal fluxes at vertical level k."""
186 u_f = u[:, :, k] + Ugal - mU[k]
187 v_f = v[:, :, k] - mV[k]
188 w_f = w[:, :, k] - mW[k]
189 TH_f = TH[:, :, k] - mTH_anom[k]
190
191 u2 = jnp.mean(u_f ** 2) * (u_scale ** 2)
192 v2 = jnp.mean(v_f ** 2) * (u_scale ** 2)
193 w2 = jnp.mean(w_f ** 2) * (u_scale ** 2)
194 TH2 = jnp.mean(TH_f ** 2) * (TH_scale ** 2)
195 uv = jnp.mean(u_f * v_f) * (u_scale ** 2)
196 uTH = jnp.mean(u_f * TH_f) * (u_scale * TH_scale)
197 vTH = jnp.mean(v_f * TH_f) * (u_scale * TH_scale)
198
199 # Q variance (at half levels, using mQ in Q_scale units)
200 Q_f = Q[:, :, k] - mQ[k] / Q_scale # fluctuation in non-dim units
201 Q2 = jnp.mean(Q_f ** 2) * (Q_scale ** 2)
202
203 return u2, v2, w2, TH2, uv, uTH, vTH, Q2
204
205 all_levels = jnp.arange(nz)
206 (u2_profile, v2_profile, w2_profile, TH2_profile,
207 uv_profile, uTH_profile, vTH_profile, Q2_profile) = (
208 jax.vmap(ComputeLevel1)(all_levels))
209
210 new_u2_avg = u2_avg + u2_profile
211 new_v2_avg = v2_avg + v2_profile
212 new_w2_avg = w2_avg + w2_profile
213 new_TH2_avg = TH2_avg + TH2_profile
214 new_uv_avg = uv_avg + uv_profile
215 new_uTH_avg = uTH_avg + uTH_profile
216 new_vTH_avg = vTH_avg + vTH_profile
217 new_Q2_avg = Q2_avg + Q2_profile
218
219 # ------------------------------------------------------------
220 # Resolved flux profiles (staggered grid)
221 # ------------------------------------------------------------
222 u_stag = ZeRo3D.copy()
223 v_stag = ZeRo3D.copy()
224 w_stag = w.copy()
225 TH_stag = ZeRo3D.copy()
226 Q_stag = ZeRo3D.copy()
227
228 u_stag = u_stag.at[:, :, 1:nz].set(StagGridAvg(u))
229 v_stag = v_stag.at[:, :, 1:nz].set(StagGridAvg(v))
230 TH_stag = TH_stag.at[:, :, 1:nz].set(StagGridAvg(TH))
231 Q_stag = Q_stag.at[:, :, 1:nz].set(StagGridAvg(Q))
232
233 u_stag = u_stag.at[:, :, 0].set(u[:, :, 0] + Ugal)
234 v_stag = v_stag.at[:, :, 0].set(v[:, :, 0])
235 TH_stag = TH_stag.at[:, :, 0].set(TH[:, :, 0])
236 Q_stag = Q_stag.at[:, :, 0].set(Q[:, :, 0])
237
238 mu_stag = (jnp.mean(u_stag, axis=(0, 1)) + Ugal) * u_scale
239 mv_stag = jnp.mean(v_stag, axis=(0, 1)) * u_scale
240 mw_stag = jnp.mean(w_stag, axis=(0, 1)) * u_scale
241 mTH_stag = jnp.mean(TH_stag, axis=(0, 1)) * TH_scale
242 mQ_stag = jnp.mean(Q_stag, axis=(0, 1)) * Q_scale
243
244 def ComputeLevel2(k):
245 """Compute vertical fluxes at vertical level k."""
246 u_stag_f = u_stag[:, :, k] + Ugal - mu_stag[k]
247 v_stag_f = v_stag[:, :, k] - mv_stag[k]
248 w_stag_f = w_stag[:, :, k] - mw_stag[k]
249 TH_stag_f = TH_stag[:, :, k] - mTH_stag[k]
250 # Q_stag fluctuation in non-dim units
251 Q_stag_f = Q_stag[:, :, k] - mQ_stag[k] / Q_scale
252
253 uw = jnp.mean(u_stag_f * w_stag_f) * (u_scale ** 2)
254 vw = jnp.mean(v_stag_f * w_stag_f) * (u_scale ** 2)
255 wTH = jnp.mean(w_stag_f * TH_stag_f) * (u_scale * TH_scale)
256 wQ = jnp.mean(w_stag_f * Q_stag_f) * (u_scale * Q_scale)
257
258 return uw, vw, wTH, wQ
259
260 (uw_profile, vw_profile, wTH_profile, wQ_profile) = (
261 jax.vmap(ComputeLevel2)(all_levels))
262
263 new_uw_avg = uw_avg + uw_profile
264 new_vw_avg = vw_avg + vw_profile
265 new_wTH_avg = wTH_avg + wTH_profile
266 new_wQ_avg = wQ_avg + wQ_profile
267
268 # ------------------------------------------------------------
269 # SGS stress and flux profiles
270 # ------------------------------------------------------------
271 mtxy = jnp.mean(txy, axis=(0, 1)) * (u_scale ** 2)
272 mtxz = jnp.mean(txz, axis=(0, 1)) * (u_scale ** 2)
273 mtyz = jnp.mean(tyz, axis=(0, 1)) * (u_scale ** 2)
274 mqz = jnp.mean(qz, axis=(0, 1)) * (u_scale * TH_scale)
275 mqHz = jnp.mean(qHz_q, axis=(0, 1)) * (u_scale * Q_scale)
276
277 new_txy_avg = txy_avg + mtxy
278 new_txz_avg = txz_avg + mtxz
279 new_tyz_avg = tyz_avg + mtyz
280 new_qz_avg = qz_avg + mqz
281 new_qHz_avg = qHz_avg + mqHz
282
283 # ------------------------------------------------------------
284 # Surface variables
285 # ------------------------------------------------------------
286 mM_sfc = jnp.mean(M_sfc_loc)
287 mustar = jnp.sqrt(jnp.mean(ustar ** 2)) * u_scale
288 mq_sfc = qz_sfc_avg_nd * u_scale * TH_scale # dimensional K m/s
289 mqm_sfc = qm_sfc_avg_nd * u_scale * Q_scale # dimensional kg/kg m/s
290
291 new_M_sfc_avg = M_sfc_avg + mM_sfc
292 new_ustar_avg = ustar_avg + mustar
293 new_qz_sfc_sum = qz_sfc_sum + mq_sfc
294 new_qm_sfc_sum = qm_sfc_sum + mqm_sfc
295
296 # ------------------------------------------------------------
297 # SGS coefficients
298 # ------------------------------------------------------------
299 new_Cs2_1_avg = Cs2_1_avg + Cs2_1D_avg1
300 new_Cs2_2_avg = Cs2_2_avg + Cs2_1D_avg2
301 new_Cs2PrRatio_avg = Cs2PrRatio_avg + Cs2PrRatio_1D
302 new_Beta1_avg = Beta1_avg + beta1_1D
303 new_Beta2_avg = Beta2_avg + beta2_1D
304
305 return {
306 "U": new_U_avg, "V": new_V_avg, "W": new_W_avg, "TH": new_TH_avg,
307 "dUdz": new_dUdz_avg, "dVdz": new_dVdz_avg,
308 "dTHdz": new_dTHdz_avg,
309 "u2": new_u2_avg, "v2": new_v2_avg, "w2": new_w2_avg,
310 "TH2": new_TH2_avg,
311 "uv": new_uv_avg, "uw": new_uw_avg, "vw": new_vw_avg,
312 "uTH": new_uTH_avg, "vTH": new_vTH_avg, "wTH": new_wTH_avg,
313 "txy": new_txy_avg, "txz": new_txz_avg, "tyz": new_tyz_avg,
314 "qz": new_qz_avg,
315 "M_sfc": new_M_sfc_avg, "ustar": new_ustar_avg,
316 "qz_sfc": new_qz_sfc_sum,
317 "Cs2_1": new_Cs2_1_avg, "Cs2_2": new_Cs2_2_avg,
318 "Cs2PrRatio": new_Cs2PrRatio_avg,
319 "Beta1": new_Beta1_avg, "Beta2": new_Beta2_avg,
320 "Q": new_Q_avg, "dQdz": new_dQdz_avg, "Q2": new_Q2_avg,
321 "wQ": new_wQ_avg, "qHz": new_qHz_avg, "qm_sfc": new_qm_sfc_sum,
322 "Ugal": Ugal, "ZeRo1D": ZeRo1D
323 }
324
325 UpdatedStats = jax.lax.cond(
326 ResetFlag == 1,
327 ResetStats,
328 UpdateStats,
329 None
330 )
331
332 return UpdatedStats
333
334
335def InitializeStats(ZeRo1D):
336 """
337 Initialize the statistics dictionary with zeros.
338
339 Parameters:
340 -----------
341 ZeRo1D : ndarray
342 Pre-allocated zero array
343
344 Returns:
345 --------
346 StatsDict : dict
347 Initialized statistics dictionary
348 """
349 StatsDict = {
350 # Mean profiles
351 "U": ZeRo1D, "V": ZeRo1D, "W": ZeRo1D, "TH": ZeRo1D,
352
353 # Mean gradients
354 "dUdz": ZeRo1D, "dVdz": ZeRo1D, "dTHdz": ZeRo1D,
355
356 # Resolved variances
357 "u2": ZeRo1D, "v2": ZeRo1D, "w2": ZeRo1D, "TH2": ZeRo1D,
358
359 # Resolved fluxes
360 "uv": ZeRo1D, "uw": ZeRo1D, "vw": ZeRo1D,
361 "uTH": ZeRo1D, "vTH": ZeRo1D, "wTH": ZeRo1D,
362
363 # SGS terms
364 "txy": ZeRo1D, "txz": ZeRo1D, "tyz": ZeRo1D,
365 "qz": ZeRo1D,
366
367 # Surface terms
368 "M_sfc": 0.0, "ustar": 0.0, "qz_sfc": 0.0,
369
370 # SGS coefficients
371 "Cs2_1": ZeRo1D, "Cs2_2": ZeRo1D, "Cs2PrRatio": ZeRo1D,
372 "Beta1": ZeRo1D, "Beta2": ZeRo1D,
373
374 # Moisture statistics (non-zero only when optMoisture=1)
375 "Q": ZeRo1D, "dQdz": ZeRo1D, "Q2": ZeRo1D,
376 "wQ": ZeRo1D, "qHz": ZeRo1D, "qm_sfc": 0.0,
377
378 # Constants
379 "Ugal": Ugal, "ZeRo1D": ZeRo1D
380 }
381
382 return StatsDict