Filtering Computations

Source Code: Filtering

Filtering.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Filtering.py
 18===========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: performs main filtering operations for LES
 25"""
 26
 27
 28# ============================================================
 29#  Imports
 30# ============================================================
 31
 32import jax
 33import jax.numpy as jnp
 34
 35# Import configuration from namelist
 36from ..config.ConfigLoader import *
 37
 38# Import derived variables
 39from ..config.DerivedVars import *
 40
 41# Import constants and arrays
 42from ..initialization.Preprocess import ZeRo3D_fftIni
 43
 44
 45# ============================================================
 46# Explicit filtering (for FGR = 1, remove Nyquist)
 47# ============================================================
 48
 49@jax.jit
 50def Filtering_Explicit(F_fft):
 51    """
 52    Parameters:
 53    -----------
 54    F_fft : ndarray with shape (nx, ny/2, nz)
 55        rfft2 of a field (could be velocity components or scalars)
 56
 57    Returns:
 58    --------
 59    F_new : ndarray with shape (nx, ny, nz)
 60        Filtered field
 61    F_fft_new : ndarray with shape (nx, ny_rfft, nz)
 62        Filtered field in Fourier space
 63
 64    Notes:
 65    ------
 66    - For FGR = 1 (implicit filtering), Nyquist frequencies are removed
 67    """
 68
 69    # Calculate inverse of 2*FGR for cutoff wavenumber
 70    Inv_2FGR = 1.0 / (2.0 * FGR)
 71
 72    # Calculate cutoff indices for real FFT
 73    nx_cut = round(nx * Inv_2FGR)
 74    ny_cut = round(ny * Inv_2FGR)
 75
 76    # Initialize zero array for filtered spectrum
 77    F_fft_new = ZeRo3D_fftIni()
 78
 79    # Apply spectral cutoff filter
 80    # First quadrant
 81    F_fft_new = F_fft_new.at[:nx_cut, :ny_cut, :].set(F_fft[:nx_cut, :ny_cut, :])
 82
 83    # Second quadrant
 84    F_fft_new = F_fft_new.at[nx - nx_cut + 1:, :ny_cut, :].set(F_fft[nx - nx_cut + 1:, :ny_cut, :])
 85
 86    # Transform back to physical space
 87    F_new = jnp.fft.irfft2(F_fft_new, axes=(0, 1), s=(nx, ny))
 88
 89    return F_new, F_fft_new
 90
 91
 92# ============================================================
 93# Level 1 filtering (filter width = FGR*TFR)
 94# ============================================================
 95
 96@jax.jit
 97def Filtering_Level1(F_fft):
 98    """
 99    Parameters:
100    -----------
101    F_fft : ndarray with shape (nx, ny/2, nz)
102        rfft2 of a field (could be velocity components or scalars)
103
104    Returns:
105    --------
106    F_hat : ndarray with shape (nx, ny, nz)
107        Filtered field
108    """
109
110    # Calculate cutoff wavenumbers for filtering
111    mr = round(nx / (2 * FGR * TFR))  # Cutoff in x-direction for rfft
112    mc = round(ny / (2 * FGR * TFR))  # Cutoff in y-direction for rfft
113
114    # Initialize zero array for filtered spectrum
115    F_fft_hat = ZeRo3D_fftIni()
116
117    # Apply spectral cutoff filter
118    # First quadrant
119    F_fft_hat = F_fft_hat.at[:mr, :mc, :].set(F_fft[:mr, :mc, :])
120
121    # Second quadrant
122    F_fft_hat = F_fft_hat.at[nx - mr + 1:, :mc, :].set(F_fft[nx - mr + 1:, :mc, :])
123
124    # Transform back to physical space
125    F_hat = jnp.fft.irfft2(F_fft_hat, axes=(0, 1), s=(nx, ny))
126
127    return F_hat
128
129
130# ============================================================
131# Level 2 filtering (filter width = FGR*TFR*TFR)
132# ============================================================
133
134@jax.jit
135def Filtering_Level2(F_fft):
136    """
137    Parameters:
138    -----------
139    F_fft : ndarray with shape (nx, ny/2, nz)
140        rfft2 of a field (could be velocity components or scalars)
141
142    Returns:
143    --------
144    F_hatd : ndarray with shape (nx, ny, nz)
145        Filtered field with TFR**2 filter width
146    """
147
148    # Calculate cutoff wavenumbers for filtering
149    pr = round(nx / (2 * FGR * TFR * TFR))  # Cutoff in x-direction for rfft
150    pc = round(ny / (2 * FGR * TFR * TFR))  # Cutoff in y-direction for rfft
151
152    # Initialize zero array for filtered spectrum
153    F_fft_hatd = ZeRo3D_fftIni()
154
155    # Apply spectral cutoff filter
156    # First quadrant
157    F_fft_hatd = F_fft_hatd.at[:pr, :pc, :].set(F_fft[:pr, :pc, :])
158
159    # Second quadrant
160    F_fft_hatd = F_fft_hatd.at[nx - pr + 1:, :pc, :].set(F_fft[nx - pr + 1:, :pc, :])
161
162    # Transform back to physical space
163    F_hatd = jnp.fft.irfft2(F_fft_hatd, axes=(0, 1), s=(nx, ny))
164
165    return F_hatd