Initialization of Variables

Source Code: Initialization

Initialization.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Initialization.py
 18==============================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: loads velocity and temperature fields & reshape them
 25"""
 26
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31
 32import os
 33import jax
 34import jax.numpy as jnp
 35
 36# Import configuration from namelist
 37from ..config.ConfigLoader import *
 38
 39# Import derived variables
 40from ..config.DerivedVars import *
 41
 42# Import StagGridAvg
 43from ..utilities.Utilities import StagGridAvg
 44
 45InputDir = os.path.join(os.environ['JAXALFA_RUNDIR'], 'input')
 46
 47# Absolute paths resolved once at import time
 48_SurfaceBCFile        = os.path.join(os.environ['JAXALFA_RUNDIR'], SurfaceBCFile)
 49_GeoWindFile          = os.path.join(os.environ['JAXALFA_RUNDIR'], GeoWindFile)
 50_AdvectionFile        = os.path.join(os.environ['JAXALFA_RUNDIR'], AdvectionFile)
 51_MoistureSurfaceBCFile = os.path.join(os.environ['JAXALFA_RUNDIR'], MoistureSurfaceBCFile)
 52
 53
 54# ============================================================
 55# Load velocity field
 56# ============================================================
 57
 58def Initialize_uvw():
 59    """
 60    Returns:
 61    --------
 62    u, v, w : ndarray
 63        3D arrays of size (nx, ny, nz) containing the initialized
 64        velocity components
 65    """
 66
 67    InputVelocity = os.path.join(InputDir, 'vel.npy')
 68    vel = np.load(InputVelocity)
 69    u = vel[:, 0] - Ugal
 70    v = vel[:, 1]
 71    w = vel[:, 2]
 72
 73    u = np.reshape(u, (nx, ny, nz), order='F') / u_scale
 74    v = np.reshape(v, (nx, ny, nz), order='F') / u_scale
 75    w = np.reshape(w, (nx, ny, nz), order='F') / u_scale
 76
 77    return jnp.array(u), jnp.array(v), jnp.array(w)
 78
 79
 80# ============================================================
 81# Load potential temperature field
 82# ============================================================
 83
 84def Initialize_TH():
 85    """
 86    Returns:
 87    --------
 88    TH : ndarray
 89        3D array of size (nx, ny, nz) containing the initialized
 90        potential temperature field
 91    """
 92
 93    InputTH = os.path.join(InputDir, 'TH.npy')
 94    TH = np.load(InputTH)
 95
 96    TH = np.reshape(TH, (nx, ny, nz), order='F')
 97
 98    # Subtract base state in numpy float64 before JAX cast; TH is stored as
 99    # anomaly TH' = TH - T_0 throughout the simulation.
100    return jnp.array(TH - T_0_nondim)
101
102
103# ============================================================
104# Geostrophic wind components
105# ============================================================
106
107def Initialize_SurfaceBC():
108    """
109    Load the time-varying surface BC series (heat flux or surface temperature)
110    from SurfaceBCFile. Called once before the main loop when optSurfBC >= 1.
111
112    Validates that the file's dt_surf matches Config dt, and that the series
113    length matches nsteps+1.
114
115    Returns:
116    --------
117    SurfaceBC_series : jnp.ndarray of shape (nsteps+1,)
118        Non-dimensional surface BC values at every timestep from t=0 to t=SimTime.
119        For optSurfBC=1: non-dim heat flux     = data / (u_scale * TH_scale)
120        For optSurfBC=2: non-dim temp anomaly  = (data - T_0) / TH_scale
121            Stored as anomaly so float32 represents small values (~0 to -2.25 K)
122            rather than absolute temperature (~265 K). Surface flux functions
123            expect this anomaly and compare it against TH_air anomalies.
124    """
125
126    data = np.load(_SurfaceBCFile)
127
128    # --- Validation ---
129    dt_surf_file = float(data['dt_surf'])
130    if abs(dt_surf_file - dt) > 1e-6:
131        raise ValueError(
132            f"SurfaceBC file dt_surf={dt_surf_file:.6f} s does not match "
133            f"Config dt={dt:.6f} s. Re-run CreateSurfaceBC with the current Config."
134        )
135
136    optSurfBC_file = int(data['optSurfBC'])
137    if optSurfBC_file != optSurfBC:
138        raise ValueError(
139            f"SurfaceBC file optSurfBC={optSurfBC_file} does not match "
140            f"Config optSurfBC={optSurfBC}. Re-run CreateSurfaceBC."
141        )
142
143    series = data['data_series']
144    expected_len = nsteps + 1
145    if len(series) != expected_len:
146        raise ValueError(
147            f"SurfaceBC series length {len(series)} != nsteps+1={expected_len}. "
148            f"Re-run CreateSurfaceBC with the current Config."
149        )
150
151    # --- Non-dimensionalise ---
152    # optSurfBC=2: store as anomaly (theta_sfc - T_0) / TH_scale.
153    # Subtraction happens here in NumPy float64, so the small differences
154    # (0 to -2.25 K) are stored accurately when JAX casts to float32.
155    if optSurfBC == 1:
156        series_nondim = series / (u_scale * TH_scale)
157    else:
158        series_nondim = (series - T_0) / TH_scale
159
160    return jnp.array(series_nondim)
161
162
163def Initialize_GeoWind():
164    """
165    Returns constant (nx, ny, nz) geostrophic wind arrays from Config scalars
166    Ug2, Vg2.  Used when optGeoWind == 0.
167    """
168    Ug = Ug2 * jnp.ones((nx, ny, nz)) / u_scale
169    Vg = Vg2 * jnp.ones((nx, ny, nz)) / u_scale
170    return Ug, Vg
171
172
173def Initialize_GeoWind_Varying():
174    """
175    Load the time- and height-varying geostrophic wind series from GeoWindFile.
176    Called once before the main loop when optGeoWind >= 1.
177
178    Validates that the file's dt_geo matches Config dt and that the series
179    dimensions match nsteps+1 and nz.
180
181    Returns:
182    --------
183    GeoWind_U, GeoWind_V : jnp.ndarray of shape (nsteps+1, nz)
184        Non-dimensional geostrophic wind profiles at every timestep.
185        Index [iteration-1] gives the profile for that iteration.
186    """
187
188    data = np.load(_GeoWindFile)
189
190    dt_geo_file = float(data['dt_geo'])
191    if abs(dt_geo_file - dt) > 1e-6:
192        raise ValueError(
193            f"GeoWind file dt_geo={dt_geo_file:.6f} s does not match "
194            f"Config dt={dt:.6f} s. Re-run CreateGeoWind with the current Config."
195        )
196
197    optGeoWind_file = int(data['optGeoWind'])
198    if optGeoWind_file != optGeoWind:
199        raise ValueError(
200            f"GeoWind file optGeoWind={optGeoWind_file} does not match "
201            f"Config optGeoWind={optGeoWind}. Re-run CreateGeoWind."
202        )
203
204    if 'Ug_series' in data and 'Vg_series' in data:
205        Ug_series = data['Ug_series']   # (nsteps+1, nz) dimensional m/s
206        Vg_series = data['Vg_series']   # (nsteps+1, nz) dimensional m/s
207
208        if Ug_series.shape[0] != nsteps + 1:
209            raise ValueError(
210                f"GeoWind series length {Ug_series.shape[0]} != nsteps+1={nsteps + 1}. "
211                f"Re-run CreateGeoWind with the current Config."
212            )
213        if Ug_series.shape[1] != nz:
214            raise ValueError(
215                f"GeoWind nz={Ug_series.shape[1]} != Config nz={nz}. "
216                f"Re-run CreateGeoWind with the current Config."
217            )
218    else:
219        t_profile = data['t_profile']       # (ntimes,) seconds
220        Ug_profile = data['Ug_profile']     # (ntimes, nz) dimensional m/s
221        Vg_profile = data['Vg_profile']     # (ntimes, nz) dimensional m/s
222
223        if Ug_profile.shape[1] != nz:
224            raise ValueError(
225                f"GeoWind nz={Ug_profile.shape[1]} != Config nz={nz}. "
226                f"Re-run CreateGeoWind with the current Config."
227            )
228
229        t_target = np.arange(nsteps + 1) * dt
230        if t_target[0] < t_profile[0] - 1e-6 or t_target[-1] > t_profile[-1] + 1e-6:
231            raise ValueError(
232                "GeoWind compact profiles do not cover the configured simulation "
233                f"interval 0-{t_target[-1]:.1f} s. Re-run CreateGeoWind."
234            )
235
236        Ug_series = np.zeros((nsteps + 1, nz))
237        Vg_series = np.zeros((nsteps + 1, nz))
238        for k in range(nz):
239            Ug_series[:, k] = np.interp(t_target, t_profile, Ug_profile[:, k])
240            Vg_series[:, k] = np.interp(t_target, t_profile, Vg_profile[:, k])
241
242    # Non-dimensionalise (u_scale = 1 in current JAX-ALFA, kept for generality)
243    GeoWind_U = jnp.array(Ug_series / u_scale)
244    GeoWind_V = jnp.array(Vg_series / u_scale)
245
246    return GeoWind_U, GeoWind_V
247
248
249# ============================================================
250# Large-scale (mesoscale) advection forcing
251# ============================================================
252
253def Initialize_AdvForcing():
254    """
255    Load the time- and height-varying large-scale advection forcing from
256    AdvectionFile.  Called once before the main loop when optAdvection >= 1.
257
258    The .npz file must contain:
259        Uadv_series  : (nsteps+1, nz)  dimensional  [m s⁻²]
260        Vadv_series  : (nsteps+1, nz)  dimensional  [m s⁻²]
261        THadv_series : (nsteps+1, nz)  dimensional  [K s⁻¹]   (optional)
262        Qadv_series  : (nsteps+1, nz)  dimensional  [kg/kg s⁻¹] (optional)
263        dt_adv       : float           [s]  — must equal Config dt
264        optAdvection : int32
265
266    Returns:
267    --------
268    AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q :
269        jnp.ndarray of shape (nsteps+1, nz)
270        Non-dimensional advection tendencies at every timestep.
271        Missing series default to zero.
272        Non-dimensionalisation: [m s⁻²] × z_scale  (u_scale = TH_scale = Q_scale = 1).
273    """
274
275    data = np.load(_AdvectionFile)
276
277    dt_adv_file = float(data['dt_adv'])
278    if abs(dt_adv_file - dt) > 1e-6:
279        raise ValueError(
280            f"AdvForcing file dt_adv={dt_adv_file:.6f} s does not match "
281            f"Config dt={dt:.6f} s. Re-run CreateAdvForcing with the current Config."
282        )
283
284    optAdvection_file = int(data['optAdvection'])
285    if optAdvection_file != optAdvection:
286        raise ValueError(
287            f"AdvForcing file optAdvection={optAdvection_file} does not match "
288            f"Config optAdvection={optAdvection}. Re-run CreateAdvForcing."
289        )
290
291    Uadv_series = data['Uadv_series']   # (nsteps+1, nz)  [m/s^2]
292    Vadv_series = data['Vadv_series']   # (nsteps+1, nz)  [m/s^2]
293
294    if Uadv_series.shape[0] != nsteps + 1:
295        raise ValueError(
296            f"AdvForcing series length {Uadv_series.shape[0]} != nsteps+1={nsteps + 1}. "
297            f"Re-run CreateAdvForcing with the current Config."
298        )
299    if Uadv_series.shape[1] != nz:
300        raise ValueError(
301            f"AdvForcing nz={Uadv_series.shape[1]} != Config nz={nz}. "
302            f"Re-run CreateAdvForcing with the current Config."
303        )
304
305    # Nondimensionalise: [m/s^2] * z_scale / u_scale^2  (u_scale = 1)
306    AdvForcing_U  = jnp.array(Uadv_series * z_scale / u_scale ** 2)
307    AdvForcing_V  = jnp.array(Vadv_series * z_scale / u_scale ** 2)
308
309    if 'THadv_series' in data:
310        THadv_series = data['THadv_series']   # (nsteps+1, nz)  [K/s]
311        # Nondimensionalise: [K/s] * z_scale / (u_scale * TH_scale)  (both = 1)
312        AdvForcing_TH = jnp.array(THadv_series * z_scale / (u_scale * TH_scale))
313    else:
314        AdvForcing_TH = jnp.zeros((nsteps + 1, nz))
315
316    if 'Qadv_series' in data:
317        Qadv_series = data['Qadv_series']   # (nsteps+1, nz)  [kg/kg/s]
318        # Nondimensionalise: [kg/kg/s] * z_scale / (u_scale * Q_scale)  (both = 1)
319        AdvForcing_Q = jnp.array(Qadv_series * z_scale / (u_scale * Q_scale))
320    else:
321        AdvForcing_Q = jnp.zeros((nsteps + 1, nz))
322
323    return AdvForcing_U, AdvForcing_V, AdvForcing_TH, AdvForcing_Q
324
325
326# ============================================================
327# Moisture field
328# ============================================================
329
330def Initialize_Q():
331    """
332    Load the initial specific humidity field from input/Q.ini.
333    Q is stored as absolute values (kg/kg) — no base-state subtraction.
334
335    Returns:
336    --------
337    Q : jnp.ndarray of shape (nx, ny, nz)  [kg/kg]
338    """
339    InputQ = os.path.join(InputDir, 'Q.npy')
340    Q = np.load(InputQ)
341    Q = np.reshape(Q, (nx, ny, nz), order='F')
342    return jnp.array(Q)
343
344
345def Initialize_MoistureSurfaceBC():
346    """
347    Load the time-varying moisture surface BC series from MoistureSurfaceBCFile.
348    Called once before the main loop when optMoisture=1 and optMoistureSurfBC >= 1.
349
350    Returns:
351    --------
352    MoistureSurfaceBC_series : jnp.ndarray of shape (nsteps+1,)
353        Non-dimensional values at every timestep.
354        For optMoistureSurfBC=1: non-dim moisture flux = data / (u_scale * Q_scale)
355        For optMoistureSurfBC=2: surface Q in kg/kg (stored as-is, already dimensional)
356    """
357    data = np.load(_MoistureSurfaceBCFile)
358
359    dt_moist_file = float(data['dt_moist'])
360    if abs(dt_moist_file - dt) > 1e-6:
361        raise ValueError(
362            f"MoistureSurfaceBC file dt_moist={dt_moist_file:.6f} s does not match "
363            f"Config dt={dt:.6f} s. Re-run CreateMoistureSurfaceBC with the current Config."
364        )
365
366    optMoistureSurfBC_file = int(data['optMoistureSurfBC'])
367    if optMoistureSurfBC_file != optMoistureSurfBC:
368        raise ValueError(
369            f"MoistureSurfaceBC file optMoistureSurfBC={optMoistureSurfBC_file} does not match "
370            f"Config optMoistureSurfBC={optMoistureSurfBC}. Re-run CreateMoistureSurfaceBC."
371        )
372
373    series = data['data_series']
374    expected_len = nsteps + 1
375    if len(series) != expected_len:
376        raise ValueError(
377            f"MoistureSurfaceBC series length {len(series)} != nsteps+1={expected_len}. "
378            f"Re-run CreateMoistureSurfaceBC with the current Config."
379        )
380
381    if optMoistureSurfBC == 1:
382        # Flux in kg/kg m/s; non-dimensionalise by u_scale * Q_scale
383        series_nondim = series / (u_scale * Q_scale)
384    else:
385        # optMoistureSurfBC == 2: surface Q (kg/kg); store as-is
386        series_nondim = series
387
388    return jnp.array(series_nondim)
389
390
391# ============================================================
392# Rayleigh damping layer
393# ============================================================
394
395def Initialize_RayleighDampingLayer():
396    """
397    Returns:
398    --------
399    RayleighDampCoeff : jnp.ndarray
400        3D array of size (nx, ny, nz) containing the initialized
401        Rayleigh damping layer coefficients
402    """
403
404    # Inverse non-dimensional relaxation time
405    invRelaxTime_nondim = 1.0 / RelaxTime_nondim
406
407    # Valid for both full and half levels
408    z_top_nondim = l_z / z_scale
409
410    # Calculate the damping layer depth
411    RayleighDampThickness = z_top_nondim - z_damping_nondim
412
413    #--------------------------------------------
414    # Full levels
415    #--------------------------------------------
416
417    # Generate height levels
418    z_nondim = jnp.arange(nz) * dz
419
420    # Create mask for damping region
421    RayleighDampMask = ((z_nondim >= z_damping_nondim) &
422                        (z_nondim <= z_top_nondim))
423
424    # Compute damping coefficient where mask is True
425    RayleighDampCoeff1D = jnp.where(
426        RayleighDampMask,
427        0.5 * invRelaxTime_nondim * (1.0 - jnp.cos(
428            jnp.pi * (z_nondim - z_damping_nondim) / RayleighDampThickness)),
429        0.0)
430
431    # Broadcast to 3D array
432    RayleighDampCoeff = jnp.broadcast_to(
433        RayleighDampCoeff1D.reshape(1, 1, nz),
434        (nx, ny, nz))
435
436    #--------------------------------------------
437    # Half levels
438    #--------------------------------------------
439
440    # Generate height levels
441    z_stag_nondim = (jnp.arange(nz) + 0.5) * dz
442
443    # Create mask for damping region
444    RayleighDampMask_stag = ((z_stag_nondim >= z_damping_nondim) &
445                             (z_stag_nondim <= z_top_nondim))
446
447    # Compute damping coefficient where mask is True
448    RayleighDampCoeff1D_stag = jnp.where(
449        RayleighDampMask_stag,
450        0.5 * invRelaxTime_nondim * (1.0 - jnp.cos(
451            jnp.pi * (z_stag_nondim - z_damping_nondim) /
452            RayleighDampThickness)),
453        0.0)
454
455    # Broadcast to 3D array
456    RayleighDampCoeff_stag = jnp.broadcast_to(
457        RayleighDampCoeff1D_stag.reshape(1, 1, nz),
458        (nx, ny, nz))
459
460    return RayleighDampCoeff, RayleighDampCoeff_stag