Compute Statistics

Source Code: Statistics

Statistics.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Statistics.py
 18===================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-5-3
 24:Description: this file is used to compute various statistics.
 25"""
 26
 27# ============================================================
 28#  Imports
 29# ============================================================
 30
 31import jax
 32import jax.numpy as jnp
 33from ..utilities.Utilities import StagGridAvg
 34
 35# Import configuration from namelist
 36from ..config.ConfigLoader import *
 37
 38# Import derived variables
 39from ..config.DerivedVars import *
 40
 41
 42# ============================================================
 43#  Compute planar-averaged statistics
 44# ============================================================
 45
 46@jax.jit
 47def ComputeStats(
 48        u, v, w, TH, Q,
 49        dudz, dvdz, dTHdz, dQdz,
 50        M_sfc_loc, ustar, qz_sfc_avg_nd, qm_sfc_avg_nd,
 51        txy, txz, tyz, qz, qHz_q,
 52        Cs2_1D_avg1, Cs2_1D_avg2,
 53        Cs2PrRatio_1D,
 54        beta1_1D, beta2_1D,
 55        StatsDict, ResetFlag,
 56        ZeRo3D):
 57    """
 58    Computes spatial averaged-statistics for LES flow variables.
 59
 60    Parameters:
 61    -----------
 62    u, v, w : ndarray — velocity components
 63    TH : ndarray — potential temperature (stored as anomaly TH - T_0)
 64    Q : ndarray — specific humidity (kg/kg); ZeRo3D when optMoisture=0
 65    dudz, dvdz : ndarray — velocity vertical gradients
 66    dTHdz : ndarray — potential temperature vertical gradient
 67    dQdz : ndarray — specific humidity vertical gradient; ZeRo3D when optMoisture=0
 68    M_sfc_loc : ndarray (nx, ny) — surface wind speed
 69    ustar : ndarray (nx, ny) — friction velocity
 70    qz_sfc_avg_nd : scalar — non-dim surface heat flux
 71    qm_sfc_avg_nd : scalar — non-dim surface moisture flux; 0.0 when optMoisture=0
 72    txy, txz, tyz : ndarray — SGS stress components
 73    qz : ndarray — SGS heat flux in z
 74    qHz_q : ndarray — SGS moisture flux in z; ZeRo3D when optMoisture=0
 75    Cs2_1D_avg1, Cs2_1D_avg2 : ndarray (nz) — Smagorinsky coefficients
 76    Cs2PrRatio_1D : ndarray (nz) — Cs2/Pr_t profile
 77    beta1_1D, beta2_1D : ndarray (nz) — beta coefficients
 78    StatsDict : dict — accumulated statistics
 79    ResetFlag : int — 1 to reset, 0 to accumulate
 80    ZeRo3D : ndarray (nx, ny, nz) — pre-allocated zero array
 81
 82    Returns:
 83    --------
 84    UpdatedStats : dict
 85    """
 86
 87    # Extract existing statistics
 88
 89    # Mean profiles
 90    U_avg = StatsDict["U"]; V_avg = StatsDict["V"]
 91    W_avg = StatsDict["W"]; TH_avg = StatsDict["TH"]
 92
 93    # Mean gradients
 94    dUdz_avg = StatsDict["dUdz"]; dVdz_avg = StatsDict["dVdz"]
 95    dTHdz_avg = StatsDict["dTHdz"]
 96
 97    # Resolved variances
 98    u2_avg = StatsDict["u2"]; v2_avg = StatsDict["v2"]
 99    w2_avg = StatsDict["w2"]; TH2_avg = StatsDict["TH2"]
100
101    # Resolved fluxes
102    uv_avg = StatsDict["uv"]; uw_avg = StatsDict["uw"]
103    vw_avg = StatsDict["vw"]; uTH_avg = StatsDict["uTH"]
104    vTH_avg = StatsDict["vTH"]; wTH_avg = StatsDict["wTH"]
105
106    # Surface terms
107    M_sfc_avg = StatsDict["M_sfc"]; ustar_avg = StatsDict["ustar"]
108    qz_sfc_sum = StatsDict["qz_sfc"]
109
110    # SGS terms
111    txy_avg = StatsDict["txy"]; txz_avg = StatsDict["txz"]
112    tyz_avg = StatsDict["tyz"]; qz_avg = StatsDict["qz"]
113
114    # SGS coefficients
115    Cs2_1_avg = StatsDict["Cs2_1"]; Cs2_2_avg = StatsDict["Cs2_2"]
116    Cs2PrRatio_avg = StatsDict["Cs2PrRatio"]
117    Beta1_avg = StatsDict["Beta1"]; Beta2_avg = StatsDict["Beta2"]
118
119    # Moisture statistics
120    Q_avg      = StatsDict["Q"]
121    dQdz_avg   = StatsDict["dQdz"]
122    Q2_avg     = StatsDict["Q2"]
123    wQ_avg     = StatsDict["wQ"]
124    qHz_avg    = StatsDict["qHz"]
125    qm_sfc_sum = StatsDict["qm_sfc"]
126
127    # Constants
128    Ugal = StatsDict["Ugal"]; ZeRo1D = StatsDict["ZeRo1D"]
129
130    # Reset statistics function
131    def ResetStats(_):
132        return {
133            "U": ZeRo1D, "V": ZeRo1D, "W": ZeRo1D, "TH": ZeRo1D,
134            "dUdz": ZeRo1D, "dVdz": ZeRo1D, "dTHdz": ZeRo1D,
135            "u2": ZeRo1D, "v2": ZeRo1D, "w2": ZeRo1D, "TH2": ZeRo1D,
136            "uv": ZeRo1D, "uw": ZeRo1D, "vw": ZeRo1D,
137            "uTH": ZeRo1D, "vTH": ZeRo1D, "wTH": ZeRo1D,
138            "txy": ZeRo1D, "txz": ZeRo1D, "tyz": ZeRo1D,
139            "qz": ZeRo1D,
140            "M_sfc": 0.0, "ustar": 0.0, "qz_sfc": 0.0,
141            "Cs2_1": ZeRo1D, "Cs2_2": ZeRo1D,
142            "Cs2PrRatio": ZeRo1D,
143            "Beta1": ZeRo1D, "Beta2": ZeRo1D,
144            "Q": ZeRo1D, "dQdz": ZeRo1D, "Q2": ZeRo1D,
145            "wQ": ZeRo1D, "qHz": ZeRo1D, "qm_sfc": 0.0,
146            "Ugal": Ugal, "ZeRo1D": ZeRo1D
147        }
148
149    # Update statistics function
150    def UpdateStats(_):
151        # ------------------------------------------------------------
152        # Profiles of mean variables
153        # ------------------------------------------------------------
154        mU = (jnp.mean(u, axis=(0, 1)) + Ugal) * u_scale
155        mV = jnp.mean(v, axis=(0, 1)) * u_scale
156        mW = jnp.mean(w, axis=(0, 1)) * u_scale
157        # TH is stored as anomaly TH' = TH - T_0; add T_0 back for output.
158        # mTH_anom is kept separately as the non-dim fluctuation reference.
159        mTH_anom = jnp.mean(TH, axis=(0, 1))
160        mTH = (mTH_anom + T_0_nondim) * TH_scale
161
162        mdUdz  = jnp.mean(dudz,  axis=(0, 1)) * (u_scale  / z_scale)
163        mdVdz  = jnp.mean(dvdz,  axis=(0, 1)) * (u_scale  / z_scale)
164        mdTHdz = jnp.mean(dTHdz, axis=(0, 1)) * (TH_scale / z_scale)
165
166        # Q mean profile (dimensional kg/kg, stored as-is)
167        mQ     = jnp.mean(Q,    axis=(0, 1)) * Q_scale
168        mdQdz  = jnp.mean(dQdz, axis=(0, 1)) * (Q_scale  / z_scale)
169
170        # Updated mean profiles
171        new_U_avg    = U_avg    + mU
172        new_V_avg    = V_avg    + mV
173        new_W_avg    = W_avg    + mW
174        new_TH_avg   = TH_avg   + mTH
175        new_dUdz_avg = dUdz_avg + mdUdz
176        new_dVdz_avg = dVdz_avg + mdVdz
177        new_dTHdz_avg = dTHdz_avg + mdTHdz
178        new_Q_avg    = Q_avg    + mQ
179        new_dQdz_avg = dQdz_avg + mdQdz
180
181        # ------------------------------------------------------------
182        # Profiles of resolved variances and horizontal fluxes
183        # ------------------------------------------------------------
184        def ComputeLevel1(k):
185            """Compute variances and horizontal fluxes at vertical level k."""
186            u_f  = u[:, :, k]  + Ugal - mU[k]
187            v_f  = v[:, :, k]         - mV[k]
188            w_f  = w[:, :, k]         - mW[k]
189            TH_f = TH[:, :, k]        - mTH_anom[k]
190
191            u2  = jnp.mean(u_f  ** 2) * (u_scale  ** 2)
192            v2  = jnp.mean(v_f  ** 2) * (u_scale  ** 2)
193            w2  = jnp.mean(w_f  ** 2) * (u_scale  ** 2)
194            TH2 = jnp.mean(TH_f ** 2) * (TH_scale ** 2)
195            uv  = jnp.mean(u_f  * v_f)  * (u_scale ** 2)
196            uTH = jnp.mean(u_f  * TH_f) * (u_scale * TH_scale)
197            vTH = jnp.mean(v_f  * TH_f) * (u_scale * TH_scale)
198
199            # Q variance (at half levels, using mQ in Q_scale units)
200            Q_f = Q[:, :, k] - mQ[k] / Q_scale   # fluctuation in non-dim units
201            Q2  = jnp.mean(Q_f ** 2) * (Q_scale ** 2)
202
203            return u2, v2, w2, TH2, uv, uTH, vTH, Q2
204
205        all_levels = jnp.arange(nz)
206        (u2_profile, v2_profile, w2_profile, TH2_profile,
207         uv_profile, uTH_profile, vTH_profile, Q2_profile) = (
208            jax.vmap(ComputeLevel1)(all_levels))
209
210        new_u2_avg  = u2_avg  + u2_profile
211        new_v2_avg  = v2_avg  + v2_profile
212        new_w2_avg  = w2_avg  + w2_profile
213        new_TH2_avg = TH2_avg + TH2_profile
214        new_uv_avg  = uv_avg  + uv_profile
215        new_uTH_avg = uTH_avg + uTH_profile
216        new_vTH_avg = vTH_avg + vTH_profile
217        new_Q2_avg  = Q2_avg  + Q2_profile
218
219        # ------------------------------------------------------------
220        # Resolved flux profiles (staggered grid)
221        # ------------------------------------------------------------
222        u_stag  = ZeRo3D.copy()
223        v_stag  = ZeRo3D.copy()
224        w_stag  = w.copy()
225        TH_stag = ZeRo3D.copy()
226        Q_stag  = ZeRo3D.copy()
227
228        u_stag  = u_stag.at[:,  :, 1:nz].set(StagGridAvg(u))
229        v_stag  = v_stag.at[:,  :, 1:nz].set(StagGridAvg(v))
230        TH_stag = TH_stag.at[:, :, 1:nz].set(StagGridAvg(TH))
231        Q_stag  = Q_stag.at[:,  :, 1:nz].set(StagGridAvg(Q))
232
233        u_stag  = u_stag.at[:,  :, 0].set(u[:, :, 0] + Ugal)
234        v_stag  = v_stag.at[:,  :, 0].set(v[:, :, 0])
235        TH_stag = TH_stag.at[:, :, 0].set(TH[:, :, 0])
236        Q_stag  = Q_stag.at[:,  :, 0].set(Q[:, :, 0])
237
238        mu_stag  = (jnp.mean(u_stag,  axis=(0, 1)) + Ugal) * u_scale
239        mv_stag  = jnp.mean(v_stag,   axis=(0, 1)) * u_scale
240        mw_stag  = jnp.mean(w_stag,   axis=(0, 1)) * u_scale
241        mTH_stag = jnp.mean(TH_stag,  axis=(0, 1)) * TH_scale
242        mQ_stag  = jnp.mean(Q_stag,   axis=(0, 1)) * Q_scale
243
244        def ComputeLevel2(k):
245            """Compute vertical fluxes at vertical level k."""
246            u_stag_f  = u_stag[:, :, k]  + Ugal - mu_stag[k]
247            v_stag_f  = v_stag[:, :, k]         - mv_stag[k]
248            w_stag_f  = w_stag[:, :, k]         - mw_stag[k]
249            TH_stag_f = TH_stag[:, :, k]        - mTH_stag[k]
250            # Q_stag fluctuation in non-dim units
251            Q_stag_f  = Q_stag[:, :, k]         - mQ_stag[k] / Q_scale
252
253            uw  = jnp.mean(u_stag_f  * w_stag_f) * (u_scale  ** 2)
254            vw  = jnp.mean(v_stag_f  * w_stag_f) * (u_scale  ** 2)
255            wTH = jnp.mean(w_stag_f  * TH_stag_f) * (u_scale * TH_scale)
256            wQ  = jnp.mean(w_stag_f  * Q_stag_f)  * (u_scale * Q_scale)
257
258            return uw, vw, wTH, wQ
259
260        (uw_profile, vw_profile, wTH_profile, wQ_profile) = (
261            jax.vmap(ComputeLevel2)(all_levels))
262
263        new_uw_avg  = uw_avg  + uw_profile
264        new_vw_avg  = vw_avg  + vw_profile
265        new_wTH_avg = wTH_avg + wTH_profile
266        new_wQ_avg  = wQ_avg  + wQ_profile
267
268        # ------------------------------------------------------------
269        # SGS stress and flux profiles
270        # ------------------------------------------------------------
271        mtxy = jnp.mean(txy,   axis=(0, 1)) * (u_scale ** 2)
272        mtxz = jnp.mean(txz,   axis=(0, 1)) * (u_scale ** 2)
273        mtyz = jnp.mean(tyz,   axis=(0, 1)) * (u_scale ** 2)
274        mqz  = jnp.mean(qz,    axis=(0, 1)) * (u_scale * TH_scale)
275        mqHz = jnp.mean(qHz_q, axis=(0, 1)) * (u_scale * Q_scale)
276
277        new_txy_avg = txy_avg + mtxy
278        new_txz_avg = txz_avg + mtxz
279        new_tyz_avg = tyz_avg + mtyz
280        new_qz_avg  = qz_avg  + mqz
281        new_qHz_avg = qHz_avg + mqHz
282
283        # ------------------------------------------------------------
284        # Surface variables
285        # ------------------------------------------------------------
286        mM_sfc = jnp.mean(M_sfc_loc)
287        mustar = jnp.sqrt(jnp.mean(ustar ** 2)) * u_scale
288        mq_sfc  = qz_sfc_avg_nd * u_scale * TH_scale   # dimensional K m/s
289        mqm_sfc = qm_sfc_avg_nd * u_scale * Q_scale    # dimensional kg/kg m/s
290
291        new_M_sfc_avg   = M_sfc_avg   + mM_sfc
292        new_ustar_avg   = ustar_avg   + mustar
293        new_qz_sfc_sum  = qz_sfc_sum  + mq_sfc
294        new_qm_sfc_sum  = qm_sfc_sum  + mqm_sfc
295
296        # ------------------------------------------------------------
297        # SGS coefficients
298        # ------------------------------------------------------------
299        new_Cs2_1_avg      = Cs2_1_avg      + Cs2_1D_avg1
300        new_Cs2_2_avg      = Cs2_2_avg      + Cs2_1D_avg2
301        new_Cs2PrRatio_avg = Cs2PrRatio_avg + Cs2PrRatio_1D
302        new_Beta1_avg      = Beta1_avg      + beta1_1D
303        new_Beta2_avg      = Beta2_avg      + beta2_1D
304
305        return {
306            "U": new_U_avg, "V": new_V_avg, "W": new_W_avg, "TH": new_TH_avg,
307            "dUdz": new_dUdz_avg, "dVdz": new_dVdz_avg,
308            "dTHdz": new_dTHdz_avg,
309            "u2": new_u2_avg, "v2": new_v2_avg, "w2": new_w2_avg,
310            "TH2": new_TH2_avg,
311            "uv": new_uv_avg, "uw": new_uw_avg, "vw": new_vw_avg,
312            "uTH": new_uTH_avg, "vTH": new_vTH_avg, "wTH": new_wTH_avg,
313            "txy": new_txy_avg, "txz": new_txz_avg, "tyz": new_tyz_avg,
314            "qz": new_qz_avg,
315            "M_sfc": new_M_sfc_avg, "ustar": new_ustar_avg,
316            "qz_sfc": new_qz_sfc_sum,
317            "Cs2_1": new_Cs2_1_avg, "Cs2_2": new_Cs2_2_avg,
318            "Cs2PrRatio": new_Cs2PrRatio_avg,
319            "Beta1": new_Beta1_avg, "Beta2": new_Beta2_avg,
320            "Q": new_Q_avg, "dQdz": new_dQdz_avg, "Q2": new_Q2_avg,
321            "wQ": new_wQ_avg, "qHz": new_qHz_avg, "qm_sfc": new_qm_sfc_sum,
322            "Ugal": Ugal, "ZeRo1D": ZeRo1D
323        }
324
325    UpdatedStats = jax.lax.cond(
326        ResetFlag == 1,
327        ResetStats,
328        UpdateStats,
329        None
330    )
331
332    return UpdatedStats
333
334
335def InitializeStats(ZeRo1D):
336    """
337    Initialize the statistics dictionary with zeros.
338
339    Parameters:
340    -----------
341    ZeRo1D : ndarray
342        Pre-allocated zero array
343
344    Returns:
345    --------
346    StatsDict : dict
347        Initialized statistics dictionary
348    """
349    StatsDict = {
350        # Mean profiles
351        "U": ZeRo1D, "V": ZeRo1D, "W": ZeRo1D, "TH": ZeRo1D,
352
353        # Mean gradients
354        "dUdz": ZeRo1D, "dVdz": ZeRo1D, "dTHdz": ZeRo1D,
355
356        # Resolved variances
357        "u2": ZeRo1D, "v2": ZeRo1D, "w2": ZeRo1D, "TH2": ZeRo1D,
358
359        # Resolved fluxes
360        "uv": ZeRo1D, "uw": ZeRo1D, "vw": ZeRo1D,
361        "uTH": ZeRo1D, "vTH": ZeRo1D, "wTH": ZeRo1D,
362
363        # SGS terms
364        "txy": ZeRo1D, "txz": ZeRo1D, "tyz": ZeRo1D,
365        "qz": ZeRo1D,
366
367        # Surface terms
368        "M_sfc": 0.0, "ustar": 0.0, "qz_sfc": 0.0,
369
370        # SGS coefficients
371        "Cs2_1": ZeRo1D, "Cs2_2": ZeRo1D, "Cs2PrRatio": ZeRo1D,
372        "Beta1": ZeRo1D, "Beta2": ZeRo1D,
373
374        # Moisture statistics (non-zero only when optMoisture=1)
375        "Q": ZeRo1D, "dQdz": ZeRo1D, "Q2": ZeRo1D,
376        "wQ": ZeRo1D, "qHz": ZeRo1D, "qm_sfc": 0.0,
377
378        # Constants
379        "Ugal": Ugal, "ZeRo1D": ZeRo1D
380    }
381
382    return StatsDict