Dealiasing Computations

Source Code: Dealiasing

Dealiasing.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Dealiasing.py
 18===========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: performs dealiasing
 25"""
 26
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31
 32import jax
 33import jax.numpy as jnp
 34
 35# Import configuration from namelist
 36from ..config.ConfigLoader import *
 37
 38# Import derived variables
 39from ..config.DerivedVars import *
 40
 41# Import constants
 42from ..initialization.Preprocess import Constant
 43mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
 44
 45
 46# ============================================================
 47# First function for dealiasing
 48# ============================================================
 49
 50@jax.jit
 51def Dealias1(F_fft, ZeRo3D_pad_fft):
 52    """
 53    Parameters:
 54    -----------
 55    F_fft : jnp.ndarray
 56        Fourier transformed input array
 57    ZeRo3D_pad_fft : jnp.ndarray
 58        Pre-allocated zero-padded array
 59
 60    Returns:
 61    --------
 62    F_pad : jnp.ndarray
 63        Dealiased padded array in spatial domain
 64    """
 65
 66    # Allocate padded array
 67    F_pad_fft = ZeRo3D_pad_fft.copy()
 68
 69    # First quadrant
 70    F_pad_fft = F_pad_fft.at[:nx_rfft, :ny_rfft, :].set(F_fft[:nx_rfft, :ny_rfft, :])
 71
 72    # Second quadrant
 73    F_pad_fft = F_pad_fft.at[mx - (nx_rfft-1):, :ny_rfft, :].set(F_fft[(nx_rfft-1):, :ny_rfft, :])
 74
 75    # Transform back to spatial domain using irfft2
 76    F_pad = jnp.fft.irfft2(F_pad_fft, axes=(0, 1), s=(mx, my))
 77
 78    return F_pad
 79
 80
 81# ============================================================
 82# Second function for dealiasing
 83# ============================================================
 84
 85@jax.jit
 86def Dealias2(F_pad_fft, ZeRo3D_fft):
 87    """
 88    Parameters:
 89    -----------
 90    F_pad_fft : jnp.ndarray
 91        Fourier transformed padded array
 92    ZeRo3D_fft : jnp.ndarray
 93        Pre-allocated zero array for Fourier operations
 94
 95    Returns:
 96    --------
 97    F : jnp.ndarray
 98        Dealiased output array on regular grid
 99    Note: Nyquist is explicitly set to zero.
100    """
101
102    # Allocate array
103    F_fft = ZeRo3D_fft.copy()
104
105    # First quadrant
106    F_fft = F_fft.at[:(nx_rfft-1), :(ny_rfft-1), :].set(F_pad_fft[:(nx_rfft-1), :(ny_rfft-1), :])
107
108    # Second quadrant
109    F_fft = F_fft.at[nx_rfft:, :(ny_rfft-1), :].set(F_pad_fft[mx - (nx_rfft-2):, :(ny_rfft-1), :])
110
111    # Transform back to physical space and apply 9/4 scaling
112    F = (9.0 / 4.0) * jnp.fft.irfft2(F_fft, axes=(0, 1), s=(nx, ny))
113
114    return F