Derivative Computations

Source Code: Derivatives

Derivatives.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Derivatives.py
 18===========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: computes derivatives in x, y and z directions
 25"""
 26
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31
 32import jax
 33import jax.numpy as jnp
 34
 35# Import configuration from namelist
 36from ..config.ConfigLoader import *
 37
 38# Import derived variables
 39from ..config.DerivedVars import *
 40
 41
 42# ============================================================
 43#  Compute spatial derivatives
 44# ============================================================
 45
 46@jax.jit
 47def velocityGradients(
 48        u, v, w,
 49        u_fft, v_fft, w_fft,
 50        kx2, ky2,
 51        ustar, M_sfc_loc, MOSTfunctions, ZeRo3D):
 52    """
 53    Parameters:
 54    -----------
 55    u, v, w : ndarray of shape (nx, ny, nz)
 56        Velocity components in x, y, and z directions in physical space
 57    u_fft, v_fft, w_fft : ndarray of shape (nx, ny//2 + 1, nz)
 58        Pre-computed Fourier transforms of the velocity components
 59    kx2, ky2 : ndarray of shape (nx, ny//2 + 1, nz)
 60        Pre-computed wavenumber arrays for spectral derivatives
 61    M_sfc_loc : ndarray of shape (nx, ny)
 62        Near-surface wind speed used for boundary conditions
 63    ustar : ndarray of shape (nx, ny)
 64        Friction velocity for boundary condition calculations
 65    ZeRo3D : ndarray of shape (nx, ny, nz)
 66        Pre-allocated zero array for storing derivative results
 67
 68    Returns:
 69    --------
 70    dudx, dvdx, dwdx :
 71        x-derivatives of the velocity components
 72    dudy, dvdy, dwdy :
 73        y-derivatives of the velocity components
 74    dudz, dvdz, dwdz :
 75        z-derivatives of the velocity components
 76
 77    Notes:
 78    ------
 79    - Horizontal derivatives (x, y) are computed using spectral methods via `Derivxy`
 80    - Vertical derivatives (z) are computed using finite differences via `Derivz_M`
 81    - Boundary conditions for vertical derivatives are handled in `Derivz_M`
 82    """
 83
 84    # X derivatives
 85    dudx, dvdx, dwdx = (Derivxy(u_fft, kx2),
 86                        Derivxy(v_fft, kx2),
 87                        Derivxy(w_fft, kx2))
 88
 89    # Y derivatives
 90    dudy, dvdy, dwdy = (Derivxy(u_fft, ky2),
 91                        Derivxy(v_fft, ky2),
 92                        Derivxy(w_fft, ky2))
 93
 94    # unpack MOST functions
 95    (psi2D_m, psi2D_m0,
 96     psi2D_h, psi2D_h0,
 97     fi2D_m, fi2D_h) = MOSTfunctions
 98
 99    # Z derivatives
100    dudz, dvdz, dwdz = Derivz_M(u, v, w, ustar, M_sfc_loc, fi2D_m, ZeRo3D)
101
102    # Return all derivatives
103    return dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz
104
105
106@jax.jit
107def potentialTemperatureGradients(
108        TH,
109        kx2, ky2,
110        ustar, qz_sfc, MOSTfunctions, ZeRo3D):
111    """
112    Parameters:
113    -----------
114    TH : ndarray of shape (nx, ny, nz)
115        Potential temperature in physical space
116    kx2, ky2 : ndarray of shape (nx, ny//2 + 1, nz)
117        Pre-computed wavenumber arrays for spectral derivatives
118    qz_sfc : ndarray of shape (nx, ny)
119        Surface sensible heat flux, unit: K m/s
120    ustar : ndarray of shape (nx, ny)
121        Friction velocity for boundary condition calculations
122    ZeRo3D : ndarray of shape (nx, ny, nz)
123        Pre-allocated zero array for storing derivative results
124
125    Returns:
126    --------
127    dTHdx, dTHdy, dTHdz :
128        x,y,z-derivatives of potential temperature
129
130    Notes:
131    ------
132    - Horizontal derivatives (x, y) are computed spectrally from TH, which is
133      stored as anomaly TH' = TH - T_0 throughout the simulation
134    - Vertical derivatives (z) are computed using finite differences via `Derivz_TH`
135    - Boundary conditions for vertical derivatives are handled in `Derivz_TH`
136    """
137
138    # TH is stored as anomaly (TH - T_0), so no base-state subtraction needed.
139    # FFT butterfly operations now work on small values (~0-5 K) rather than ~265 K.
140    TH_pert_fft = jnp.fft.rfft2(TH, axes=(0, 1))
141    dTHdx = Derivxy(TH_pert_fft, kx2)
142    dTHdy = Derivxy(TH_pert_fft, ky2)
143
144    # unpack MOST functions
145    (psi2D_m, psi2D_m0,
146     psi2D_h, psi2D_h0,
147     fi2D_m, fi2D_h) = MOSTfunctions
148
149    # Z derivatives — finite differences; T_0 cancels in diff() so no benefit
150    dTHdz = Derivz_TH(TH, ustar, qz_sfc, fi2D_h, ZeRo3D)
151
152    # Return all derivatives
153    return dTHdx, dTHdy, dTHdz
154
155
156# ============================================================
157#  Compute spectral derivatives in x or y direction
158# ============================================================
159
160@jax.jit
161def Derivxy(F_fft, kxy2):
162    """
163    Parameters:
164    -----------
165    F_fft : ndarray with shape (nx, ny//2 + 1, nz)
166        Fourier-transformed 3D field.
167    kxy2 : ndarray shape (nx, ny//2 + 1, nz)
168        Pre-computed wavenumbers (use kx2 for dudx; ky2 for dudy)
169
170    Returns:
171    --------
172    dFdxy : ndarray
173        x- or y-derivative field
174
175    Notes:
176    ------
177    - Nyquist frequencies are explicitly set to zero
178    """
179    # Compute derivative in Fourier space and transform back
180    dFdxy = jnp.fft.irfft2(1j * kxy2 * F_fft, axes=(0, 1), s=(nx, ny))
181
182    return dFdxy
183
184
185# ============================================================
186#  Finite difference-based vertical derivatives for velocity
187# ============================================================
188
189@jax.jit
190def Derivz_M(u, v, w, ustar, M_sfc_loc, fi2D_m, ZeRo3D):
191    """
192    Parameters:
193    -----------
194    u : ndarray of shape (nx, ny, nz)
195        Longitudinal velocity component
196    v : ndarray of shape (nx, ny, nz)
197        Lateral velocity component
198    w : ndarray of shape (nx, ny, nz)
199        Vertical velocity component
200    M : ndarray of shape (nx, ny)
201        Near-surface wind speed
202    fi2D : ndarray  of shape (nx, ny)
203        Normalized gradient function
204    ustar : ndarray of shape (nx, ny)
205        Friction velocity
206
207    Returns:
208    --------
209    dudz : Vertical derivative of u
210    dvdz : Vertical derivative of v
211    dwdz : Vertical derivative of w
212    """
213
214    # Initialize arrays with zeros
215    dudz = ZeRo3D.copy()
216    dvdz = ZeRo3D.copy()
217    dwdz = ZeRo3D.copy()
218
219    # Compute interior derivatives using central differences
220    dudz = dudz.at[:, :, 1:nz - 1].set(jnp.diff(u[:, :, 0:nz - 1], axis=2) * idz)
221    dvdz = dvdz.at[:, :, 1:nz - 1].set(jnp.diff(v[:, :, 0:nz - 1], axis=2) * idz)
222
223    # Bottom boundary conditions using Monin-Obukhov similarity
224    dudz = dudz.at[:, :, 0].set(
225        fi2D_m * ustar * (u[:, :, 0] + Ugal) / (M_sfc_loc * vonk * 0.5 * dz)
226    )
227    dvdz = dvdz.at[:, :, 0].set(
228        fi2D_m * ustar * v[:, :, 0] / (M_sfc_loc * vonk * 0.5 * dz)
229    )
230
231    # Vertical velocity derivatives
232    dwdz = dwdz.at[:, :, 0:nz - 1].set(jnp.diff(w, axis=2) * idz)
233    dwdz = dwdz.at[:, :, nz - 1].set(0.0)  # Top boundary condition
234
235    return dudz, dvdz, dwdz
236
237
238# ============================================================
239#  Vertical derivatives for temperature
240# ============================================================
241
242@jax.jit
243def Derivz_TH(TH, ustar, qz_sfc, fi2D_h, ZeRo3D):
244    """
245    Parameters:
246    -----------
247    TH : ndarray of shape (nx, ny, nz)
248        Potential temperature
249    fi2D_h : ndarray of shape (nx, ny)
250        Normalized gradient function for heat
251    qz_sfc : ndarray of shape (nx, ny)
252        Surface sensible heat flux, unit: K m/s
253    ustar : ndarray of shape (nx, ny)
254        Friction velocity
255
256    Returns:
257    --------
258    ndarray
259        dTHdz : Vertical derivative of potential temperature
260    """
261
262    # Initialize array with zeros
263    dTHdz = ZeRo3D.copy()
264
265    # Compute interior derivatives
266    dTHdz = dTHdz.at[:, :, 1:nz].set(jnp.diff(TH, axis=2) * idz)
267
268    # Bottom boundary condition using Monin-Obukhov similarity
269    dTHdz = dTHdz.at[:, :, 0].set(
270        fi2D_h * (-qz_sfc / ustar) / (vonk * 0.5 * dz)
271    )
272
273    return dTHdz
274
275
276# ============================================================
277#  Vertical derivatives for a generic variable on uvp nodes
278# ============================================================
279
280@jax.jit
281def moistureGradients(
282        Q,
283        kx2, ky2,
284        ustar, qm_sfc, MOSTfunctions, ZeRo3D):
285    """
286    Parameters:
287    -----------
288    Q : ndarray of shape (nx, ny, nz) — specific humidity (kg/kg)
289    kx2, ky2 : ndarray — wavenumber arrays for spectral derivatives
290    ustar : ndarray (nx, ny) — friction velocity
291    qm_sfc : ndarray (nx, ny) — surface moisture flux (non-dim)
292    MOSTfunctions : tuple of six (nx, ny) stability arrays
293    ZeRo3D : ndarray (nx, ny, nz) — pre-allocated zero array
294
295    Returns:
296    --------
297    dQdx, dQdy, dQdz : ndarray (nx, ny, nz)
298    """
299    Q_fft = jnp.fft.rfft2(Q, axes=(0, 1))
300    dQdx  = Derivxy(Q_fft, kx2)
301    dQdy  = Derivxy(Q_fft, ky2)
302
303    (_, _, _, _, _, fi2D_h) = MOSTfunctions
304    dQdz = Derivz_TH(Q, ustar, qm_sfc, fi2D_h, ZeRo3D)
305
306    return dQdx, dQdy, dQdz
307
308
309@jax.jit
310def Derivz_Generic_uvp(F, ZeRo3D):
311    """
312    Parameters:
313    -----------
314    F : ndarray of shape (nx, ny, nz)
315        Generic variable defined on uvp nodes
316
317    Returns:
318    --------
319    ndarray
320        dFdz : Vertical derivative of F
321    """
322
323    # Initialize array with zeros
324    dFdz = ZeRo3D.copy()
325
326    # Compute interior derivatives
327    dFdz = dFdz.at[:, :, 1:nz].set(jnp.diff(F, axis=2) * idz)
328
329    # Bottom boundary condition
330    dFdz = dFdz.at[:, :, 0].set(0)
331
332    return dFdz
333
334
335# ============================================================
336#  Vertical derivatives for a generic variable on w nodes
337# ============================================================
338
339@jax.jit
340def Derivz_Generic_w(F, ZeRo3D):
341    """
342    Parameters:
343    -----------
344    F : ndarray of shape (nx, ny, nz)
345        Generic variable defined on w nodes
346
347    Returns:
348    --------
349    ndarray
350        dFdz : Vertical derivative of F
351    """
352
353    # Initialize array with zeros
354    dFdz = ZeRo3D.copy()
355
356    # Compute interior derivatives
357    dFdz = dFdz.at[:, :, 0:nz - 1].set(jnp.diff(F, axis=2) * idz)
358
359    # Top boundary condition
360    dFdz = dFdz.at[:, :, nz - 1].set(0)
361
362    return dFdz