Initialization of Variables
Source Code: Initialization
Initialization.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Initialization.py
18==============================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-3
24:Description: loads velocity and temperature fields & reshape them
25"""
26
27
28# ============================================================
29# Imports
30# ============================================================
31
32import os
33import jax
34import jax.numpy as jnp
35
36# Import configuration from namelist
37from ..config.ConfigLoader import *
38
39# Import derived variables
40from ..config.DerivedVars import *
41
42# Import StagGridAvg
43from ..utilities.Utilities import StagGridAvg
44
45InputDir = os.path.join(os.environ['JAXALFA_RUNDIR'], 'input')
46
47# Absolute paths resolved once at import time
48_SurfaceBCFile = os.path.join(os.environ['JAXALFA_RUNDIR'], SurfaceBCFile)
49_GeoWindFile = os.path.join(os.environ['JAXALFA_RUNDIR'], GeoWindFile)
50_AdvectionFile = os.path.join(os.environ['JAXALFA_RUNDIR'], AdvectionFile)
51_MoistureSurfaceBCFile = os.path.join(os.environ['JAXALFA_RUNDIR'], MoistureSurfaceBCFile)
52
53
54# ============================================================
55# Load velocity field
56# ============================================================
57
58def Initialize_uvw():
59 """
60 Returns:
61 --------
62 u, v, w : ndarray
63 3D arrays of size (nx, ny, nz) containing the initialized
64 velocity components
65 """
66
67 InputVelocity = os.path.join(InputDir, 'vel.npy')
68 vel = np.load(InputVelocity)
69 u = vel[:, 0] - Ugal
70 v = vel[:, 1]
71 w = vel[:, 2]
72
73 u = np.reshape(u, (nx, ny, nz), order='F') / u_scale
74 v = np.reshape(v, (nx, ny, nz), order='F') / u_scale
75 w = np.reshape(w, (nx, ny, nz), order='F') / u_scale
76
77 return jnp.array(u), jnp.array(v), jnp.array(w)
78
79
80# ============================================================
81# Load potential temperature field
82# ============================================================
83
84def Initialize_TH():
85 """
86 Returns:
87 --------
88 TH : ndarray
89 3D array of size (nx, ny, nz) containing the initialized
90 potential temperature field
91 """
92
93 InputTH = os.path.join(InputDir, 'TH.npy')
94 TH = np.load(InputTH)
95
96 TH = np.reshape(TH, (nx, ny, nz), order='F')
97
98 # Subtract base state in numpy float64 before JAX cast; TH is stored as
99 # anomaly TH' = TH - T_0 throughout the simulation.
100 return jnp.array(TH - T_0_nondim)
101
102
103# ============================================================
104# Geostrophic wind components
105# ============================================================
106
107def Initialize_SurfaceBC():
108 """
109 Load the time-varying surface BC series (heat flux or surface temperature)
110 from SurfaceBCFile. Called once before the main loop when optSurfBC >= 1.
111
112 Validates that the file's dt_surf matches Config dt, and that the series
113 length matches nsteps+1.
114
115 Returns:
116 --------
117 SurfaceBC_series : jnp.ndarray of shape (nsteps+1,)
118 Non-dimensional surface BC values at every timestep from t=0 to t=SimTime.
119 For optSurfBC=1: non-dim heat flux = data / (u_scale * TH_scale)
120 For optSurfBC=2: non-dim temp anomaly = (data - T_0) / TH_scale
121 Stored as anomaly so float32 represents small values (~0 to -2.25 K)
122 rather than absolute temperature (~265 K). Surface flux functions
123 expect this anomaly and compare it against TH_air anomalies.
124 """
125
126 data = np.load(_SurfaceBCFile)
127
128 # --- Validation ---
129 dt_surf_file = float(data['dt_surf'])
130 if abs(dt_surf_file - dt) > 1e-6:
131 raise ValueError(
132 f"SurfaceBC file dt_surf={dt_surf_file:.6f} s does not match "
133 f"Config dt={dt:.6f} s. Re-run CreateSurfaceBC with the current Config."
134 )
135
136 optSurfBC_file = int(data['optSurfBC'])
137 if optSurfBC_file != optSurfBC:
138 raise ValueError(
139 f"SurfaceBC file optSurfBC={optSurfBC_file} does not match "
140 f"Config optSurfBC={optSurfBC}. Re-run CreateSurfaceBC."
141 )
142
143 series = data['data_series']
144 expected_len = nsteps + 1
145 if len(series) != expected_len:
146 raise ValueError(
147 f"SurfaceBC series length {len(series)} != nsteps+1={expected_len}. "
148 f"Re-run CreateSurfaceBC with the current Config."
149 )
150
151 # --- Non-dimensionalise ---
152 # optSurfBC=2: store as anomaly (theta_sfc - T_0) / TH_scale.
153 # Subtraction happens here in NumPy float64, so the small differences
154 # (0 to -2.25 K) are stored accurately when JAX casts to float32.
155 if optSurfBC == 1:
156 series_nondim = series / (u_scale * TH_scale)
157 else:
158 series_nondim = (series - T_0) / TH_scale
159
160 return jnp.array(series_nondim)
161
162
163def Initialize_GeoWind():
164 """
165 Returns constant (nx, ny, nz) geostrophic wind arrays from Config scalars
166 Ug2, Vg2. Used when optGeoWind == 0.
167 """
168 Ug = Ug2 * jnp.ones((nx, ny, nz)) / u_scale
169 Vg = Vg2 * jnp.ones((nx, ny, nz)) / u_scale
170 return Ug, Vg
171
172
173def Initialize_GeoWind_Varying():
174 """
175 Load the time- and height-varying geostrophic wind series from GeoWindFile.
176 Called once before the main loop when optGeoWind >= 1.
177
178 Validates that the file's dt_geo matches Config dt and that the series
179 dimensions match nsteps+1 and nz.
180
181 Returns:
182 --------
183 GeoWind_U, GeoWind_V : jnp.ndarray of shape (nsteps+1, nz)
184 Non-dimensional geostrophic wind profiles at every timestep.
185 Index [iteration-1] gives the profile for that iteration.
186 """
187
188 data = np.load(_GeoWindFile)
189
190 dt_geo_file = float(data['dt_geo'])
191 if abs(dt_geo_file - dt) > 1e-6:
192 raise ValueError(
193 f"GeoWind file dt_geo={dt_geo_file:.6f} s does not match "
194 f"Config dt={dt:.6f} s. Re-run CreateGeoWind with the current Config."
195 )
196
197 optGeoWind_file = int(data['optGeoWind'])
198 if optGeoWind_file != optGeoWind:
199 raise ValueError(
200 f"GeoWind file optGeoWind={optGeoWind_file} does not match "
201 f"Config optGeoWind={optGeoWind}. Re-run CreateGeoWind."
202 )
203
204 if 'Ug_series' in data and 'Vg_series' in data:
205 Ug_series = data['Ug_series'] # (nsteps+1, nz) dimensional m/s
206 Vg_series = data['Vg_series'] # (nsteps+1, nz) dimensional m/s
207
208 if Ug_series.shape[0] != nsteps + 1:
209 raise ValueError(
210 f"GeoWind series length {Ug_series.shape[0]} != nsteps+1={nsteps + 1}. "
211 f"Re-run CreateGeoWind with the current Config."
212 )
213 if Ug_series.shape[1] != nz:
214 raise ValueError(
215 f"GeoWind nz={Ug_series.shape[1]} != Config nz={nz}. "
216 f"Re-run CreateGeoWind with the current Config."
217 )
218 else:
219 t_profile = data['t_profile'] # (ntimes,) seconds
220 Ug_profile = data['Ug_profile'] # (ntimes, nz) dimensional m/s
221 Vg_profile = data['Vg_profile'] # (ntimes, nz) dimensional m/s
222
223 if Ug_profile.shape[1] != nz:
224 raise ValueError(
225 f"GeoWind nz={Ug_profile.shape[1]} != Config nz={nz}. "
226 f"Re-run CreateGeoWind with the current Config."
227 )
228
229 t_target = np.arange(nsteps + 1) * dt
230 if t_target[0] < t_profile[0] - 1e-6 or t_target[-1] > t_profile[-1] + 1e-6:
231 raise ValueError(
232 "GeoWind compact profiles do not cover the configured simulation "
233 f"interval 0-{t_target[-1]:.1f} s. Re-run CreateGeoWind."
234 )
235
236 Ug_series = np.zeros((nsteps + 1, nz))
237 Vg_series = np.zeros((nsteps + 1, nz))
238 for k in range(nz):
239 Ug_series[:, k] = np.interp(t_target, t_profile, Ug_profile[:, k])
240 Vg_series[:, k] = np.interp(t_target, t_profile, Vg_profile[:, k])
241
242 # Non-dimensionalise (u_scale = 1 in current JAX-ALFA, kept for generality)
243 GeoWind_U = jnp.array(Ug_series / u_scale)
244 GeoWind_V = jnp.array(Vg_series / u_scale)
245
246 return GeoWind_U, GeoWind_V
247
248
249# ============================================================
250# Large-scale (mesoscale) advection forcing
251# ============================================================
252
253def Initialize_AdvForcing():
254 """
255 Load the time- and height-varying large-scale advection forcing from
256 AdvectionFile. Called once before the main loop when optAdvection >= 1.
257
258 The .npz file must contain:
259 Uadv_series : (nsteps+1, nz) dimensional [m s⁻²]
260 Vadv_series : (nsteps+1, nz) dimensional [m s⁻²]
261 THadv_series : (nsteps+1, nz) dimensional [K s⁻¹] (optional)
262 Qadv_series : (nsteps+1, nz) dimensional [kg/kg s⁻¹] (optional)
263 dt_adv : float [s] — must equal Config dt
264 optAdvection : int32
265
266 Returns:
267 --------
268 AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q :
269 jnp.ndarray of shape (nsteps+1, nz)
270 Non-dimensional advection tendencies at every timestep.
271 Missing series default to zero.
272 Non-dimensionalisation: [m s⁻²] × z_scale (u_scale = TH_scale = Q_scale = 1).
273 """
274
275 data = np.load(_AdvectionFile)
276
277 dt_adv_file = float(data['dt_adv'])
278 if abs(dt_adv_file - dt) > 1e-6:
279 raise ValueError(
280 f"AdvForcing file dt_adv={dt_adv_file:.6f} s does not match "
281 f"Config dt={dt:.6f} s. Re-run CreateAdvForcing with the current Config."
282 )
283
284 optAdvection_file = int(data['optAdvection'])
285 if optAdvection_file != optAdvection:
286 raise ValueError(
287 f"AdvForcing file optAdvection={optAdvection_file} does not match "
288 f"Config optAdvection={optAdvection}. Re-run CreateAdvForcing."
289 )
290
291 Uadv_series = data['Uadv_series'] # (nsteps+1, nz) [m/s^2]
292 Vadv_series = data['Vadv_series'] # (nsteps+1, nz) [m/s^2]
293
294 if Uadv_series.shape[0] != nsteps + 1:
295 raise ValueError(
296 f"AdvForcing series length {Uadv_series.shape[0]} != nsteps+1={nsteps + 1}. "
297 f"Re-run CreateAdvForcing with the current Config."
298 )
299 if Uadv_series.shape[1] != nz:
300 raise ValueError(
301 f"AdvForcing nz={Uadv_series.shape[1]} != Config nz={nz}. "
302 f"Re-run CreateAdvForcing with the current Config."
303 )
304
305 # Nondimensionalise: [m/s^2] * z_scale / u_scale^2 (u_scale = 1)
306 AdvForcing_U = jnp.array(Uadv_series * z_scale / u_scale ** 2)
307 AdvForcing_V = jnp.array(Vadv_series * z_scale / u_scale ** 2)
308
309 if 'THadv_series' in data:
310 THadv_series = data['THadv_series'] # (nsteps+1, nz) [K/s]
311 # Nondimensionalise: [K/s] * z_scale / (u_scale * TH_scale) (both = 1)
312 AdvForcing_TH = jnp.array(THadv_series * z_scale / (u_scale * TH_scale))
313 else:
314 AdvForcing_TH = jnp.zeros((nsteps + 1, nz))
315
316 if 'Qadv_series' in data:
317 Qadv_series = data['Qadv_series'] # (nsteps+1, nz) [kg/kg/s]
318 # Nondimensionalise: [kg/kg/s] * z_scale / (u_scale * Q_scale) (both = 1)
319 AdvForcing_Q = jnp.array(Qadv_series * z_scale / (u_scale * Q_scale))
320 else:
321 AdvForcing_Q = jnp.zeros((nsteps + 1, nz))
322
323 return AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q
324
325
326# ============================================================
327# Moisture field
328# ============================================================
329
330def Initialize_Q():
331 """
332 Load the initial specific humidity field from input/Q.ini.
333 Q is stored as absolute values (kg/kg) — no base-state subtraction.
334
335 Returns:
336 --------
337 Q : jnp.ndarray of shape (nx, ny, nz) [kg/kg]
338 """
339 InputQ = os.path.join(InputDir, 'Q.npy')
340 Q = np.load(InputQ)
341 Q = np.reshape(Q, (nx, ny, nz), order='F')
342 return jnp.array(Q)
343
344
345def Initialize_MoistureSurfaceBC():
346 """
347 Load the time-varying moisture surface BC series from MoistureSurfaceBCFile.
348 Called once before the main loop when optMoisture=1 and optMoistureSurfBC >= 1.
349
350 Returns:
351 --------
352 MoistureSurfaceBC_series : jnp.ndarray of shape (nsteps+1,)
353 Non-dimensional values at every timestep.
354 For optMoistureSurfBC=1: non-dim moisture flux = data / (u_scale * Q_scale)
355 For optMoistureSurfBC=2: surface Q in kg/kg (stored as-is, already dimensional)
356 """
357 data = np.load(_MoistureSurfaceBCFile)
358
359 dt_moist_file = float(data['dt_moist'])
360 if abs(dt_moist_file - dt) > 1e-6:
361 raise ValueError(
362 f"MoistureSurfaceBC file dt_moist={dt_moist_file:.6f} s does not match "
363 f"Config dt={dt:.6f} s. Re-run CreateMoistureSurfaceBC with the current Config."
364 )
365
366 optMoistureSurfBC_file = int(data['optMoistureSurfBC'])
367 if optMoistureSurfBC_file != optMoistureSurfBC:
368 raise ValueError(
369 f"MoistureSurfaceBC file optMoistureSurfBC={optMoistureSurfBC_file} does not match "
370 f"Config optMoistureSurfBC={optMoistureSurfBC}. Re-run CreateMoistureSurfaceBC."
371 )
372
373 series = data['data_series']
374 expected_len = nsteps + 1
375 if len(series) != expected_len:
376 raise ValueError(
377 f"MoistureSurfaceBC series length {len(series)} != nsteps+1={expected_len}. "
378 f"Re-run CreateMoistureSurfaceBC with the current Config."
379 )
380
381 if optMoistureSurfBC == 1:
382 # Flux in kg/kg m/s; non-dimensionalise by u_scale * Q_scale
383 series_nondim = series / (u_scale * Q_scale)
384 else:
385 # optMoistureSurfBC == 2: surface Q (kg/kg); store as-is
386 series_nondim = series
387
388 return jnp.array(series_nondim)
389
390
391# ============================================================
392# Rayleigh damping layer
393# ============================================================
394
395def Initialize_RayleighDampingLayer():
396 """
397 Returns:
398 --------
399 RayleighDampCoeff : jnp.ndarray
400 3D array of size (nx, ny, nz) containing the initialized
401 Rayleigh damping layer coefficients
402 """
403
404 # Inverse non-dimensional relaxation time
405 invRelaxTime_nondim = 1.0 / RelaxTime_nondim
406
407 # Valid for both full and half levels
408 z_top_nondim = l_z / z_scale
409
410 # Calculate the damping layer depth
411 RayleighDampThickness = z_top_nondim - z_damping_nondim
412
413 #--------------------------------------------
414 # Full levels
415 #--------------------------------------------
416
417 # Generate height levels
418 z_nondim = jnp.arange(nz) * dz
419
420 # Create mask for damping region
421 RayleighDampMask = ((z_nondim >= z_damping_nondim) &
422 (z_nondim <= z_top_nondim))
423
424 # Compute damping coefficient where mask is True
425 RayleighDampCoeff1D = jnp.where(
426 RayleighDampMask,
427 0.5 * invRelaxTime_nondim * (1.0 - jnp.cos(
428 jnp.pi * (z_nondim - z_damping_nondim) / RayleighDampThickness)),
429 0.0)
430
431 # Broadcast to 3D array
432 RayleighDampCoeff = jnp.broadcast_to(
433 RayleighDampCoeff1D.reshape(1, 1, nz),
434 (nx, ny, nz))
435
436 #--------------------------------------------
437 # Half levels
438 #--------------------------------------------
439
440 # Generate height levels
441 z_stag_nondim = (jnp.arange(nz) + 0.5) * dz
442
443 # Create mask for damping region
444 RayleighDampMask_stag = ((z_stag_nondim >= z_damping_nondim) &
445 (z_stag_nondim <= z_top_nondim))
446
447 # Compute damping coefficient where mask is True
448 RayleighDampCoeff1D_stag = jnp.where(
449 RayleighDampMask_stag,
450 0.5 * invRelaxTime_nondim * (1.0 - jnp.cos(
451 jnp.pi * (z_stag_nondim - z_damping_nondim) /
452 RayleighDampThickness)),
453 0.0)
454
455 # Broadcast to 3D array
456 RayleighDampCoeff_stag = jnp.broadcast_to(
457 RayleighDampCoeff1D_stag.reshape(1, 1, nz),
458 (nx, ny, nz))
459
460 return RayleighDampCoeff, RayleighDampCoeff_stag