Surface Flux Parameterization

Source Code: SurfaceFlux

SurfaceFlux.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: SurfaceFlux.py
 18==============================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: surface flux calculation module.
 25              Supports constant flux (optSurfBC=0), time-varying flux
 26              (optSurfBC=1), and prescribed surface temperature (optSurfBC=2),
 27              each in homogeneous (optSurfFlux=0) and heterogeneous
 28              (optSurfFlux=1) flavours.
 29
 30              Sign convention: qz = -u* x th*
 31              Stable BL: qz < 0 (downward), th* > 0
 32"""
 33
 34# ============================================================
 35#  Imports
 36# ============================================================
 37
 38import jax
 39import jax.numpy as jnp
 40
 41# Import configuration from namelist
 42from ..config.ConfigLoader import *
 43
 44# Import derived variables
 45from ..config.DerivedVars import *
 46
 47
 48# ============================================================
 49#  Monin-Obukhov similarity functions
 50# ============================================================
 51
 52@jax.jit
 53def MOSTstable(z_over_L):
 54    """
 55    Compute Monin-Obukhov stability functions for stable conditions
 56
 57    Parameters:
 58    -----------
 59    z_over_L : jnp.ndarray
 60        Stability parameter z/L (height / Obukhov length)
 61
 62    Returns:
 63    --------
 64    psi_m : jnp.ndarray
 65        Stability function for momentum
 66    psi_h : jnp.ndarray
 67        Stability function for heat
 68    fi_m : jnp.ndarray
 69        Normalized gradient function for momentum
 70    fi_h : jnp.ndarray
 71        Normalized gradient function for heat
 72    """
 73
 74    psi_m = 5.0 * z_over_L
 75    psi_h = 5.0 * z_over_L
 76
 77    fi_m = 1.0 + psi_m
 78    fi_h = 1.0 + psi_h
 79
 80    return psi_m, psi_h, fi_m, fi_h
 81
 82
 83def MOSTunstable(z_over_L):
 84    """
 85    Compute Monin-Obukhov stability functions for unstable conditions
 86
 87    Parameters:
 88    -----------
 89    z_over_L : jnp.ndarray of shape (nx, ny)
 90        Stability parameter z/L (height / Obukhov length)
 91
 92    Returns:
 93    --------
 94    psi2D_m : jnp.ndarray of shape (nx, ny)
 95        Stability function for momentum
 96    psi2D_h : jnp.ndarray of shape (nx, ny)
 97        Stability function for heat
 98    fi2D_m : jnp.ndarray of shape (nx, ny)
 99        Normalized gradient function for momentum
100    fi2D_h : jnp.ndarray of shape (nx, ny)
101        Normalized gradient function for heat
102    """
103
104    x = (1 - 15 * z_over_L) ** 0.25
105
106    psi2D_m = (-2 * jnp.log(0.5 * (1 + x)) - jnp.log(0.5 * (1 + x ** 2))
107               + 2 * jnp.arctan(x) - jnp.pi / 2)
108    psi2D_h = -2 * jnp.log(0.5 * (1 + x ** 2))
109
110    fi2D_m = 1.0 / x
111    fi2D_h = (1.0 / x) ** 2
112    return psi2D_m, psi2D_h, fi2D_m, fi2D_h
113
114
115# ============================================================
116#  Shared helper: update MOST stability functions
117# ============================================================
118
119def _update_MOSTfunctions(ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0,
120                          psi2D_h, psi2D_h0):
121    """
122    Compute Obukhov length and update all MOST stability functions.
123    Used internally by all six surface flux variants.
124
125    Returns updated (psi2D_m, psi2D_m0, psi2D_h, psi2D_h0, fi2D_m, fi2D_h)
126    and invOB.
127    """
128    invOB = -(vonk * g_nondim * qz_sfc_avg) / ((ustar ** 3) * TH_ref)
129    is_stable = qz_sfc_avg <= 0
130
131    z1_over_L   = (0.5 * dz) * invOB
132    z0m_over_L  = (z0m / z_scale) * invOB
133
134    # When zTemperature > 0 the surface temperature is observed at that screen
135    # height (m), not at z0T.  Use it as the reference for psi_h0 and denom_h.
136    _z_ref_T        = zTemperature if zTemperature > 0.0 else z0T
137    z_ref_T_over_L  = (_z_ref_T / z_scale) * invOB
138
139    psi2D_m, psi2D_h, fi2D_m, fi2D_h = jax.lax.cond(
140        is_stable,
141        lambda _: MOSTstable(z1_over_L),
142        lambda _: MOSTunstable(z1_over_L),
143        operand=None)
144
145    psi2D_m0, _, _, _ = jax.lax.cond(
146        is_stable,
147        lambda _: MOSTstable(z0m_over_L),
148        lambda _: MOSTunstable(z0m_over_L),
149        operand=None)
150
151    _, psi2D_h0, _, _ = jax.lax.cond(
152        is_stable,
153        lambda _: MOSTstable(z_ref_T_over_L),
154        lambda _: MOSTunstable(z_ref_T_over_L),
155        operand=None)
156
157    MOSTfunctions = (psi2D_m, psi2D_m0, psi2D_h, psi2D_h0, fi2D_m, fi2D_h)
158    return MOSTfunctions, invOB
159
160
161# ============================================================
162#  optSurfBC = 0 : constant prescribed heat flux
163# ============================================================
164
165@jax.jit
166def SurfaceFlux_HomogeneousConstantFlux(u, v, TH, MOSTfunctions):
167    """
168    optSurfFlux=0, optSurfBC=0: homogeneous surface, constant heat flux.
169
170    Parameters:
171    -----------
172    u, v : jnp.ndarray of shape (nx, ny, nz)
173    TH   : jnp.ndarray of shape (nx, ny, nz)
174    MOSTfunctions : tuple of six (nx, ny) arrays
175
176    Returns:
177    --------
178    M_sfc_loc    : (nx, ny) surface wind speed
179    ustar        : (nx, ny) friction velocity
180    qz_sfc_avg   : scalar, non-dimensional surface heat flux (= qz = -u* x th*)
181    invOB        : (nx, ny) inverse Obukhov length
182    MOSTfunctions: updated tuple
183    """
184
185    One2D = jnp.ones((nx, ny))
186
187    (psi2D_m, psi2D_m0,
188     psi2D_h, psi2D_h0,
189     fi2D_m, fi2D_h) = MOSTfunctions
190
191    qz_sfc_avg = jnp.mean(qz_sfc)
192
193    M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
194    M_sfc_loc = M_sfc_avg * One2D
195
196    # TH is anomaly; add T_0 to get absolute temperature for Obukhov length.
197    TH_ref = (jnp.mean(TH[:, :, 0]) + T_0_nondim) * One2D
198
199    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
200    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
201
202    MOSTfunctions, invOB = _update_MOSTfunctions(
203        ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
204
205    return M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions
206
207
208@jax.jit
209def SurfaceFlux_HeterogeneousConstantFlux(u, v, TH, MOSTfunctions):
210    """
211    optSurfFlux=1, optSurfBC=0: heterogeneous surface, constant heat flux.
212
213    Returns:
214    --------
215    M_sfc_loc    : (nx, ny)
216    ustar        : (nx, ny)
217    qz_sfc_avg   : scalar, non-dimensional surface heat flux
218    invOB        : (nx, ny)
219    MOSTfunctions: updated tuple
220    """
221
222    (psi2D_m, psi2D_m0,
223     psi2D_h, psi2D_h0,
224     fi2D_m, fi2D_h) = MOSTfunctions
225
226    qz_sfc_avg = jnp.mean(qz_sfc)
227
228    M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
229    # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
230    TH_ref = TH[:, :, 0] + T_0_nondim
231
232    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
233    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
234
235    MOSTfunctions, invOB = _update_MOSTfunctions(
236        ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
237
238    return M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions
239
240
241# ============================================================
242#  optSurfBC = 1 : time-varying prescribed heat flux
243# ============================================================
244
245@jax.jit
246def SurfaceFlux_HomogeneousVaryingFlux(u, v, TH, qz_sfc_t, MOSTfunctions):
247    """
248    optSurfFlux=0, optSurfBC=1: homogeneous surface, time-varying heat flux.
249
250    Parameters:
251    -----------
252    qz_sfc_t : scalar JAX value, non-dimensional heat flux at current timestep
253               (loaded from SurfaceBC.npz series, already non-dimensionalised)
254
255    Returns:
256    --------
257    M_sfc_loc    : (nx, ny)
258    ustar        : (nx, ny)
259    qz_sfc_2D    : (nx, ny) spatially uniform flux field
260    qz_sfc_avg   : scalar
261    invOB        : (nx, ny)
262    MOSTfunctions: updated tuple
263    """
264
265    One2D = jnp.ones((nx, ny))
266
267    (psi2D_m, psi2D_m0,
268     psi2D_h, psi2D_h0,
269     fi2D_m, fi2D_h) = MOSTfunctions
270
271    qz_sfc_avg = qz_sfc_t
272    qz_sfc_2D  = qz_sfc_t * One2D
273
274    M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
275    M_sfc_loc = M_sfc_avg * One2D
276
277    # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
278    TH_ref = (jnp.mean(TH[:, :, 0]) + T_0_nondim) * One2D
279
280    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
281    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
282
283    MOSTfunctions, invOB = _update_MOSTfunctions(
284        ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
285
286    return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
287
288
289@jax.jit
290def SurfaceFlux_HeterogeneousVaryingFlux(u, v, TH, qz_sfc_t, MOSTfunctions):
291    """
292    optSurfFlux=1, optSurfBC=1: heterogeneous surface, time-varying heat flux.
293
294    Parameters:
295    -----------
296    qz_sfc_t : scalar JAX value, non-dimensional heat flux at current timestep
297
298    Returns:
299    --------
300    M_sfc_loc    : (nx, ny)
301    ustar        : (nx, ny)
302    qz_sfc_2D    : (nx, ny)
303    qz_sfc_avg   : scalar
304    invOB        : (nx, ny)
305    MOSTfunctions: updated tuple
306    """
307
308    One2D = jnp.ones((nx, ny))
309
310    (psi2D_m, psi2D_m0,
311     psi2D_h, psi2D_h0,
312     fi2D_m, fi2D_h) = MOSTfunctions
313
314    qz_sfc_avg = qz_sfc_t
315    qz_sfc_2D  = qz_sfc_t * One2D
316
317    M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
318    # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
319    TH_ref = TH[:, :, 0] + T_0_nondim
320
321    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
322    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
323
324    MOSTfunctions, invOB = _update_MOSTfunctions(
325        ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
326
327    return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
328
329
330# ============================================================
331#  optSurfBC = 2 : time-varying prescribed surface temperature
332# ============================================================
333
334@jax.jit
335def SurfaceFlux_HomogeneousPrescribedTemperature(u, v, TH, TH_sfc_t,
336                                                  MOSTfunctions):
337    """
338    optSurfFlux=0, optSurfBC=2: homogeneous surface, prescribed T_s(t).
339
340    The surface heat flux is diagnosed from MOST:
341        qz = u* x vonk x (TH_s - TH_air) / denom_h
342    consistent with qz = -u* x th*, where th* = vonk x (TH_air - TH_s) / denom_h.
343
344    Parameters:
345    -----------
346    TH_sfc_t : scalar JAX value, non-dimensional surface temperature anomaly
347               (theta_sfc - T_0) / TH_scale at current timestep, as returned
348               by Initialize_SurfaceBC for optSurfBC=2
349
350    Returns:
351    --------
352    M_sfc_loc    : (nx, ny)
353    ustar        : (nx, ny)
354    qz_sfc_2D    : (nx, ny) diagnosed surface heat flux
355    qz_sfc_avg   : scalar
356    invOB        : (nx, ny)
357    MOSTfunctions: updated tuple
358    """
359
360    One2D = jnp.ones((nx, ny))
361
362    (psi2D_m, psi2D_m0,
363     psi2D_h, psi2D_h0,
364     fi2D_m, fi2D_h) = MOSTfunctions
365
366    # Planar-mean surface wind speed
367    M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
368    M_sfc_loc = M_sfc_avg * One2D
369
370    # TH is stored as anomaly (TH - T_0); TH_sfc_t is also an anomaly from
371    # Initialize_SurfaceBC. Both are anomalies so the difference is direct.
372    TH_air_anom_avg = jnp.mean(TH[:, :, 0])
373    TH_air_loc      = (TH_air_anom_avg + T_0_nondim) * One2D  # absolute for MOST
374
375    # Friction velocity (with floor to prevent near-zero division)
376    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
377    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
378
379    # Diagnose surface heat flux — both TH_sfc_t and TH_air_anom_avg are anomalies
380    # qz = -u* x th*,  th* = vonk x (TH_air - TH_s) / denom_h  (>0 for stable)
381    _z_ref_T   = zTemperature if zTemperature > 0.0 else z0T
382    denom_h    = jnp.log(0.5 * dz * z_scale / _z_ref_T) + psi2D_h - psi2D_h0
383    qz_sfc_2D  = ustar * vonk * (TH_sfc_t - TH_air_anom_avg) * One2D / denom_h
384    qz_sfc_avg = jnp.mean(qz_sfc_2D)
385
386    MOSTfunctions, invOB = _update_MOSTfunctions(
387        ustar, qz_sfc_avg, TH_air_loc, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
388
389    return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
390
391
392@jax.jit
393def SurfaceFlux_HeterogeneousPrescribedTemperature(u, v, TH, TH_sfc_t,
394                                                    MOSTfunctions):
395    """
396    optSurfFlux=1, optSurfBC=2: heterogeneous surface, prescribed T_s(t).
397
398    Uses local (per-column) wind speed and air temperature.
399    TH_sfc_t is spatially uniform but temporally varying.
400
401    Parameters:
402    -----------
403    TH_sfc_t : scalar JAX value, non-dimensional surface temperature anomaly
404               (theta_sfc - T_0) / TH_scale at current timestep, as returned
405               by Initialize_SurfaceBC for optSurfBC=2
406
407    Returns:
408    --------
409    M_sfc_loc    : (nx, ny)
410    ustar        : (nx, ny)
411    qz_sfc_2D    : (nx, ny) diagnosed surface heat flux
412    qz_sfc_avg   : scalar
413    invOB        : (nx, ny)
414    MOSTfunctions: updated tuple
415    """
416
417    (psi2D_m, psi2D_m0,
418     psi2D_h, psi2D_h0,
419     fi2D_m, fi2D_h) = MOSTfunctions
420
421    # Local surface wind speed
422    M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
423
424    # TH is stored as anomaly (TH - T_0); TH_sfc_t is also an anomaly.
425    TH_air_anom_loc = TH[:, :, 0]
426
427    # Friction velocity
428    denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
429    ustar   = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
430
431    # Diagnose surface heat flux — both TH_sfc_t and TH_air_anom_loc are anomalies
432    _z_ref_T   = zTemperature if zTemperature > 0.0 else z0T
433    denom_h    = jnp.log(0.5 * dz * z_scale / _z_ref_T) + psi2D_h - psi2D_h0
434    qz_sfc_2D  = ustar * vonk * (TH_sfc_t - TH_air_anom_loc) / denom_h
435    qz_sfc_avg = jnp.mean(qz_sfc_2D)
436
437    # Absolute TH_air as reference for Obukhov length
438    TH_air_ref = (jnp.mean(TH_air_anom_loc) + T_0_nondim) * jnp.ones((nx, ny))
439
440    MOSTfunctions, invOB = _update_MOSTfunctions(
441        ustar, qz_sfc_avg, TH_air_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
442
443    return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
444
445
446# ============================================================
447#  optMoistureSurfBC = 2 : time-varying prescribed surface humidity
448# ============================================================
449
450@jax.jit
451def SurfaceMoistureFlux_HomogeneousPrescribedQ(Q, ustar, Q_sfc_t, MOSTfunctions):
452    """
453    optMoistureSurfBC=2, optSurfFlux=0: diagnose surface moisture flux from
454    prescribed surface specific humidity (spatially homogeneous).
455    Uses the same MOST stability functions as heat
456    (turbulent Schmidt number = turbulent Prandtl number).
457
458    Parameters:
459    -----------
460    Q : jnp.ndarray (nx, ny, nz) — specific humidity (kg/kg)
461    ustar : jnp.ndarray (nx, ny) — friction velocity
462    Q_sfc_t : scalar — prescribed surface Q at current timestep (kg/kg)
463    MOSTfunctions : tuple of six (nx, ny) stability arrays
464
465    Returns:
466    --------
467    qm_sfc_2D : jnp.ndarray (nx, ny) — surface moisture flux w'q' (non-dim)
468    """
469    One2D = jnp.ones((nx, ny))
470    (_, _, psi2D_h, psi2D_h0, _, _) = MOSTfunctions
471
472    Q_air_avg = jnp.mean(Q[:, :, 0])
473    _z_ref_Q  = zMoisture if zMoisture > 0.0 else z0T
474    denom_qm  = jnp.log(0.5 * dz * z_scale / _z_ref_Q) + psi2D_h - psi2D_h0
475    qm_sfc_2D = ustar * vonk * (Q_sfc_t - Q_air_avg) * One2D / denom_qm
476
477    return qm_sfc_2D
478
479
480@jax.jit
481def SurfaceMoistureFlux_HeterogeneousPrescribedQ(Q, ustar, Q_sfc_t, MOSTfunctions):
482    """
483    optMoistureSurfBC=2, optSurfFlux=1: per-column surface moisture flux from
484    prescribed surface specific humidity.
485
486    Parameters:
487    -----------
488    Q : jnp.ndarray (nx, ny, nz) — specific humidity (kg/kg)
489    ustar : jnp.ndarray (nx, ny) — friction velocity (per column)
490    Q_sfc_t : scalar — prescribed surface Q (kg/kg, spatially uniform)
491    MOSTfunctions : tuple of six (nx, ny) stability arrays
492
493    Returns:
494    --------
495    qm_sfc_2D : jnp.ndarray (nx, ny) — surface moisture flux (per column)
496    """
497    (_, _, psi2D_h, psi2D_h0, _, _) = MOSTfunctions
498
499    Q_air_loc = Q[:, :, 0]
500    _z_ref_Q  = zMoisture if zMoisture > 0.0 else z0T
501    denom_qm  = jnp.log(0.5 * dz * z_scale / _z_ref_Q) + psi2D_h - psi2D_h0
502    qm_sfc_2D = ustar * vonk * (Q_sfc_t - Q_air_loc) / denom_qm
503
504    return qm_sfc_2D