Derivative Computations
Source Code: Derivatives
Derivatives.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Derivatives.py
18===========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-3
24:Description: computes derivatives in x, y and z directions
25"""
26
27
28# ============================================================
29# Imports
30# ============================================================
31
32import jax
33import jax.numpy as jnp
34
35# Import configuration from namelist
36from ..config.ConfigLoader import *
37
38# Import derived variables
39from ..config.DerivedVars import *
40
41
42# ============================================================
43# Compute spatial derivatives
44# ============================================================
45
46@jax.jit
47def velocityGradients(
48 u, v, w,
49 u_fft, v_fft, w_fft,
50 kx2, ky2,
51 ustar, M_sfc_loc, MOSTfunctions, ZeRo3D):
52 """
53 Parameters:
54 -----------
55 u, v, w : ndarray of shape (nx, ny, nz)
56 Velocity components in x, y, and z directions in physical space
57 u_fft, v_fft, w_fft : ndarray of shape (nx, ny//2 + 1, nz)
58 Pre-computed Fourier transforms of the velocity components
59 kx2, ky2 : ndarray of shape (nx, ny//2 + 1, nz)
60 Pre-computed wavenumber arrays for spectral derivatives
61 M_sfc_loc : ndarray of shape (nx, ny)
62 Near-surface wind speed used for boundary conditions
63 ustar : ndarray of shape (nx, ny)
64 Friction velocity for boundary condition calculations
65 ZeRo3D : ndarray of shape (nx, ny, nz)
66 Pre-allocated zero array for storing derivative results
67
68 Returns:
69 --------
70 dudx, dvdx, dwdx :
71 x-derivatives of the velocity components
72 dudy, dvdy, dwdy :
73 y-derivatives of the velocity components
74 dudz, dvdz, dwdz :
75 z-derivatives of the velocity components
76
77 Notes:
78 ------
79 - Horizontal derivatives (x, y) are computed using spectral methods via `Derivxy`
80 - Vertical derivatives (z) are computed using finite differences via `Derivz_M`
81 - Boundary conditions for vertical derivatives are handled in `Derivz_M`
82 """
83
84 # X derivatives
85 dudx, dvdx, dwdx = (Derivxy(u_fft, kx2),
86 Derivxy(v_fft, kx2),
87 Derivxy(w_fft, kx2))
88
89 # Y derivatives
90 dudy, dvdy, dwdy = (Derivxy(u_fft, ky2),
91 Derivxy(v_fft, ky2),
92 Derivxy(w_fft, ky2))
93
94 # unpack MOST functions
95 (psi2D_m, psi2D_m0,
96 psi2D_h, psi2D_h0,
97 fi2D_m, fi2D_h) = MOSTfunctions
98
99 # Z derivatives
100 dudz, dvdz, dwdz = Derivz_M(u, v, w, ustar, M_sfc_loc, fi2D_m, ZeRo3D)
101
102 # Return all derivatives
103 return dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz
104
105
106@jax.jit
107def potentialTemperatureGradients(
108 TH,
109 kx2, ky2,
110 ustar, qz_sfc, MOSTfunctions, ZeRo3D):
111 """
112 Parameters:
113 -----------
114 TH : ndarray of shape (nx, ny, nz)
115 Potential temperature in physical space
116 kx2, ky2 : ndarray of shape (nx, ny//2 + 1, nz)
117 Pre-computed wavenumber arrays for spectral derivatives
118 qz_sfc : ndarray of shape (nx, ny)
119 Surface sensible heat flux, unit: K m/s
120 ustar : ndarray of shape (nx, ny)
121 Friction velocity for boundary condition calculations
122 ZeRo3D : ndarray of shape (nx, ny, nz)
123 Pre-allocated zero array for storing derivative results
124
125 Returns:
126 --------
127 dTHdx, dTHdy, dTHdz :
128 x,y,z-derivatives of potential temperature
129
130 Notes:
131 ------
132 - Horizontal derivatives (x, y) are computed spectrally from TH, which is
133 stored as anomaly TH' = TH - T_0 throughout the simulation
134 - Vertical derivatives (z) are computed using finite differences via `Derivz_TH`
135 - Boundary conditions for vertical derivatives are handled in `Derivz_TH`
136 """
137
138 # TH is stored as anomaly (TH - T_0), so no base-state subtraction needed.
139 # FFT butterfly operations now work on small values (~0-5 K) rather than ~265 K.
140 TH_pert_fft = jnp.fft.rfft2(TH, axes=(0, 1))
141 dTHdx = Derivxy(TH_pert_fft, kx2)
142 dTHdy = Derivxy(TH_pert_fft, ky2)
143
144 # unpack MOST functions
145 (psi2D_m, psi2D_m0,
146 psi2D_h, psi2D_h0,
147 fi2D_m, fi2D_h) = MOSTfunctions
148
149 # Z derivatives — finite differences; T_0 cancels in diff() so no benefit
150 dTHdz = Derivz_TH(TH, ustar, qz_sfc, fi2D_h, ZeRo3D)
151
152 # Return all derivatives
153 return dTHdx, dTHdy, dTHdz
154
155
156# ============================================================
157# Compute spectral derivatives in x or y direction
158# ============================================================
159
160@jax.jit
161def Derivxy(F_fft, kxy2):
162 """
163 Parameters:
164 -----------
165 F_fft : ndarray with shape (nx, ny//2 + 1, nz)
166 Fourier-transformed 3D field.
167 kxy2 : ndarray shape (nx, ny//2 + 1, nz)
168 Pre-computed wavenumbers (use kx2 for dudx; ky2 for dudy)
169
170 Returns:
171 --------
172 dFdxy : ndarray
173 x- or y-derivative field
174
175 Notes:
176 ------
177 - Nyquist frequencies are explicitly set to zero
178 """
179 # Compute derivative in Fourier space and transform back
180 dFdxy = jnp.fft.irfft2(1j * kxy2 * F_fft, axes=(0, 1), s=(nx, ny))
181
182 return dFdxy
183
184
185# ============================================================
186# Finite difference-based vertical derivatives for velocity
187# ============================================================
188
189@jax.jit
190def Derivz_M(u, v, w, ustar, M_sfc_loc, fi2D_m, ZeRo3D):
191 """
192 Parameters:
193 -----------
194 u : ndarray of shape (nx, ny, nz)
195 Longitudinal velocity component
196 v : ndarray of shape (nx, ny, nz)
197 Lateral velocity component
198 w : ndarray of shape (nx, ny, nz)
199 Vertical velocity component
200 M : ndarray of shape (nx, ny)
201 Near-surface wind speed
202 fi2D : ndarray of shape (nx, ny)
203 Normalized gradient function
204 ustar : ndarray of shape (nx, ny)
205 Friction velocity
206
207 Returns:
208 --------
209 dudz : Vertical derivative of u
210 dvdz : Vertical derivative of v
211 dwdz : Vertical derivative of w
212 """
213
214 # Initialize arrays with zeros
215 dudz = ZeRo3D.copy()
216 dvdz = ZeRo3D.copy()
217 dwdz = ZeRo3D.copy()
218
219 # Compute interior derivatives using central differences
220 dudz = dudz.at[:, :, 1:nz - 1].set(jnp.diff(u[:, :, 0:nz - 1], axis=2) * idz)
221 dvdz = dvdz.at[:, :, 1:nz - 1].set(jnp.diff(v[:, :, 0:nz - 1], axis=2) * idz)
222
223 # Bottom boundary conditions using Monin-Obukhov similarity
224 dudz = dudz.at[:, :, 0].set(
225 fi2D_m * ustar * (u[:, :, 0] + Ugal) / (M_sfc_loc * vonk * 0.5 * dz)
226 )
227 dvdz = dvdz.at[:, :, 0].set(
228 fi2D_m * ustar * v[:, :, 0] / (M_sfc_loc * vonk * 0.5 * dz)
229 )
230
231 # Vertical velocity derivatives
232 dwdz = dwdz.at[:, :, 0:nz - 1].set(jnp.diff(w, axis=2) * idz)
233 dwdz = dwdz.at[:, :, nz - 1].set(0.0) # Top boundary condition
234
235 return dudz, dvdz, dwdz
236
237
238# ============================================================
239# Vertical derivatives for temperature
240# ============================================================
241
242@jax.jit
243def Derivz_TH(TH, ustar, qz_sfc, fi2D_h, ZeRo3D):
244 """
245 Parameters:
246 -----------
247 TH : ndarray of shape (nx, ny, nz)
248 Potential temperature
249 fi2D_h : ndarray of shape (nx, ny)
250 Normalized gradient function for heat
251 qz_sfc : ndarray of shape (nx, ny)
252 Surface sensible heat flux, unit: K m/s
253 ustar : ndarray of shape (nx, ny)
254 Friction velocity
255
256 Returns:
257 --------
258 ndarray
259 dTHdz : Vertical derivative of potential temperature
260 """
261
262 # Initialize array with zeros
263 dTHdz = ZeRo3D.copy()
264
265 # Compute interior derivatives
266 dTHdz = dTHdz.at[:, :, 1:nz].set(jnp.diff(TH, axis=2) * idz)
267
268 # Bottom boundary condition using Monin-Obukhov similarity
269 dTHdz = dTHdz.at[:, :, 0].set(
270 fi2D_h * (-qz_sfc / ustar) / (vonk * 0.5 * dz)
271 )
272
273 return dTHdz
274
275
276# ============================================================
277# Vertical derivatives for a generic variable on uvp nodes
278# ============================================================
279
280@jax.jit
281def moistureGradients(
282 Q,
283 kx2, ky2,
284 ustar, qm_sfc, MOSTfunctions, ZeRo3D):
285 """
286 Parameters:
287 -----------
288 Q : ndarray of shape (nx, ny, nz) — specific humidity (kg/kg)
289 kx2, ky2 : ndarray — wavenumber arrays for spectral derivatives
290 ustar : ndarray (nx, ny) — friction velocity
291 qm_sfc : ndarray (nx, ny) — surface moisture flux (non-dim)
292 MOSTfunctions : tuple of six (nx, ny) stability arrays
293 ZeRo3D : ndarray (nx, ny, nz) — pre-allocated zero array
294
295 Returns:
296 --------
297 dQdx, dQdy, dQdz : ndarray (nx, ny, nz)
298 """
299 Q_fft = jnp.fft.rfft2(Q, axes=(0, 1))
300 dQdx = Derivxy(Q_fft, kx2)
301 dQdy = Derivxy(Q_fft, ky2)
302
303 (_, _, _, _, _, fi2D_h) = MOSTfunctions
304 dQdz = Derivz_TH(Q, ustar, qm_sfc, fi2D_h, ZeRo3D)
305
306 return dQdx, dQdy, dQdz
307
308
309@jax.jit
310def Derivz_Generic_uvp(F, ZeRo3D):
311 """
312 Parameters:
313 -----------
314 F : ndarray of shape (nx, ny, nz)
315 Generic variable defined on uvp nodes
316
317 Returns:
318 --------
319 ndarray
320 dFdz : Vertical derivative of F
321 """
322
323 # Initialize array with zeros
324 dFdz = ZeRo3D.copy()
325
326 # Compute interior derivatives
327 dFdz = dFdz.at[:, :, 1:nz].set(jnp.diff(F, axis=2) * idz)
328
329 # Bottom boundary condition
330 dFdz = dFdz.at[:, :, 0].set(0)
331
332 return dFdz
333
334
335# ============================================================
336# Vertical derivatives for a generic variable on w nodes
337# ============================================================
338
339@jax.jit
340def Derivz_Generic_w(F, ZeRo3D):
341 """
342 Parameters:
343 -----------
344 F : ndarray of shape (nx, ny, nz)
345 Generic variable defined on w nodes
346
347 Returns:
348 --------
349 ndarray
350 dFdz : Vertical derivative of F
351 """
352
353 # Initialize array with zeros
354 dFdz = ZeRo3D.copy()
355
356 # Compute interior derivatives
357 dFdz = dFdz.at[:, :, 0:nz - 1].set(jnp.diff(F, axis=2) * idz)
358
359 # Top boundary condition
360 dFdz = dFdz.at[:, :, nz - 1].set(0)
361
362 return dFdz