Surface Flux Parameterization
Source Code: SurfaceFlux
SurfaceFlux.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: SurfaceFlux.py
18==============================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-29
24:Description: surface flux calculation module.
25 Supports constant flux (optSurfBC=0), time-varying flux
26 (optSurfBC=1), and prescribed surface temperature (optSurfBC=2),
27 each in homogeneous (optSurfFlux=0) and heterogeneous
28 (optSurfFlux=1) flavours.
29
30 Sign convention: qz = -u* x th*
31 Stable BL: qz < 0 (downward), th* > 0
32"""
33
34# ============================================================
35# Imports
36# ============================================================
37
38import jax
39import jax.numpy as jnp
40
41# Import configuration from namelist
42from ..config.ConfigLoader import *
43
44# Import derived variables
45from ..config.DerivedVars import *
46
47
48# ============================================================
49# Monin-Obukhov similarity functions
50# ============================================================
51
52@jax.jit
53def MOSTstable(z_over_L):
54 """
55 Compute Monin-Obukhov stability functions for stable conditions
56
57 Parameters:
58 -----------
59 z_over_L : jnp.ndarray
60 Stability parameter z/L (height / Obukhov length)
61
62 Returns:
63 --------
64 psi_m : jnp.ndarray
65 Stability function for momentum
66 psi_h : jnp.ndarray
67 Stability function for heat
68 fi_m : jnp.ndarray
69 Normalized gradient function for momentum
70 fi_h : jnp.ndarray
71 Normalized gradient function for heat
72 """
73
74 psi_m = 5.0 * z_over_L
75 psi_h = 5.0 * z_over_L
76
77 fi_m = 1.0 + psi_m
78 fi_h = 1.0 + psi_h
79
80 return psi_m, psi_h, fi_m, fi_h
81
82
83def MOSTunstable(z_over_L):
84 """
85 Compute Monin-Obukhov stability functions for unstable conditions
86
87 Parameters:
88 -----------
89 z_over_L : jnp.ndarray of shape (nx, ny)
90 Stability parameter z/L (height / Obukhov length)
91
92 Returns:
93 --------
94 psi2D_m : jnp.ndarray of shape (nx, ny)
95 Stability function for momentum
96 psi2D_h : jnp.ndarray of shape (nx, ny)
97 Stability function for heat
98 fi2D_m : jnp.ndarray of shape (nx, ny)
99 Normalized gradient function for momentum
100 fi2D_h : jnp.ndarray of shape (nx, ny)
101 Normalized gradient function for heat
102 """
103
104 x = (1 - 15 * z_over_L) ** 0.25
105
106 psi2D_m = (-2 * jnp.log(0.5 * (1 + x)) - jnp.log(0.5 * (1 + x ** 2))
107 + 2 * jnp.arctan(x) - jnp.pi / 2)
108 psi2D_h = -2 * jnp.log(0.5 * (1 + x ** 2))
109
110 fi2D_m = 1.0 / x
111 fi2D_h = (1.0 / x) ** 2
112 return psi2D_m, psi2D_h, fi2D_m, fi2D_h
113
114
115# ============================================================
116# Shared helper: update MOST stability functions
117# ============================================================
118
119def _update_MOSTfunctions(ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0,
120 psi2D_h, psi2D_h0):
121 """
122 Compute Obukhov length and update all MOST stability functions.
123 Used internally by all six surface flux variants.
124
125 Returns updated (psi2D_m, psi2D_m0, psi2D_h, psi2D_h0, fi2D_m, fi2D_h)
126 and invOB.
127 """
128 invOB = -(vonk * g_nondim * qz_sfc_avg) / ((ustar ** 3) * TH_ref)
129 is_stable = qz_sfc_avg <= 0
130
131 z1_over_L = (0.5 * dz) * invOB
132 z0m_over_L = (z0m / z_scale) * invOB
133
134 # When zTemperature > 0 the surface temperature is observed at that screen
135 # height (m), not at z0T. Use it as the reference for psi_h0 and denom_h.
136 _z_ref_T = zTemperature if zTemperature > 0.0 else z0T
137 z_ref_T_over_L = (_z_ref_T / z_scale) * invOB
138
139 psi2D_m, psi2D_h, fi2D_m, fi2D_h = jax.lax.cond(
140 is_stable,
141 lambda _: MOSTstable(z1_over_L),
142 lambda _: MOSTunstable(z1_over_L),
143 operand=None)
144
145 psi2D_m0, _, _, _ = jax.lax.cond(
146 is_stable,
147 lambda _: MOSTstable(z0m_over_L),
148 lambda _: MOSTunstable(z0m_over_L),
149 operand=None)
150
151 _, psi2D_h0, _, _ = jax.lax.cond(
152 is_stable,
153 lambda _: MOSTstable(z_ref_T_over_L),
154 lambda _: MOSTunstable(z_ref_T_over_L),
155 operand=None)
156
157 MOSTfunctions = (psi2D_m, psi2D_m0, psi2D_h, psi2D_h0, fi2D_m, fi2D_h)
158 return MOSTfunctions, invOB
159
160
161# ============================================================
162# optSurfBC = 0 : constant prescribed heat flux
163# ============================================================
164
165@jax.jit
166def SurfaceFlux_HomogeneousConstantFlux(u, v, TH, MOSTfunctions):
167 """
168 optSurfFlux=0, optSurfBC=0: homogeneous surface, constant heat flux.
169
170 Parameters:
171 -----------
172 u, v : jnp.ndarray of shape (nx, ny, nz)
173 TH : jnp.ndarray of shape (nx, ny, nz)
174 MOSTfunctions : tuple of six (nx, ny) arrays
175
176 Returns:
177 --------
178 M_sfc_loc : (nx, ny) surface wind speed
179 ustar : (nx, ny) friction velocity
180 qz_sfc_avg : scalar, non-dimensional surface heat flux (= qz = -u* x th*)
181 invOB : (nx, ny) inverse Obukhov length
182 MOSTfunctions: updated tuple
183 """
184
185 One2D = jnp.ones((nx, ny))
186
187 (psi2D_m, psi2D_m0,
188 psi2D_h, psi2D_h0,
189 fi2D_m, fi2D_h) = MOSTfunctions
190
191 qz_sfc_avg = jnp.mean(qz_sfc)
192
193 M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
194 M_sfc_loc = M_sfc_avg * One2D
195
196 # TH is anomaly; add T_0 to get absolute temperature for Obukhov length.
197 TH_ref = (jnp.mean(TH[:, :, 0]) + T_0_nondim) * One2D
198
199 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
200 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
201
202 MOSTfunctions, invOB = _update_MOSTfunctions(
203 ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
204
205 return M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions
206
207
208@jax.jit
209def SurfaceFlux_HeterogeneousConstantFlux(u, v, TH, MOSTfunctions):
210 """
211 optSurfFlux=1, optSurfBC=0: heterogeneous surface, constant heat flux.
212
213 Returns:
214 --------
215 M_sfc_loc : (nx, ny)
216 ustar : (nx, ny)
217 qz_sfc_avg : scalar, non-dimensional surface heat flux
218 invOB : (nx, ny)
219 MOSTfunctions: updated tuple
220 """
221
222 (psi2D_m, psi2D_m0,
223 psi2D_h, psi2D_h0,
224 fi2D_m, fi2D_h) = MOSTfunctions
225
226 qz_sfc_avg = jnp.mean(qz_sfc)
227
228 M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
229 # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
230 TH_ref = TH[:, :, 0] + T_0_nondim
231
232 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
233 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
234
235 MOSTfunctions, invOB = _update_MOSTfunctions(
236 ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
237
238 return M_sfc_loc, ustar, qz_sfc_avg, invOB, MOSTfunctions
239
240
241# ============================================================
242# optSurfBC = 1 : time-varying prescribed heat flux
243# ============================================================
244
245@jax.jit
246def SurfaceFlux_HomogeneousVaryingFlux(u, v, TH, qz_sfc_t, MOSTfunctions):
247 """
248 optSurfFlux=0, optSurfBC=1: homogeneous surface, time-varying heat flux.
249
250 Parameters:
251 -----------
252 qz_sfc_t : scalar JAX value, non-dimensional heat flux at current timestep
253 (loaded from SurfaceBC.npz series, already non-dimensionalised)
254
255 Returns:
256 --------
257 M_sfc_loc : (nx, ny)
258 ustar : (nx, ny)
259 qz_sfc_2D : (nx, ny) spatially uniform flux field
260 qz_sfc_avg : scalar
261 invOB : (nx, ny)
262 MOSTfunctions: updated tuple
263 """
264
265 One2D = jnp.ones((nx, ny))
266
267 (psi2D_m, psi2D_m0,
268 psi2D_h, psi2D_h0,
269 fi2D_m, fi2D_h) = MOSTfunctions
270
271 qz_sfc_avg = qz_sfc_t
272 qz_sfc_2D = qz_sfc_t * One2D
273
274 M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
275 M_sfc_loc = M_sfc_avg * One2D
276
277 # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
278 TH_ref = (jnp.mean(TH[:, :, 0]) + T_0_nondim) * One2D
279
280 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
281 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
282
283 MOSTfunctions, invOB = _update_MOSTfunctions(
284 ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
285
286 return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
287
288
289@jax.jit
290def SurfaceFlux_HeterogeneousVaryingFlux(u, v, TH, qz_sfc_t, MOSTfunctions):
291 """
292 optSurfFlux=1, optSurfBC=1: heterogeneous surface, time-varying heat flux.
293
294 Parameters:
295 -----------
296 qz_sfc_t : scalar JAX value, non-dimensional heat flux at current timestep
297
298 Returns:
299 --------
300 M_sfc_loc : (nx, ny)
301 ustar : (nx, ny)
302 qz_sfc_2D : (nx, ny)
303 qz_sfc_avg : scalar
304 invOB : (nx, ny)
305 MOSTfunctions: updated tuple
306 """
307
308 One2D = jnp.ones((nx, ny))
309
310 (psi2D_m, psi2D_m0,
311 psi2D_h, psi2D_h0,
312 fi2D_m, fi2D_h) = MOSTfunctions
313
314 qz_sfc_avg = qz_sfc_t
315 qz_sfc_2D = qz_sfc_t * One2D
316
317 M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
318 # TH is anomaly; add T_0 for absolute temperature for Obukhov length.
319 TH_ref = TH[:, :, 0] + T_0_nondim
320
321 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
322 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
323
324 MOSTfunctions, invOB = _update_MOSTfunctions(
325 ustar, qz_sfc_avg, TH_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
326
327 return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
328
329
330# ============================================================
331# optSurfBC = 2 : time-varying prescribed surface temperature
332# ============================================================
333
334@jax.jit
335def SurfaceFlux_HomogeneousPrescribedTemperature(u, v, TH, TH_sfc_t,
336 MOSTfunctions):
337 """
338 optSurfFlux=0, optSurfBC=2: homogeneous surface, prescribed T_s(t).
339
340 The surface heat flux is diagnosed from MOST:
341 qz = u* x vonk x (TH_s - TH_air) / denom_h
342 consistent with qz = -u* x th*, where th* = vonk x (TH_air - TH_s) / denom_h.
343
344 Parameters:
345 -----------
346 TH_sfc_t : scalar JAX value, non-dimensional surface temperature anomaly
347 (theta_sfc - T_0) / TH_scale at current timestep, as returned
348 by Initialize_SurfaceBC for optSurfBC=2
349
350 Returns:
351 --------
352 M_sfc_loc : (nx, ny)
353 ustar : (nx, ny)
354 qz_sfc_2D : (nx, ny) diagnosed surface heat flux
355 qz_sfc_avg : scalar
356 invOB : (nx, ny)
357 MOSTfunctions: updated tuple
358 """
359
360 One2D = jnp.ones((nx, ny))
361
362 (psi2D_m, psi2D_m0,
363 psi2D_h, psi2D_h0,
364 fi2D_m, fi2D_h) = MOSTfunctions
365
366 # Planar-mean surface wind speed
367 M_sfc_avg = jnp.mean(jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2))
368 M_sfc_loc = M_sfc_avg * One2D
369
370 # TH is stored as anomaly (TH - T_0); TH_sfc_t is also an anomaly from
371 # Initialize_SurfaceBC. Both are anomalies so the difference is direct.
372 TH_air_anom_avg = jnp.mean(TH[:, :, 0])
373 TH_air_loc = (TH_air_anom_avg + T_0_nondim) * One2D # absolute for MOST
374
375 # Friction velocity (with floor to prevent near-zero division)
376 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
377 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
378
379 # Diagnose surface heat flux — both TH_sfc_t and TH_air_anom_avg are anomalies
380 # qz = -u* x th*, th* = vonk x (TH_air - TH_s) / denom_h (>0 for stable)
381 _z_ref_T = zTemperature if zTemperature > 0.0 else z0T
382 denom_h = jnp.log(0.5 * dz * z_scale / _z_ref_T) + psi2D_h - psi2D_h0
383 qz_sfc_2D = ustar * vonk * (TH_sfc_t - TH_air_anom_avg) * One2D / denom_h
384 qz_sfc_avg = jnp.mean(qz_sfc_2D)
385
386 MOSTfunctions, invOB = _update_MOSTfunctions(
387 ustar, qz_sfc_avg, TH_air_loc, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
388
389 return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
390
391
392@jax.jit
393def SurfaceFlux_HeterogeneousPrescribedTemperature(u, v, TH, TH_sfc_t,
394 MOSTfunctions):
395 """
396 optSurfFlux=1, optSurfBC=2: heterogeneous surface, prescribed T_s(t).
397
398 Uses local (per-column) wind speed and air temperature.
399 TH_sfc_t is spatially uniform but temporally varying.
400
401 Parameters:
402 -----------
403 TH_sfc_t : scalar JAX value, non-dimensional surface temperature anomaly
404 (theta_sfc - T_0) / TH_scale at current timestep, as returned
405 by Initialize_SurfaceBC for optSurfBC=2
406
407 Returns:
408 --------
409 M_sfc_loc : (nx, ny)
410 ustar : (nx, ny)
411 qz_sfc_2D : (nx, ny) diagnosed surface heat flux
412 qz_sfc_avg : scalar
413 invOB : (nx, ny)
414 MOSTfunctions: updated tuple
415 """
416
417 (psi2D_m, psi2D_m0,
418 psi2D_h, psi2D_h0,
419 fi2D_m, fi2D_h) = MOSTfunctions
420
421 # Local surface wind speed
422 M_sfc_loc = jnp.sqrt((u[:, :, 0] + Ugal) ** 2 + v[:, :, 0] ** 2)
423
424 # TH is stored as anomaly (TH - T_0); TH_sfc_t is also an anomaly.
425 TH_air_anom_loc = TH[:, :, 0]
426
427 # Friction velocity
428 denom_m = jnp.log(0.5 * dz * z_scale / z0m) + psi2D_m - psi2D_m0
429 ustar = jnp.maximum(vonk * M_sfc_loc / denom_m, 1e-3)
430
431 # Diagnose surface heat flux — both TH_sfc_t and TH_air_anom_loc are anomalies
432 _z_ref_T = zTemperature if zTemperature > 0.0 else z0T
433 denom_h = jnp.log(0.5 * dz * z_scale / _z_ref_T) + psi2D_h - psi2D_h0
434 qz_sfc_2D = ustar * vonk * (TH_sfc_t - TH_air_anom_loc) / denom_h
435 qz_sfc_avg = jnp.mean(qz_sfc_2D)
436
437 # Absolute TH_air as reference for Obukhov length
438 TH_air_ref = (jnp.mean(TH_air_anom_loc) + T_0_nondim) * jnp.ones((nx, ny))
439
440 MOSTfunctions, invOB = _update_MOSTfunctions(
441 ustar, qz_sfc_avg, TH_air_ref, psi2D_m, psi2D_m0, psi2D_h, psi2D_h0)
442
443 return M_sfc_loc, ustar, qz_sfc_2D, qz_sfc_avg, invOB, MOSTfunctions
444
445
446# ============================================================
447# optMoistureSurfBC = 2 : time-varying prescribed surface humidity
448# ============================================================
449
450@jax.jit
451def SurfaceMoistureFlux_HomogeneousPrescribedQ(Q, ustar, Q_sfc_t, MOSTfunctions):
452 """
453 optMoistureSurfBC=2, optSurfFlux=0: diagnose surface moisture flux from
454 prescribed surface specific humidity (spatially homogeneous).
455 Uses the same MOST stability functions as heat
456 (turbulent Schmidt number = turbulent Prandtl number).
457
458 Parameters:
459 -----------
460 Q : jnp.ndarray (nx, ny, nz) — specific humidity (kg/kg)
461 ustar : jnp.ndarray (nx, ny) — friction velocity
462 Q_sfc_t : scalar — prescribed surface Q at current timestep (kg/kg)
463 MOSTfunctions : tuple of six (nx, ny) stability arrays
464
465 Returns:
466 --------
467 qm_sfc_2D : jnp.ndarray (nx, ny) — surface moisture flux w'q' (non-dim)
468 """
469 One2D = jnp.ones((nx, ny))
470 (_, _, psi2D_h, psi2D_h0, _, _) = MOSTfunctions
471
472 Q_air_avg = jnp.mean(Q[:, :, 0])
473 _z_ref_Q = zMoisture if zMoisture > 0.0 else z0T
474 denom_qm = jnp.log(0.5 * dz * z_scale / _z_ref_Q) + psi2D_h - psi2D_h0
475 qm_sfc_2D = ustar * vonk * (Q_sfc_t - Q_air_avg) * One2D / denom_qm
476
477 return qm_sfc_2D
478
479
480@jax.jit
481def SurfaceMoistureFlux_HeterogeneousPrescribedQ(Q, ustar, Q_sfc_t, MOSTfunctions):
482 """
483 optMoistureSurfBC=2, optSurfFlux=1: per-column surface moisture flux from
484 prescribed surface specific humidity.
485
486 Parameters:
487 -----------
488 Q : jnp.ndarray (nx, ny, nz) — specific humidity (kg/kg)
489 ustar : jnp.ndarray (nx, ny) — friction velocity (per column)
490 Q_sfc_t : scalar — prescribed surface Q (kg/kg, spatially uniform)
491 MOSTfunctions : tuple of six (nx, ny) stability arrays
492
493 Returns:
494 --------
495 qm_sfc_2D : jnp.ndarray (nx, ny) — surface moisture flux (per column)
496 """
497 (_, _, psi2D_h, psi2D_h0, _, _) = MOSTfunctions
498
499 Q_air_loc = Q[:, :, 0]
500 _z_ref_Q = zMoisture if zMoisture > 0.0 else z0T
501 denom_qm = jnp.log(0.5 * dz * z_scale / _z_ref_Q) + psi2D_h - psi2D_h0
502 qm_sfc_2D = ustar * vonk * (Q_sfc_t - Q_air_loc) / denom_qm
503
504 return qm_sfc_2D