Main Program
Source Code: Main
Main.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Main.py
18===========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-5
24:Description: main file for JAX-ALFA
25"""
26
27
28# ============================================================
29# Imports
30# ============================================================
31# This file is for run time
32from .config.Imports import ImportLES
33ImportLES()
34
35# Import derived variables
36from .config.DerivedVars import *
37
38# This file is for IDE static analysis during development time
39from .utilities.Pycharm import *
40
41
42# ============================================================
43# Initialize Static Variables
44# ============================================================
45
46kx2, ky2 = Wavenumber()
47ZeRo3D = ZeRo3DIni()
48ZeRo2D = ZeRo2DIni()
49ZeRo1D = ZeRo1DIni()
50ZeRo3D_fft = ZeRo3D_fftIni()
51ZeRo3D_pad = ZeRo3D_padIni()
52ZeRo3D_pad_fft = ZeRo3D_pad_fftIni()
53
54# Static variables related to pressure solver
55# optPressureSolver = 0: LU (original) 1: Thomas (tridiagonal, faster)
56if optPressureSolver == 1:
57 (kr2_pressure, kc2_pressure,
58 a_pressure, b_pressure, c_pressure,
59 b_thomas, m_thomas) = ThomasPressureInit()
60else:
61 (kr2_pressure, kc2_pressure,
62 a_pressure, b_pressure, c_pressure) = PressureInit()
63
64
65# ============================================================
66# Initialize velocity, temperature, etc.
67# ============================================================
68
69u, v, w = Initialize_uvw()
70TH = Initialize_TH()
71if optMoisture >= 1:
72 Q = Initialize_Q()
73 RHS_Q_previous = ZeRo3D.copy()
74 if optMoistureSurfBC >= 1:
75 MoistureSurfaceBC_series = Initialize_MoistureSurfaceBC()
76 Qadv = ZeRo3D
77else:
78 Q = ZeRo3D
79 Qadv = ZeRo3D
80
81if optGeoWind == 0:
82 Ug, Vg = Initialize_GeoWind()
83else:
84 GeoWind_U, GeoWind_V = Initialize_GeoWind_Varying()
85 Ug = jnp.broadcast_to(GeoWind_U[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
86 Vg = jnp.broadcast_to(GeoWind_V[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
87
88RayleighDampCoeff, RayleighDampCoeff_stag = (
89 Initialize_RayleighDampingLayer())
90
91RHS_u_previous = ZeRo3D.copy()
92RHS_v_previous = ZeRo3D.copy()
93RHS_w_previous = ZeRo3D.copy()
94RHS_TH_previous = ZeRo3D.copy()
95
96CFLmax = 0
97CFLmax_iteration = 1
98
99# ============================================================
100# Initialize surface variables
101# ============================================================
102psi2D_m = ZeRo2D.copy()
103psi2D_m0 = ZeRo2D.copy()
104psi2D_h = ZeRo2D.copy()
105psi2D_h0 = ZeRo2D.copy()
106fi2D_m = 1.0 + ZeRo2D.copy()
107fi2D_h = 1.0 + ZeRo2D.copy()
108
109MOSTfunctions = (psi2D_m, psi2D_m0,
110 psi2D_h, psi2D_h0,
111 fi2D_m, fi2D_h)
112
113# Load time-varying surface BC series once before the loop (optSurfBC >= 1)
114if optSurfBC >= 1:
115 SurfaceBC_series = Initialize_SurfaceBC()
116
117# Load large-scale advection forcing once before the loop (optAdvection >= 1)
118if optAdvection >= 1:
119 AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q = Initialize_AdvForcing()
120 Uadv = jnp.broadcast_to(AdvForcing_U[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
121 Vadv = jnp.broadcast_to(AdvForcing_V[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
122 THadv = jnp.broadcast_to(AdvForcing_TH[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
123 if optMoisture >= 1:
124 Qadv = jnp.broadcast_to(AdvForcing_Q[istep - 1].reshape(1, 1, nz), (nx, ny, nz))
125else:
126 Uadv = ZeRo3D
127 Vadv = ZeRo3D
128 THadv = ZeRo3D
129
130
131# ============================================================
132# Initialize statistics variables
133# ============================================================
134StatsDict = InitializeStats(ZeRo1D)
135SampleCounter = 0 # Counter to sample statistics
136
137# For STAB-SM (optSgs=5) the dynamic 1D profile variables are never set by
138# the dynamic SGS branch. Initialize them here so ComputeStats always has
139# valid arrays. They are overwritten each iteration inside the optSgs==5 block.
140if optSgs == 5:
141 Cs2_1D_avg1 = ZeRo1D.copy()
142 Cs2_1D_avg2 = ZeRo1D.copy()
143 Cs2PrRatio_1D = ZeRo1D.copy()
144 beta1_1D = ZeRo1D.copy()
145 beta2_1D = ZeRo1D.copy()
146OutputDir = os.path.join(os.environ['JAXALFA_RUNDIR'], 'output')
147os.makedirs(OutputDir, exist_ok=True)
148
149
150# ============================================================
151# Main simulation loop
152# ============================================================
153
154tic_tot = time.time()
155
156for iteration in range(istep, nsteps+1, 1):
157
158 if iteration > istep:
159
160 RHS_u_previous = RHS_u
161 RHS_v_previous = RHS_v
162 RHS_w_previous = RHS_w
163 RHS_TH_previous = RHS_TH
164 if optMoisture >= 1:
165 RHS_Q_previous = RHS_Q
166
167 # ------------------------------------------------------------
168 # Update time/height-varying geostrophic wind (optGeoWind >= 1)
169 # ------------------------------------------------------------
170 if optGeoWind >= 1:
171 Ug = jnp.broadcast_to(
172 GeoWind_U[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
173 Vg = jnp.broadcast_to(
174 GeoWind_V[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
175
176 # ------------------------------------------------------------
177 # Update time/height-varying large-scale advection (optAdvection >= 1)
178 # ------------------------------------------------------------
179 if optAdvection >= 1:
180 Uadv = jnp.broadcast_to(
181 AdvForcing_U[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
182 Vadv = jnp.broadcast_to(
183 AdvForcing_V[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
184 THadv = jnp.broadcast_to(
185 AdvForcing_TH[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
186 if optMoisture >= 1:
187 Qadv = jnp.broadcast_to(
188 AdvForcing_Q[iteration - 1].reshape(1, 1, nz), (nx, ny, nz))
189
190 # ------------------------------------------------------------
191 # Filtering and FFT Computations
192 # ------------------------------------------------------------
193 u, u_fft = Filtering_Explicit(FFT(u))
194 v, v_fft = Filtering_Explicit(FFT(v))
195 w, w_fft = Filtering_Explicit(FFT(w))
196
197 TH, _ = Filtering_Explicit(FFT(TH))
198 if optMoisture >= 1:
199 Q, _ = Filtering_Explicit(FFT(Q))
200
201 # ------------------------------------------------------------
202 # Compute Surface Fluxes
203 #
204 # All branches set:
205 # M_sfc_loc (nx, ny) surface wind speed
206 # ustar (nx, ny) friction velocity
207 # qz_sfc_step (nx, ny) surface heat flux field (qz = -u* x th*)
208 # qz_sfc_avg scalar planar-mean, non-dimensional
209 # invOB (nx, ny) inverse Obukhov length
210 # MOSTfunctions updated stability functions
211 # ------------------------------------------------------------
212
213 if optSurfBC == 0:
214 # Constant prescribed heat flux
215 if optSurfFlux == 0:
216 (M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions) = (
217 SurfaceFlux_HomogeneousConstantFlux(u, v, TH, MOSTfunctions))
218 else:
219 (M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions) = (
220 SurfaceFlux_HeterogeneousConstantFlux(u, v, TH, MOSTfunctions))
221 qz_sfc_step = qz_sfc # global (nx,ny) array from DerivedVars
222
223 elif optSurfBC == 1:
224 # Time-varying prescribed heat flux
225 sfc_val = SurfaceBC_series[iteration - 1]
226 if optSurfFlux == 0:
227 (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
228 SurfaceFlux_HomogeneousVaryingFlux(u, v, TH, sfc_val, MOSTfunctions))
229 else:
230 (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
231 SurfaceFlux_HeterogeneousVaryingFlux(u, v, TH, sfc_val, MOSTfunctions))
232
233 else:
234 # optSurfBC == 2: time-varying prescribed surface temperature
235 sfc_val = SurfaceBC_series[iteration - 1]
236 if optSurfFlux == 0:
237 (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
238 SurfaceFlux_HomogeneousPrescribedTemperature(
239 u, v, TH, sfc_val, MOSTfunctions))
240 else:
241 (M_sfc_loc, ustar, qz_sfc_step, qz_sfc_avg, invOB, MOSTfunctions) = (
242 SurfaceFlux_HeterogeneousPrescribedTemperature(
243 u, v, TH, sfc_val, MOSTfunctions))
244
245 # ------------------------------------------------------------
246 # Compute Surface Moisture Flux (optMoisture >= 1)
247 #
248 # Uses ustar and MOSTfunctions already computed above.
249 # All branches set:
250 # qm_sfc_step (nx, ny) surface moisture flux field
251 # qm_sfc_avg scalar planar-mean, non-dimensional
252 # ------------------------------------------------------------
253 if optMoisture >= 1:
254 if optMoistureSurfBC == 0:
255 qm_sfc_step = qm_sfc # constant from DerivedVars
256 elif optMoistureSurfBC == 1:
257 qm_sfc_t = MoistureSurfaceBC_series[iteration - 1]
258 qm_sfc_step = qm_sfc_t * jnp.ones((nx, ny))
259 else: # optMoistureSurfBC == 2: prescribed surface Q
260 Q_sfc_t = MoistureSurfaceBC_series[iteration - 1]
261 if optSurfFlux == 0:
262 qm_sfc_step = SurfaceMoistureFlux_HomogeneousPrescribedQ(
263 Q, ustar, Q_sfc_t, MOSTfunctions)
264 else:
265 qm_sfc_step = SurfaceMoistureFlux_HeterogeneousPrescribedQ(
266 Q, ustar, Q_sfc_t, MOSTfunctions)
267 qm_sfc_avg = jnp.mean(qm_sfc_step)
268 else:
269 qm_sfc_step = ZeRo2D
270 qm_sfc_avg = 0.0
271
272 # ------------------------------------------------------------
273 # Compute Velocity Gradients
274 # ------------------------------------------------------------
275 (dudx, dvdx, dwdx,
276 dudy, dvdy, dwdy,
277 dudz, dvdz, dwdz) = (
278 velocityGradients(
279 u, v, w,
280 u_fft, v_fft, w_fft,
281 kx2, ky2,
282 ustar, M_sfc_loc, MOSTfunctions,
283 ZeRo3D))
284
285 (dTHdx, dTHdy, dTHdz) = (
286 potentialTemperatureGradients(
287 TH,
288 kx2, ky2,
289 ustar, qz_sfc_step, MOSTfunctions,
290 ZeRo3D))
291
292 if optMoisture >= 1:
293 (dQdx, dQdy, dQdz) = moistureGradients(
294 Q, kx2, ky2, ustar, qm_sfc_step, MOSTfunctions, ZeRo3D)
295 else:
296 dQdx = ZeRo3D; dQdy = ZeRo3D; dQdz = ZeRo3D
297
298 # ------------------------------------------------------------
299 # Compute Advection Terms
300 # ------------------------------------------------------------
301 Cx, Cy, Cz = Advection(
302 u, v, w,
303 dudy, dudz, dvdx, dvdz, dwdx, dwdy,
304 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
305 ZeRo3D_pad_fft)
306
307 THAdvectionSum = ScalarAdvection(
308 u, v, w,
309 dTHdx, dTHdy, dTHdz,
310 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
311 ZeRo3D_pad_fft)
312
313 if optMoisture >= 1:
314 QAdvectionSum = ScalarAdvection(
315 u, v, w,
316 dQdx, dQdy, dQdz,
317 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad,
318 ZeRo3D_pad_fft)
319
320 # ------------------------------------------------------------
321 # Compute Buoyancy Terms
322 # ------------------------------------------------------------
323 H = Q if optMoisture >= 1 else ZeRo3D
324 if optBuoyancy == 0:
325 buoyancy = BuoyancyOpt1(TH, H, ZeRo3D)
326 else:
327 buoyancy = BuoyancyOpt2(TH, H, ZeRo3D)
328
329 # ------------------------------------------------------------
330 # Compute SGS Terms
331 # ------------------------------------------------------------
332
333 if 1 <= optSgs <= 4 and (iteration == istep or iteration % dynamicSGS_call_time == 0):
334
335 # print('Dynamic SGS')
336
337 (divtx, divty, divtz,
338 Cs2_1D_avg1, Cs2_1D_avg2, beta1_1D,
339 Cs2_3D,
340 dynamicSGSmomentum) = (
341 DivStressDynamicSGS(
342 dudx, dvdx, dwdx,
343 dudy, dvdy, dwdy,
344 dudz, dvdz, dwdz,
345 u, v, w, M_sfc_loc, MOSTfunctions,
346 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
347 kx2, ky2))
348
349 (qz, divq, Cs2PrRatio_3D, Cs2PrRatio_1D, beta2_1D) = (
350 DivFluxDynamicSGS(
351 dynamicSGSmomentum[10:],
352 TH,
353 dTHdx, dTHdy, dTHdz,
354 qz_sfc_step,
355 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
356 kx2, ky2))
357
358 # Moisture SGS: reuse strain rates from dynamic momentum SGS with
359 # the same Cs2PrRatio (turbulent Sc = turbulent Pr approximation).
360 # dynamicSGSmomentum[19:23] = (S_uvp, S_uvp_pad, S_w, S_w_pad)
361 if optMoisture >= 1:
362 qHz_q, divqm = DivFluxStaticSGS(
363 (dynamicSGSmomentum[19], dynamicSGSmomentum[20],
364 dynamicSGSmomentum[21], dynamicSGSmomentum[22]),
365 Cs2PrRatio_3D,
366 dQdx, dQdy, dQdz,
367 qm_sfc_step,
368 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
369 kx2, ky2)
370 else:
371 qHz_q = ZeRo3D; divqm = ZeRo3D
372
373 # Unpack variables for computation of statistics
374 _, _, _, txy, txz, tyz = dynamicSGSmomentum[0:6]
375
376 elif optSgs == 5:
377
378 # print('STAB-SM SGS')
379
380 (divtx, divty, divtz,
381 stabsmSGSmomentum) = (
382 DivStressStaticSGS_STABSM(
383 dudx, dvdx, dwdx,
384 dudy, dvdy, dwdy,
385 dudz, dvdz, dwdz,
386 dTHdz,
387 u, v, M_sfc_loc, MOSTfunctions,
388 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
389 kx2, ky2))
390
391 # stabsmSGSmomentum[10:14] = (Lambda_uvp2_3D, Lambda_w2_3D, fhS_uvp, fhS_w)
392 qz, divq = (
393 DivFluxStaticSGS_STABSM(
394 stabsmSGSmomentum[10:14],
395 dTHdx, dTHdy, dTHdz,
396 qz_sfc_step,
397 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
398 kx2, ky2))
399
400 # Moisture SGS: reuse same Lambda and fhS as heat
401 if optMoisture >= 1:
402 qHz_q, divqm = DivFluxStaticSGS_STABSM(
403 stabsmSGSmomentum[10:14],
404 dQdx, dQdy, dQdz,
405 qm_sfc_step,
406 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
407 kx2, ky2)
408 else:
409 qHz_q = ZeRo3D; divqm = ZeRo3D
410
411 # stabsmSGSmomentum[14] = Lambda_uvp2_1D (effective Cs^2 profile)
412 Cs2_1D_avg1 = stabsmSGSmomentum[14]
413 Cs2_1D_avg2 = stabsmSGSmomentum[14]
414 Cs2PrRatio_1D = stabsmSGSmomentum[14]
415 beta1_1D = ZeRo1D
416 beta2_1D = ZeRo1D
417
418 # Unpack variables for computation of statistics
419 _, _, _, txy, txz, tyz = stabsmSGSmomentum[0:6]
420
421 else:
422
423 # print('Static SGS')
424
425 (divtx, divty, divtz,
426 staticSGSmomentum) = (
427 DivStressStaticSGS(
428 dudx, dvdx, dwdx,
429 dudy, dvdy, dwdy,
430 dudz, dvdz, dwdz,
431 Cs2_3D,
432 u, v, M_sfc_loc, MOSTfunctions,
433 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
434 kx2, ky2))
435
436 qz, divq = (
437 DivFluxStaticSGS(
438 staticSGSmomentum[6:],
439 Cs2PrRatio_3D,
440 dTHdx, dTHdy, dTHdz,
441 qz_sfc_step,
442 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
443 kx2, ky2))
444
445 # Moisture SGS: same Cs2PrRatio as heat.
446 if optMoisture >= 1:
447 qHz_q, divqm = DivFluxStaticSGS(
448 staticSGSmomentum[6:],
449 Cs2PrRatio_3D,
450 dQdx, dQdy, dQdz,
451 qm_sfc_step,
452 ZeRo3D, ZeRo3D_fft, ZeRo3D_pad_fft,
453 kx2, ky2)
454 else:
455 qHz_q = ZeRo3D; divqm = ZeRo3D
456
457 # Unpack variables for computation of statistics
458 _, _, _, txy, txz, tyz = staticSGSmomentum[0:6]
459
460 # ------------------------------------------------------------
461 # Compute right hand side (RHS) terms
462 # ------------------------------------------------------------
463
464 (RHS_u, RHS_v, RHS_w) = (
465 RHS_Momentum(u, v, w,
466 Ug, Vg,
467 Cx, Cy, Cz,
468 buoyancy,
469 divtx, divty, divtz,
470 RayleighDampCoeff, RayleighDampCoeff_stag,
471 Uadv, Vadv))
472
473 RHS_TH = RHS_Scalar(TH, THAdvectionSum, divq, RayleighDampCoeff_stag, THadv)
474 if optMoisture >= 1:
475 RHS_Q = RHS_Moisture(Q, QAdvectionSum, divqm, RayleighDampCoeff_stag, Qadv)
476
477 # ------------------------------------------------------------
478 # Pressure solution
479 # ------------------------------------------------------------
480
481 (RC_real, RC_imag, fRz_real) = (
482 PressureRC(
483 u, v, w,
484 RHS_u, RHS_v, RHS_w,
485 RHS_u_previous, RHS_v_previous, RHS_w_previous,
486 divtz, kr2_pressure, kc2_pressure))
487
488 if optPressureSolver == 1:
489 (p, dpdx, dpdy, dpdz) = ThomasPressureSolve(
490 RC_real, RC_imag, fRz_real,
491 b_thomas, m_thomas, c_pressure)
492 else:
493 (p, dpdx, dpdy, dpdz) = PressureSolve(
494 RC_real, RC_imag, fRz_real,
495 a_pressure, b_pressure, c_pressure)
496
497 # Add pressure gradient terms to RHS
498 RHS_u = RHS_u - dpdx
499 RHS_v = RHS_v - dpdy
500 RHS_w = RHS_w - dpdz
501
502 # ------------------------------------------------------------
503 # Initialize RHS terms for previous time step
504 # ------------------------------------------------------------
505
506 if iteration == istep:
507 RHS_u_previous = RHS_u
508 RHS_v_previous = RHS_v
509 RHS_w_previous = RHS_w
510 RHS_TH_previous = RHS_TH
511 if optMoisture >= 1:
512 RHS_Q_previous = RHS_Q
513
514 # ------------------------------------------------------------
515 # Time advancement
516 # ------------------------------------------------------------
517
518 (u, v, w) = (
519 AB2_uvw(u, v, w,
520 RHS_u, RHS_u_previous,
521 RHS_v, RHS_v_previous,
522 RHS_w, RHS_w_previous))
523
524 (TH) = (
525 AB2_TH(TH,
526 RHS_TH, RHS_TH_previous))
527
528 if optMoisture >= 1:
529 Q = AB2_Q(Q, RHS_Q, RHS_Q_previous)
530
531 # ------------------------------------------------------------
532 # Compute CFLmax
533 # ------------------------------------------------------------
534 CFLx = jnp.max(jnp.abs(u)) * dt_nondim / dx
535 CFLy = jnp.max(jnp.abs(v)) * dt_nondim / dy
536 CFLz = jnp.max(jnp.abs(w)) * dt_nondim / dz
537 CFL = jnp.max(jnp.array([CFLx, CFLy, CFLz]))
538 if CFL > CFLmax:
539 CFLmax = CFL
540 CFLmax_iteration = iteration
541
542 # ------------------------------------------------------------
543 # Compute and output averaged statistics
544 # ------------------------------------------------------------
545
546 # Collect samples at specified intervals including output intervals
547 if iteration % SampleInterval == 0:
548 # Accumulation of statistics
549 ResetFlag = 0
550 StatsDict = ComputeStats(u, v, w, TH, Q,
551 dudz, dvdz, dTHdz, dQdz,
552 M_sfc_loc, ustar, qz_sfc_avg, qm_sfc_avg,
553 txy, txz, tyz, qz, qHz_q,
554 Cs2_1D_avg1, Cs2_1D_avg2,
555 Cs2PrRatio_1D,
556 beta1_1D, beta2_1D,
557 StatsDict, ResetFlag,
558 ZeRo3D)
559 SampleCounter += 1
560
561 pct = 100.0 * iteration / nsteps
562 elapsed = time.time() - tic_tot
563 rate = elapsed / (iteration - istep + 1) # seconds per iteration
564 eta = rate * (nsteps - iteration)
565
566 def _fmt(s):
567 h, r = divmod(int(s), 3600)
568 m, sec = divmod(r, 60)
569 return f"{h:02d}:{m:02d}:{sec:02d}"
570
571 print(f"\n============= Finished Iteration {iteration} / {nsteps} "
572 f"({pct:.1f}%) =============")
573 print(f" Elapsed: {_fmt(elapsed)} ETA: {_fmt(eta)}")
574 print(
575 f"Statistics: collected sample {SampleCounter} at iteration {iteration}")
576 print(f" Friction Velocity: {jnp.sqrt(jnp.mean(ustar ** 2)):.4f} "
577 f"m/s")
578 print(f" Sensible Heat Flux: "
579 f"{float(qz_sfc_avg * u_scale * TH_scale):.4f} K m/s")
580 if optMoisture >= 1:
581 print(f" Moisture Flux: "
582 f"{float(qm_sfc_avg * u_scale * Q_scale):.6e} kg/kg m/s")
583 print(f" Current CFL: {CFL:.3f}")
584 print(f" CFLmax: {CFLmax:.3f}")
585 print(f" CFLmax happened at iteration: {CFLmax_iteration}")
586
587 # At output intervals, check if we've collected any samples
588 if iteration % OutputInterval == 0 and SampleCounter > 0:
589 OutputStats = {}
590 for key in StatsDict:
591 if key not in ["Ugal", "ZeRo1D"]:
592 # Average the accumulated statistics
593 OutputStats[key] = StatsDict[key] / SampleCounter
594 else:
595 OutputStats[key] = StatsDict[key]
596
597 # Generate output filename and save statistics
598 OutputFile = f'ALFA_Statistics_Iteration_{iteration}.npz'
599 OutputDirFile = os.path.join(OutputDir, OutputFile)
600 np.savez(OutputDirFile, **OutputStats)
601 print(
602 f"Statistics saved to {OutputFile} "
603 f"(averaged over {SampleCounter} samples)")
604
605 # Reset statistics for next averaging interval
606 SampleCounter = 0
607 ResetFlag = 1
608 StatsDict = ComputeStats(u, v, w, TH, Q,
609 dudz, dvdz, dTHdz, dQdz,
610 M_sfc_loc, ustar, qz_sfc_avg, qm_sfc_avg,
611 txy, txz, tyz, qz, qHz_q,
612 Cs2_1D_avg1, Cs2_1D_avg2,
613 Cs2PrRatio_1D,
614 beta1_1D, beta2_1D,
615 StatsDict, ResetFlag,
616 ZeRo3D)
617
618 # At regular intervals, save 3D fields for visualizations
619 # Output 3D fields at specified intervals
620 if iteration % Output3DInterval == 0:
621 # Create dictionary of fields to save
622 Fields3D = {
623 "u": u + Ugal, # Galilean velocity added back
624 "v": v,
625 "w": w,
626 "TH": TH + T_0_nondim # anomaly → absolute (TH stored as TH - T_0)
627 }
628 if optMoisture >= 1:
629 Fields3D["Q"] = Q
630
631 # Generate output filename and save 3D fields
632 OutputFile3D = f'ALFA_3DFields_Iteration_{iteration}.npz'
633 OutputDirFile3D = os.path.join(OutputDir, OutputFile3D)
634 np.savez(OutputDirFile3D, **Fields3D)
635 print(f"3D fields saved to {OutputFile3D}")
636
637print(f"Total Elapsed Time: {time.time() - tic_tot:.5f} seconds")