Strain Rates Computations

Source Code: StrainRates

StrainRates.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: StrainRates.py
 18=====================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-29
 24:Description: computes strain rate tensors and their magnitudes
 25"""
 26
 27# ============================================================
 28#  Imports
 29# ============================================================
 30
 31import jax
 32import jax.numpy as jnp
 33
 34# Import configuration from namelist
 35from ..config.ConfigLoader import *
 36from ..config import ConfigLoader as Config
 37
 38# Import derived variables
 39from ..config.DerivedVars import *
 40
 41# Import helper functions
 42from ..utilities.Utilities import StagGridAvg
 43
 44# Import FFT modules
 45from ..operations.FFT import FFT, FFT_pad
 46
 47# Import dealiasing modules
 48from ..operations.Dealiasing import Dealias1
 49
 50
 51# =================================================
 52# Compute strain rates on UVP nodes with dealiasing
 53# =================================================
 54
 55@jax.jit
 56def StrainsUVPnodes_Dealias(
 57        dudx, dvdx, dwdx,
 58        dudy, dvdy, dwdy,
 59        dudz, dvdz, dwdz,
 60        ZeRo3D, ZeRo3D_pad_fft):
 61    """
 62    Parameters:
 63    -----------
 64    dudx, dvdx, dwdx : ndarray
 65        Derivatives of velocity components in x-direction
 66    dudy, dvdy, dwdy : ndarray
 67        Derivatives of velocity components in y-direction
 68    dudz, dvdz, dwdz : ndarray
 69        Derivatives of velocity components in z-direction
 70    ZeRo3D : ndarray
 71        Pre-allocated zero array
 72    ZeRo3D_pad_fft : ndarray
 73        Pre-allocated zero array for dealiasing
 74
 75    Returns:
 76    --------
 77    S11, S22, S33, S12, S13, S23 : ndarray
 78        Strain rate tensor components
 79    S : ndarray
 80        Strain rate magnitude
 81    S11_pad, S22_pad, S33_pad, S12_pad, S13_pad, S23_pad : ndarray
 82        Dealiased strain rate tensor components
 83    S_pad : ndarray
 84        Dealiased strain rate magnitude
 85    """
 86
 87    # Initialize arrays
 88    uz, vz = ZeRo3D.copy(), ZeRo3D.copy()
 89    wx, wy = ZeRo3D.copy(), ZeRo3D.copy()
 90
 91    # Average derivatives to uvp nodes
 92    uz = uz.at[:, :, 1:nz - 1].set(StagGridAvg(dudz[:, :, 1:nz]))
 93    vz = vz.at[:, :, 1:nz - 1].set(StagGridAvg(dvdz[:, :, 1:nz]))
 94    wx = wx.at[:, :, 1:nz - 1].set(StagGridAvg(dwdx[:, :, 1:nz]))
 95    wy = wy.at[:, :, 1:nz - 1].set(StagGridAvg(dwdy[:, :, 1:nz]))
 96
 97    # Set lower boundary conditions (z = dz/2)
 98    uz = uz.at[:, :, 0].set(dudz[:, :, 0])
 99    vz = vz.at[:, :, 0].set(dvdz[:, :, 0])
100    wx = wx.at[:, :, 0].set(0.5 * dwdx[:, :, 1])
101    wy = wy.at[:, :, 0].set(0.5 * dwdy[:, :, 1])
102
103    # Set upper boundary conditions
104    uz = uz.at[:, :, nz - 1].set(dudz[:, :, nz - 1])
105    vz = vz.at[:, :, nz - 1].set(dvdz[:, :, nz - 1])
106    wx = wx.at[:, :, nz - 1].set(dwdx[:, :, nz - 1])
107    wy = wy.at[:, :, nz - 1].set(dwdy[:, :, nz - 1])
108
109    # Compute strain rate tensors on uvp nodes
110    S11, S22, S33 = dudx.copy(), dvdy.copy(), dwdz.copy()
111    S12 = 0.5 * (dudy + dvdx)
112    S13 = 0.5 * (uz + wx)
113    S23 = 0.5 * (vz + wy)
114
115    # Compute strain rate magnitude
116    S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
117                      2 * S12 ** 2 +
118                      2 * S13 ** 2 +
119                      2 * S23 ** 2))
120
121    # Dealiased variables - forward operations
122    S11_pad = Dealias1(FFT(S11), ZeRo3D_pad_fft)
123    S22_pad = Dealias1(FFT(S22), ZeRo3D_pad_fft)
124    S33_pad = Dealias1(FFT(S33), ZeRo3D_pad_fft)
125    S12_pad = Dealias1(FFT(S12), ZeRo3D_pad_fft)
126    S13_pad = Dealias1(FFT(S13), ZeRo3D_pad_fft)
127    S23_pad = Dealias1(FFT(S23), ZeRo3D_pad_fft)
128
129    # Compute dealiased strain rate magnitude
130    S_pad = jnp.sqrt(2 * (S11_pad ** 2 + S22_pad ** 2 + S33_pad ** 2 +
131                          2 * S12_pad ** 2 +
132                          2 * S13_pad ** 2 +
133                          2 * S23_pad ** 2))
134
135    return (
136        S11, S22, S33,
137        S12, S13, S23,
138        S,
139        S11_pad, S22_pad, S33_pad,
140        S12_pad, S13_pad, S23_pad,
141        S_pad)
142
143
144# ====================================================
145# Compute strain rates on UVP nodes without dealiasing
146# ====================================================
147
148@jax.jit
149def StrainsUVPnodes_NoDealias(
150        dudx, dvdx, dwdx,
151        dudy, dvdy, dwdy,
152        dudz, dvdz, dwdz,
153        ZeRo3D):
154    """
155    Parameters:
156    -----------
157    dudx, dvdx, dwdx : ndarray
158        Derivatives of velocity components in x-direction
159    dudy, dvdy, dwdy : ndarray
160        Derivatives of velocity components in y-direction
161    dudz, dvdz, dwdz : ndarray
162        Derivatives of velocity components in z-direction
163    ZeRo3D : ndarray
164        Pre-allocated zero array
165
166    Returns:
167    --------
168    S11, S22, S33, S12, S13, S23 : ndarray
169        Strain rate tensor components
170    S : ndarray
171        Strain rate magnitude
172    """
173
174    # Initialize arrays
175    uz, vz = ZeRo3D.copy(), ZeRo3D.copy()
176    wx, wy = ZeRo3D.copy(), ZeRo3D.copy()
177
178    # Average derivatives to uvp nodes
179    uz = uz.at[:, :, 1:nz - 1].set(StagGridAvg(dudz[:, :, 1:nz]))
180    vz = vz.at[:, :, 1:nz - 1].set(StagGridAvg(dvdz[:, :, 1:nz]))
181    wx = wx.at[:, :, 1:nz - 1].set(StagGridAvg(dwdx[:, :, 1:nz]))
182    wy = wy.at[:, :, 1:nz - 1].set(StagGridAvg(dwdy[:, :, 1:nz]))
183
184    # Set lower boundary conditions
185    uz = uz.at[:, :, 0].set(dudz[:, :, 0])
186    vz = vz.at[:, :, 0].set(dvdz[:, :, 0])
187    wx = wx.at[:, :, 0].set(0.5 * dwdx[:, :, 1])
188    wy = wy.at[:, :, 0].set(0.5 * dwdy[:, :, 1])
189
190    # Set upper boundary conditions
191    uz = uz.at[:, :, nz - 1].set(dudz[:, :, nz - 1])
192    vz = vz.at[:, :, nz - 1].set(dvdz[:, :, nz - 1])
193    wx = wx.at[:, :, nz - 1].set(dwdx[:, :, nz - 1])
194    wy = wy.at[:, :, nz - 1].set(dwdy[:, :, nz - 1])
195
196    # Compute strain rate tensors on uvp nodes
197    S11, S22, S33 = dudx.copy(), dvdy.copy(), dwdz.copy()
198    S12 = 0.5 * (dudy + dvdx)
199    S13 = 0.5 * (uz + wx)
200    S23 = 0.5 * (vz + wy)
201
202    # Compute strain rate magnitude
203    S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
204                      2 * S12 ** 2 +
205                      2 * S13 ** 2 +
206                      2 * S23 ** 2))
207
208    return (
209        S11, S22, S33,
210        S12, S13, S23,
211        S)
212
213
214# =================================================
215# Compute strain rates on W nodes with dealiasing
216# =================================================
217
218@jax.jit
219def StrainsWnodes_Dealias(
220        dudx, dvdx, dwdx,
221        dudy, dvdy, dwdy,
222        dudz, dvdz, dwdz,
223        ZeRo3D, ZeRo3D_pad_fft):
224    """
225    Parameters:
226    -----------
227    dudx, dvdx, dwdx : ndarray
228        Derivatives of velocity components in x-direction
229    dudy, dvdy, dwdy : ndarray
230        Derivatives of velocity components in y-direction
231    dudz, dvdz, dwdz : ndarray
232        Derivatives of velocity components in z-direction
233    ZeRo3D : ndarray
234        Pre-allocated zero array
235    ZeRo3D_pad_fft : ndarray
236        Pre-allocated array for dealiasing
237
238    Returns:
239    --------
240    S13_pad, S23_pad : ndarray
241        Dealiased shear strain components at W nodes
242    S_pad : ndarray
243        Dealiased strain rate magnitude at W nodes
244    """
245
246    # Initialize arrays
247    ux, uy = ZeRo3D.copy(), ZeRo3D.copy()
248    vx, vy = ZeRo3D.copy(), ZeRo3D.copy()
249    wz = ZeRo3D.copy()
250
251    # Average derivatives to w-levels
252    ux = ux.at[:, :, 1:nz-1].set(StagGridAvg(dudx[:, :, :nz - 1]))
253    uy = uy.at[:, :, 1:nz-1].set(StagGridAvg(dudy[:, :, :nz - 1]))
254    vx = vx.at[:, :, 1:nz-1].set(StagGridAvg(dvdx[:, :, :nz - 1]))
255    vy = vy.at[:, :, 1:nz-1].set(StagGridAvg(dvdy[:, :, :nz - 1]))
256    uz, vz = dudz.copy(), dvdz.copy()
257    wx, wy = dwdx.copy(), dwdy.copy()
258    wz = wz.at[:, :, 1:nz-1].set(StagGridAvg(dwdz[:, :, :nz - 1]))
259
260    # Set bottom boundary conditions (at z = dz/2)
261    ux = ux.at[:, :, 0].set(dudx[:, :, 0])
262    uy = uy.at[:, :, 0].set(dudy[:, :, 0])
263    vx = vx.at[:, :, 0].set(dvdx[:, :, 0])
264    vy = vy.at[:, :, 0].set(dvdy[:, :, 0])
265    wx = wx.at[:, :, 0].set(0.5 * (dwdx[:, :, 0] + dwdx[:, :, 1]))
266    wy = wy.at[:, :, 0].set(0.5 * (dwdy[:, :, 0] + dwdy[:, :, 1]))
267    wz = wz.at[:, :, 0].set(dwdz[:, :, 0])
268
269    # Compute strain rates at w-levels
270    S11, S22, S33 = ux.copy(), vy.copy(), wz.copy()
271    S12 = 0.5 * (uy + vx)
272    S13 = 0.5 * (uz + wx)
273    S23 = 0.5 * (vz + wy)
274
275    # Dealiased variables - forward operations
276    S11_pad = Dealias1(FFT(S11), ZeRo3D_pad_fft)
277    S22_pad = Dealias1(FFT(S22), ZeRo3D_pad_fft)
278    S33_pad = Dealias1(FFT(S33), ZeRo3D_pad_fft)
279    S12_pad = Dealias1(FFT(S12), ZeRo3D_pad_fft)
280    S13_pad = Dealias1(FFT(S13), ZeRo3D_pad_fft)
281    S23_pad = Dealias1(FFT(S23), ZeRo3D_pad_fft)
282
283    # Compute strain rate magnitude at w-levels
284    S_pad = jnp.sqrt(2 * (S11_pad ** 2 + S22_pad ** 2 + S33_pad ** 2 +
285                          2 * S12_pad ** 2 +
286                          2 * S13_pad ** 2 +
287                          2 * S23_pad ** 2))
288
289    return (S13_pad, S23_pad,
290            S_pad)
291
292
293# ====================================================
294# Compute strain rates on W nodes without dealiasing
295# ====================================================
296
297@jax.jit
298def StrainsWnodes_NoDealias(
299        dudx, dvdx, dwdx,
300        dudy, dvdy, dwdy,
301        dudz, dvdz, dwdz,
302        ZeRo3D):
303    """
304    Parameters:
305    -----------
306    dudx, dvdx, dwdx : ndarray
307        Derivatives of velocity components in x-direction
308    dudy, dvdy, dwdy : ndarray
309        Derivatives of velocity components in y-direction
310    dudz, dvdz, dwdz : ndarray
311        Derivatives of velocity components in z-direction
312    ZeRo3D : ndarray
313        Pre-allocated zero array
314
315    Returns:
316    --------
317    S13, S23 : ndarray
318        Shear strain components at W nodes
319    S : ndarray
320        Strain rate magnitude at W nodes
321    """
322
323    # Initialize arrays
324    ux, uy = ZeRo3D.copy(), ZeRo3D.copy()
325    vx, vy = ZeRo3D.copy(), ZeRo3D.copy()
326    wz = ZeRo3D.copy()
327
328    # Average derivatives to w-levels
329    ux = ux.at[:, :, 1:nz-1].set(StagGridAvg(dudx[:, :, :nz - 1]))
330    uy = uy.at[:, :, 1:nz-1].set(StagGridAvg(dudy[:, :, :nz - 1]))
331    vx = vx.at[:, :, 1:nz-1].set(StagGridAvg(dvdx[:, :, :nz - 1]))
332    vy = vy.at[:, :, 1:nz-1].set(StagGridAvg(dvdy[:, :, :nz - 1]))
333    uz, vz = dudz.copy(), dvdz.copy()
334    wx, wy = dwdx.copy(), dwdy.copy()
335    wz = wz.at[:, :, 1:nz-1].set(StagGridAvg(dwdz[:, :, :nz - 1]))
336
337    # Set bottom boundary conditions (at z = dz/2)
338    ux = ux.at[:, :, 0].set(dudx[:, :, 0])
339    uy = uy.at[:, :, 0].set(dudy[:, :, 0])
340    vx = vx.at[:, :, 0].set(dvdx[:, :, 0])
341    vy = vy.at[:, :, 0].set(dvdy[:, :, 0])
342    wx = wx.at[:, :, 0].set(0.5 * (dwdx[:, :, 0] + dwdx[:, :, 1]))
343    wy = wy.at[:, :, 0].set(0.5 * (dwdy[:, :, 0] + dwdy[:, :, 1]))
344    wz = wz.at[:, :, 0].set(dwdz[:, :, 0])
345
346    # Compute strain rates at w-levels
347    S11, S22, S33 = ux.copy(), vy.copy(), wz.copy()
348    S12 = 0.5 * (uy + vx)
349    S13 = 0.5 * (uz + wx)
350    S23 = 0.5 * (vz + wy)
351
352    # Compute strain rate magnitude at w-levels
353    S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
354                      2 * S12 ** 2 +
355                      2 * S13 ** 2 +
356                      2 * S23 ** 2))
357
358    return (S13, S23,
359            S)