Main Program

Source Code: Main

Main.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Main.py
 18===========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-5
 24:Description: main file for JAX-ALFA
 25"""
 26
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31# This file is for run time
 32from .config.Imports import ImportLES
 33ImportLES()
 34
 35# Import derived variables
 36from .config.DerivedVars import *
 37
 38# This file is for IDE static analysis during development time
 39from .utilities.Pycharm import *
 40
 41
 42# ============================================================
 43#  Initialize Static Variables
 44# ============================================================
 45
 46kx2, ky2 = Wavenumber()
 47ZeRo3D = ZeRo3DIni()
 48ZeRo2D = ZeRo2DIni()
 49ZeRo1D = ZeRo1DIni()
 50ZeRo3D_fft = ZeRo3D_fftIni()
 51ZeRo3D_pad = ZeRo3D_padIni()
 52ZeRo3D_pad_fft = ZeRo3D_pad_fftIni()
 53
 54# Static variables related to pressure solver
 55# optPressureSolver = 0: LU (original)   1: Thomas (tridiagonal, faster)
 56if optPressureSolver == 1:
 57    (kr2_pressure, kc2_pressure,
 58     a_pressure, b_pressure, c_pressure,
 59     b_thomas, m_thomas) = ThomasPressureInit()
 60else:
 61    (kr2_pressure, kc2_pressure,
 62     a_pressure, b_pressure, c_pressure) = PressureInit()
 63
 64
 65# ============================================================
 66#  Initialize velocity, temperature, etc.
 67# ============================================================
 68
 69u, v, w = Initialize_uvw()
 70TH = Initialize_TH()
 71if optMoisture >= 1:
 72    Q = Initialize_Q()
 73    RHS_Q_previous = ZeRo3D.copy()
 74    if optMoistureSurfBC >= 1:
 75        MoistureSurfaceBC_series = Initialize_MoistureSurfaceBC()
 76    Qadv = ZeRo3D
 77else:
 78    Q = ZeRo3D
 79    Qadv = ZeRo3D
 80
 81if optGeoWind == 0:
 82    Ug, Vg = Initialize_GeoWind()
 83else:
 84    GeoWind_U, GeoWind_V = Initialize_GeoWind_Varying()
 85    Ug = jnp.broadcast_to(GeoWind_U[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
 86    Vg = jnp.broadcast_to(GeoWind_V[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
 87
 88RayleighDampCoeff, RayleighDampCoeff_stag = (
 89    Initialize_RayleighDampingLayer())
 90
 91RHS_u_previous  = ZeRo3D.copy()
 92RHS_v_previous  = ZeRo3D.copy()
 93RHS_w_previous  = ZeRo3D.copy()
 94RHS_TH_previous = ZeRo3D.copy()
 95
 96CFLmax = 0
 97CFLmax_iteration = 1
 98
 99# ============================================================
100#  Initialize surface variables
101# ============================================================
102psi2D_m = ZeRo2D.copy()
103psi2D_m0 = ZeRo2D.copy()
104psi2D_h = ZeRo2D.copy()
105psi2D_h0 = ZeRo2D.copy()
106fi2D_m = 1.0 + ZeRo2D.copy()
107fi2D_h = 1.0 + ZeRo2D.copy()
108
109MOSTfunctions = (psi2D_m, psi2D_m0,
110                 psi2D_h, psi2D_h0,
111                 fi2D_m, fi2D_h)
112
113# Load time-varying surface BC series once before the loop (optSurfBC >= 1)
114if optSurfBC >= 1:
115    SurfaceBC_series = Initialize_SurfaceBC()
116
117# Load large-scale advection forcing once before the loop (optAdvection >= 1)
118if optAdvection >= 1:
119    AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q = Initialize_AdvForcing()
120    Uadv  = jnp.broadcast_to(AdvForcing_U[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
121    Vadv  = jnp.broadcast_to(AdvForcing_V[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
122    THadv = jnp.broadcast_to(AdvForcing_TH[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
123    if optMoisture >= 1:
124        Qadv  = jnp.broadcast_to(AdvForcing_Q[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
125else:
126    Uadv  = ZeRo3D
127    Vadv  = ZeRo3D
128    THadv = ZeRo3D
129
130
131# ============================================================
132# Initialize statistics variables
133# ============================================================
134StatsDict = InitializeStats(ZeRo1D)
135SampleCounter = 0  # Counter to sample statistics
136
137# For STAB-SM (optSgs=5) the dynamic 1D profile variables are never set by
138# the dynamic SGS branch.  Initialize them here so ComputeStats always has
139# valid arrays.  They are overwritten each iteration inside the optSgs==5 block.
140if optSgs == 5:
141    Cs2_1D_avg1   = ZeRo1D.copy()
142    Cs2_1D_avg2   = ZeRo1D.copy()
143    Cs2PrRatio_1D = ZeRo1D.copy()
144    beta1_1D      = ZeRo1D.copy()
145    beta2_1D      = ZeRo1D.copy()
146OutputDir = os.path.join(os.environ['JAXALFA_RUNDIR'], 'output')
147os.makedirs(OutputDir, exist_ok=True)
148
149
150# ============================================================
151#  Main simulation loop
152# ============================================================
153
154tic_tot = time.time()
155
156for iteration in range(istep, nsteps+1, 1):
157
158    if iteration > istep:
159
160        RHS_u_previous  = RHS_u
161        RHS_v_previous  = RHS_v
162        RHS_w_previous  = RHS_w
163        RHS_TH_previous = RHS_TH
164        if optMoisture >= 1:
165            RHS_Q_previous = RHS_Q
166
167    # ------------------------------------------------------------
168    #  Update time/height-varying geostrophic wind (optGeoWind >= 1)
169    # ------------------------------------------------------------
170    if optGeoWind >= 1:
171        Ug = jnp.broadcast_to(
172            GeoWind_U[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
173        Vg = jnp.broadcast_to(
174            GeoWind_V[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
175
176    # ------------------------------------------------------------
177    #  Update time/height-varying large-scale advection (optAdvection >= 1)
178    # ------------------------------------------------------------
179    if optAdvection >= 1:
180        Uadv  = jnp.broadcast_to(
181            AdvForcing_U[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
182        Vadv  = jnp.broadcast_to(
183            AdvForcing_V[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
184        THadv = jnp.broadcast_to(
185            AdvForcing_TH[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
186        if optMoisture >= 1:
187            Qadv = jnp.broadcast_to(
188                AdvForcing_Q[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
189
190    # ------------------------------------------------------------
191    #  Filtering and FFT Computations
192    # ------------------------------------------------------------
193    u, u_fft = Filtering_Explicit(FFT(u))
194    v, v_fft = Filtering_Explicit(FFT(v))
195    w, w_fft = Filtering_Explicit(FFT(w))
196
197    TH, _ = Filtering_Explicit(FFT(TH))
198    if optMoisture >= 1:
199        Q, _ = Filtering_Explicit(FFT(Q))
200
201    # ------------------------------------------------------------
202    #  Compute Surface Fluxes
203    #
204    #  All branches set:
205    #    M_sfc_loc   (nx, ny)  surface wind speed
206    #    ustar       (nx, ny)  friction velocity
207    #    qz_sfc_step (nx, ny)  surface heat flux field  (qz = -u* x th*)
208    #    qz_sfc_avg  scalar    planar-mean, non-dimensional
209    #    invOB       (nx, ny)  inverse Obukhov length
210    #    MOSTfunctions         updated stability functions
211    # ------------------------------------------------------------
212
213    if optSurfBC == 0:
214        # Constant prescribed heat flux
215        if optSurfFlux == 0:
216            (M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions) = (
217                SurfaceFlux_HomogeneousConstantFlux(u, v, TH, MOSTfunctions))
218        else:
219            (M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions) = (
220                SurfaceFlux_HeterogeneousConstantFlux(u, v, TH, MOSTfunctions))
221        qz_sfc_step = qz_sfc  # global (nx,ny) array from DerivedVars
222
223    elif optSurfBC == 1:
224        # Time-varying prescribed heat flux
225        sfc_val = SurfaceBC_series[iteration - 1]
226        if optSurfFlux == 0:
227            (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
228                SurfaceFlux_HomogeneousVaryingFlux(u, v, TH, sfc_val, MOSTfunctions))
229        else:
230            (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
231                SurfaceFlux_HeterogeneousVaryingFlux(u, v, TH, sfc_val, MOSTfunctions))
232
233    else:
234        # optSurfBC == 2: time-varying prescribed surface temperature
235        sfc_val = SurfaceBC_series[iteration - 1]
236        if optSurfFlux == 0:
237            (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
238                SurfaceFlux_HomogeneousPrescribedTemperature(
239                    u, v, TH, sfc_val, MOSTfunctions))
240        else:
241            (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
242                SurfaceFlux_HeterogeneousPrescribedTemperature(
243                    u, v, TH, sfc_val, MOSTfunctions))
244
245    # ------------------------------------------------------------
246    #  Compute Surface Moisture Flux (optMoisture >= 1)
247    #
248    #  Uses ustar and MOSTfunctions already computed above.
249    #  All branches set:
250    #    qm_sfc_step (nx, ny)   surface moisture flux field
251    #    qm_sfc_avg  scalar     planar-mean, non-dimensional
252    # ------------------------------------------------------------
253    if optMoisture >= 1:
254        if optMoistureSurfBC == 0:
255            qm_sfc_step = qm_sfc          # constant from DerivedVars
256        elif optMoistureSurfBC == 1:
257            qm_sfc_t    = MoistureSurfaceBC_series[iteration - 1]
258            qm_sfc_step = qm_sfc_t * jnp.ones((nx, ny))
259        else:  # optMoistureSurfBC == 2: prescribed surface Q
260            Q_sfc_t = MoistureSurfaceBC_series[iteration - 1]
261            if optSurfFlux == 0:
262                qm_sfc_step = SurfaceMoistureFlux_HomogeneousPrescribedQ(
263                    Q, ustar, Q_sfc_t, MOSTfunctions)
264            else:
265                qm_sfc_step = SurfaceMoistureFlux_HeterogeneousPrescribedQ(
266                    Q, ustar, Q_sfc_t, MOSTfunctions)
267        qm_sfc_avg = jnp.mean(qm_sfc_step)
268    else:
269        qm_sfc_step = ZeRo2D
270        qm_sfc_avg  = 0.0
271
272    # ------------------------------------------------------------
273    #  Compute Velocity Gradients
274    # ------------------------------------------------------------
275    (dudx, dvdx, dwdx,
276     dudy, dvdy, dwdy,
277     dudz, dvdz, dwdz) = (
278        velocityGradients(
279            u, v, w,
280            u_fft, v_fft, w_fft,
281            kx2, ky2,
282            ustar, M_sfc_loc, MOSTfunctions,
283            ZeRo3D))
284
285    (dTHdx, dTHdy, dTHdz) = (
286        potentialTemperatureGradients(
287            TH,
288            kx2, ky2,
289            ustar, qz_sfc_step, MOSTfunctions,
290            ZeRo3D))
291
292    if optMoisture >= 1:
293        (dQdx, dQdy, dQdz) = moistureGradients(
294            Q, kx2, ky2, ustar, qm_sfc_step, MOSTfunctions, ZeRo3D)
295    else:
296        dQdx = ZeRo3D; dQdy = ZeRo3D; dQdz = ZeRo3D
297
298    # ------------------------------------------------------------
299    #  Compute Advection Terms
300    # ------------------------------------------------------------
301    Cx, Cy, Cz = Advection(
302        u, v, w,
303        dudy, dudz, dvdx, dvdz, dwdx, dwdy,
304        ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
305        ZeRo3D_pad_fft)
306
307    THAdvectionSum = ScalarAdvection(
308        u, v, w,
309        dTHdx, dTHdy, dTHdz,
310        ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
311        ZeRo3D_pad_fft)
312
313    if optMoisture >= 1:
314        QAdvectionSum = ScalarAdvection(
315            u, v, w,
316            dQdx, dQdy, dQdz,
317            ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
318            ZeRo3D_pad_fft)
319
320    # ------------------------------------------------------------
321    #  Compute Buoyancy Terms
322    # ------------------------------------------------------------
323    H = Q if optMoisture >= 1 else ZeRo3D
324    if optBuoyancy == 0:
325        buoyancy = BuoyancyOpt1(TH, H, ZeRo3D)
326    else:
327        buoyancy = BuoyancyOpt2(TH, H, ZeRo3D)
328
329    # ------------------------------------------------------------
330    #  Compute SGS Terms
331    # ------------------------------------------------------------
332
333    if 1 <= optSgs <= 4 and (iteration == istep or iteration % dynamicSGS_call_time == 0):
334
335        # print('Dynamic SGS')
336
337        (divtx, divty, divtz,
338         Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D,
339         Cs2_3D,
340         dynamicSGSmomentum) = (
341            DivStressDynamicSGS(
342                dudx, dvdx, dwdx,
343                dudy, dvdy, dwdy,
344                dudz, dvdz, dwdz,
345                u, v, w, M_sfc_loc, MOSTfunctions,
346                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
347                kx2, ky2))
348
349        (qz, divq, Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D) = (
350            DivFluxDynamicSGS(
351                dynamicSGSmomentum[10:],
352                TH,
353                dTHdx, dTHdy, dTHdz,
354                qz_sfc_step,
355                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
356                kx2, ky2))
357
358        # Moisture SGS: reuse strain rates from dynamic momentum SGS with
359        # the same Cs2PrRatio (turbulent Sc = turbulent Pr approximation).
360        # dynamicSGSmomentum[19:23] = (S_uvp, S_uvp_pad, S_w, S_w_pad)
361        if optMoisture >= 1:
362            qHz_q, divqm = DivFluxStaticSGS(
363                (dynamicSGSmomentum[19], dynamicSGSmomentum[20],
364                 dynamicSGSmomentum[21], dynamicSGSmomentum[22]),
365                Cs2PrRatio_3D,
366                dQdx, dQdy, dQdz,
367                qm_sfc_step,
368                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
369                kx2, ky2)
370        else:
371            qHz_q = ZeRo3D; divqm = ZeRo3D
372
373        # Unpack variables for computation of statistics
374        _, _, _, txy, txz, tyz = dynamicSGSmomentum[0:6]
375
376    elif optSgs == 5:
377
378        # print('STAB-SM SGS')
379
380        (divtx, divty, divtz,
381         stabsmSGSmomentum) = (
382            DivStressStaticSGS_STABSM(
383                dudx, dvdx, dwdx,
384                dudy, dvdy, dwdy,
385                dudz, dvdz, dwdz,
386                dTHdz,
387                u, v, M_sfc_loc, MOSTfunctions,
388                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
389                kx2, ky2))
390
391        # stabsmSGSmomentum[10:14] = (Lambda_uvp2_3D, Lambda_w2_3D, fhS_uvp, fhS_w)
392        qz, divq = (
393            DivFluxStaticSGS_STABSM(
394                stabsmSGSmomentum[10:14],
395                dTHdx, dTHdy, dTHdz,
396                qz_sfc_step,
397                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
398                kx2, ky2))
399
400        # Moisture SGS: reuse same Lambda and fhS as heat
401        if optMoisture >= 1:
402            qHz_q, divqm = DivFluxStaticSGS_STABSM(
403                stabsmSGSmomentum[10:14],
404                dQdx, dQdy, dQdz,
405                qm_sfc_step,
406                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
407                kx2, ky2)
408        else:
409            qHz_q = ZeRo3D; divqm = ZeRo3D
410
411        # stabsmSGSmomentum[14] = Lambda_uvp2_1D (effective Cs^2 profile)
412        Cs2_1D_avg1    = stabsmSGSmomentum[14]
413        Cs2_1D_avg2    = stabsmSGSmomentum[14]
414        Cs2PrRatio_1D  = stabsmSGSmomentum[14]
415        beta1_1D       = ZeRo1D
416        beta2_1D       = ZeRo1D
417
418        # Unpack variables for computation of statistics
419        _, _, _, txy, txz, tyz = stabsmSGSmomentum[0:6]
420
421    else:
422
423        # print('Static SGS')
424
425        (divtx, divty, divtz,
426         staticSGSmomentum) = (
427            DivStressStaticSGS(
428                dudx, dvdx, dwdx,
429                dudy, dvdy, dwdy,
430                dudz, dvdz, dwdz,
431                Cs2_3D,
432                u, v, M_sfc_loc, MOSTfunctions,
433                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
434                kx2, ky2))
435
436        qz, divq = (
437            DivFluxStaticSGS(
438                staticSGSmomentum[6:],
439                Cs2PrRatio_3D,
440                dTHdx, dTHdy, dTHdz,
441                qz_sfc_step,
442                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
443                kx2, ky2))
444
445        # Moisture SGS: same Cs2PrRatio as heat.
446        if optMoisture >= 1:
447            qHz_q, divqm = DivFluxStaticSGS(
448                staticSGSmomentum[6:],
449                Cs2PrRatio_3D,
450                dQdx, dQdy, dQdz,
451                qm_sfc_step,
452                ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
453                kx2, ky2)
454        else:
455            qHz_q = ZeRo3D; divqm = ZeRo3D
456
457        # Unpack variables for computation of statistics
458        _, _, _, txy, txz, tyz = staticSGSmomentum[0:6]
459
460    # ------------------------------------------------------------
461    #  Compute right hand side (RHS) terms
462    # ------------------------------------------------------------
463
464    (RHS_u, RHS_v, RHS_w) = (
465        RHS_Momentum(u, v, w,
466                     Ug, Vg,
467                     Cx, Cy, Cz,
468                     buoyancy,
469                     divtx, divty, divtz,
470                     RayleighDampCoeff, RayleighDampCoeff_stag,
471                     Uadv, Vadv))
472
473    RHS_TH = RHS_Scalar(TH, THAdvectionSum, divq, RayleighDampCoeff_stag, THadv)
474    if optMoisture >= 1:
475        RHS_Q = RHS_Moisture(Q, QAdvectionSum, divqm, RayleighDampCoeff_stag, Qadv)
476
477    # ------------------------------------------------------------
478    #  Pressure solution
479    # ------------------------------------------------------------
480
481    (RC_real, RC_imag, fRz_real) = (
482        PressureRC(
483            u, v, w,
484            RHS_u, RHS_v, RHS_w,
485            RHS_u_previous, RHS_v_previous, RHS_w_previous,
486            divtz, kr2_pressure, kc2_pressure))
487
488    if optPressureSolver == 1:
489        (p, dpdx, dpdy, dpdz) = ThomasPressureSolve(
490            RC_real, RC_imag, fRz_real,
491            b_thomas, m_thomas, c_pressure)
492    else:
493        (p, dpdx, dpdy, dpdz) = PressureSolve(
494            RC_real, RC_imag, fRz_real,
495            a_pressure, b_pressure, c_pressure)
496
497    # Add pressure gradient terms to RHS
498    RHS_u = RHS_u - dpdx
499    RHS_v = RHS_v - dpdy
500    RHS_w = RHS_w - dpdz
501
502    # ------------------------------------------------------------
503    #  Initialize RHS terms for previous time step
504    # ------------------------------------------------------------
505
506    if iteration == istep:
507        RHS_u_previous  = RHS_u
508        RHS_v_previous  = RHS_v
509        RHS_w_previous  = RHS_w
510        RHS_TH_previous = RHS_TH
511        if optMoisture >= 1:
512            RHS_Q_previous = RHS_Q
513
514    # ------------------------------------------------------------
515    #  Time advancement
516    # ------------------------------------------------------------
517
518    (u, v, w) = (
519        AB2_uvw(u, v, w,
520                RHS_u, RHS_u_previous,
521                RHS_v, RHS_v_previous,
522                RHS_w, RHS_w_previous))
523
524    (TH) = (
525        AB2_TH(TH,
526               RHS_TH, RHS_TH_previous))
527
528    if optMoisture >= 1:
529        Q = AB2_Q(Q, RHS_Q, RHS_Q_previous)
530
531    # ------------------------------------------------------------
532    #  Compute CFLmax
533    # ------------------------------------------------------------
534    CFLx = jnp.max(jnp.abs(u)) * dt_nondim / dx
535    CFLy = jnp.max(jnp.abs(v)) * dt_nondim / dy
536    CFLz = jnp.max(jnp.abs(w)) * dt_nondim / dz
537    CFL = jnp.max(jnp.array([CFLx, CFLy, CFLz]))
538    if CFL > CFLmax:
539        CFLmax = CFL
540        CFLmax_iteration = iteration
541
542    # ------------------------------------------------------------
543    #  Compute and output averaged statistics
544    # ------------------------------------------------------------
545
546    # Collect samples at specified intervals including output intervals
547    if iteration % SampleInterval == 0:
548        # Accumulation of statistics
549        ResetFlag = 0
550        StatsDict = ComputeStats(u, v, w, TH, Q,
551                                 dudz, dvdz, dTHdz, dQdz,
552                                 M_sfc_loc, ustar, qz_sfc_avg, qm_sfc_avg,
553                                 txy, txz, tyz, qz, qHz_q,
554                                 Cs2_1D_avg1, Cs2_1D_avg2,
555                                 Cs2PrRatio_1D,
556                                 beta1_1D, beta2_1D,
557                                 StatsDict, ResetFlag,
558                                 ZeRo3D)
559        SampleCounter += 1
560
561        pct     = 100.0 * iteration / nsteps
562        elapsed = time.time() - tic_tot
563        rate    = elapsed / (iteration - istep + 1)   # seconds per iteration
564        eta     = rate * (nsteps - iteration)
565
566        def _fmt(s):
567            h, r = divmod(int(s), 3600)
568            m, sec = divmod(r, 60)
569            return f"{h:02d}:{m:02d}:{sec:02d}"
570
571        print(f"\n============= Finished Iteration {iteration} / {nsteps} "
572              f"({pct:.1f}%) =============")
573        print(f"  Elapsed: {_fmt(elapsed)}   ETA: {_fmt(eta)}")
574        print(
575            f"Statistics: collected sample {SampleCounter} at iteration {iteration}")
576        print(f"  Friction Velocity:    {jnp.sqrt(jnp.mean(ustar ** 2)):.4f} "
577              f"m/s")
578        print(f"  Sensible Heat Flux:   "
579              f"{float(qz_sfc_avg * u_scale * TH_scale):.4f} K m/s")
580        if optMoisture >= 1:
581            print(f"  Moisture Flux:        "
582                  f"{float(qm_sfc_avg * u_scale * Q_scale):.6e} kg/kg m/s")
583        print(f"  Current CFL:          {CFL:.3f}")
584        print(f"  CFLmax:               {CFLmax:.3f}")
585        print(f"  CFLmax happened at iteration: {CFLmax_iteration}")
586
587    # At output intervals, check if we've collected any samples
588    if iteration % OutputInterval == 0 and SampleCounter > 0:
589        OutputStats = {}
590        for key in StatsDict:
591            if key not in ["Ugal", "ZeRo1D"]:
592                # Average the accumulated statistics
593                OutputStats[key] = StatsDict[key] / SampleCounter
594            else:
595                OutputStats[key] = StatsDict[key]
596
597        # Generate output filename and save statistics
598        OutputFile = f'ALFA_Statistics_Iteration_{iteration}.npz'
599        OutputDirFile = os.path.join(OutputDir, OutputFile)
600        np.savez(OutputDirFile, **OutputStats)
601        print(
602            f"Statistics saved to {OutputFile} "
603            f"(averaged over {SampleCounter} samples)")
604
605        # Reset statistics for next averaging interval
606        SampleCounter = 0
607        ResetFlag = 1
608        StatsDict = ComputeStats(u, v, w, TH, Q,
609                                 dudz, dvdz, dTHdz, dQdz,
610                                 M_sfc_loc, ustar, qz_sfc_avg, qm_sfc_avg,
611                                 txy, txz, tyz, qz, qHz_q,
612                                 Cs2_1D_avg1, Cs2_1D_avg2,
613                                 Cs2PrRatio_1D,
614                                 beta1_1D, beta2_1D,
615                                 StatsDict, ResetFlag,
616                                 ZeRo3D)
617
618    # At regular intervals, save 3D fields for visualizations
619    # Output 3D fields at specified intervals
620    if iteration % Output3DInterval == 0:
621        # Create dictionary of fields to save
622        Fields3D = {
623            "u": u + Ugal,        # Galilean velocity added back
624            "v": v,
625            "w": w,
626            "TH": TH + T_0_nondim  # anomaly → absolute (TH stored as TH - T_0)
627        }
628        if optMoisture >= 1:
629            Fields3D["Q"] = Q
630
631        # Generate output filename and save 3D fields
632        OutputFile3D = f'ALFA_3DFields_Iteration_{iteration}.npz'
633        OutputDirFile3D = os.path.join(OutputDir, OutputFile3D)
634        np.savez(OutputDirFile3D, **Fields3D)
635        print(f"3D fields saved to {OutputFile3D}")
636
637print(f"Total Elapsed Time: {time.time() - tic_tot:.5f} seconds")