Strain Rates Computations
Source Code: StrainRates
StrainRates.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: StrainRates.py
18=====================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-29
24:Description: computes strain rate tensors and their magnitudes
25"""
26
27# ============================================================
28# Imports
29# ============================================================
30
31import jax
32import jax.numpy as jnp
33
34# Import configuration from namelist
35from ..config.ConfigLoader import *
36from ..config import ConfigLoader as Config
37
38# Import derived variables
39from ..config.DerivedVars import *
40
41# Import helper functions
42from ..utilities.Utilities import StagGridAvg
43
44# Import FFT modules
45from ..operations.FFT import FFT, FFT_pad
46
47# Import dealiasing modules
48from ..operations.Dealiasing import Dealias1
49
50
51# =================================================
52# Compute strain rates on UVP nodes with dealiasing
53# =================================================
54
55@jax.jit
56def StrainsUVPnodes_Dealias(
57 dudx, dvdx, dwdx,
58 dudy, dvdy, dwdy,
59 dudz, dvdz, dwdz,
60 ZeRo3D, ZeRo3D_pad_fft):
61 """
62 Parameters:
63 -----------
64 dudx, dvdx, dwdx : ndarray
65 Derivatives of velocity components in x-direction
66 dudy, dvdy, dwdy : ndarray
67 Derivatives of velocity components in y-direction
68 dudz, dvdz, dwdz : ndarray
69 Derivatives of velocity components in z-direction
70 ZeRo3D : ndarray
71 Pre-allocated zero array
72 ZeRo3D_pad_fft : ndarray
73 Pre-allocated zero array for dealiasing
74
75 Returns:
76 --------
77 S11, S22, S33, S12, S13, S23 : ndarray
78 Strain rate tensor components
79 S : ndarray
80 Strain rate magnitude
81 S11_pad, S22_pad, S33_pad, S12_pad, S13_pad, S23_pad : ndarray
82 Dealiased strain rate tensor components
83 S_pad : ndarray
84 Dealiased strain rate magnitude
85 """
86
87 # Initialize arrays
88 uz, vz = ZeRo3D.copy(), ZeRo3D.copy()
89 wx, wy = ZeRo3D.copy(), ZeRo3D.copy()
90
91 # Average derivatives to uvp nodes
92 uz = uz.at[:, :, 1:nz - 1].set(StagGridAvg(dudz[:, :, 1:nz]))
93 vz = vz.at[:, :, 1:nz - 1].set(StagGridAvg(dvdz[:, :, 1:nz]))
94 wx = wx.at[:, :, 1:nz - 1].set(StagGridAvg(dwdx[:, :, 1:nz]))
95 wy = wy.at[:, :, 1:nz - 1].set(StagGridAvg(dwdy[:, :, 1:nz]))
96
97 # Set lower boundary conditions (z = dz/2)
98 uz = uz.at[:, :, 0].set(dudz[:, :, 0])
99 vz = vz.at[:, :, 0].set(dvdz[:, :, 0])
100 wx = wx.at[:, :, 0].set(0.5 * dwdx[:, :, 1])
101 wy = wy.at[:, :, 0].set(0.5 * dwdy[:, :, 1])
102
103 # Set upper boundary conditions
104 uz = uz.at[:, :, nz - 1].set(dudz[:, :, nz - 1])
105 vz = vz.at[:, :, nz - 1].set(dvdz[:, :, nz - 1])
106 wx = wx.at[:, :, nz - 1].set(dwdx[:, :, nz - 1])
107 wy = wy.at[:, :, nz - 1].set(dwdy[:, :, nz - 1])
108
109 # Compute strain rate tensors on uvp nodes
110 S11, S22, S33 = dudx.copy(), dvdy.copy(), dwdz.copy()
111 S12 = 0.5 * (dudy + dvdx)
112 S13 = 0.5 * (uz + wx)
113 S23 = 0.5 * (vz + wy)
114
115 # Compute strain rate magnitude
116 S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
117 2 * S12 ** 2 +
118 2 * S13 ** 2 +
119 2 * S23 ** 2))
120
121 # Dealiased variables - forward operations
122 S11_pad = Dealias1(FFT(S11), ZeRo3D_pad_fft)
123 S22_pad = Dealias1(FFT(S22), ZeRo3D_pad_fft)
124 S33_pad = Dealias1(FFT(S33), ZeRo3D_pad_fft)
125 S12_pad = Dealias1(FFT(S12), ZeRo3D_pad_fft)
126 S13_pad = Dealias1(FFT(S13), ZeRo3D_pad_fft)
127 S23_pad = Dealias1(FFT(S23), ZeRo3D_pad_fft)
128
129 # Compute dealiased strain rate magnitude
130 S_pad = jnp.sqrt(2 * (S11_pad ** 2 + S22_pad ** 2 + S33_pad ** 2 +
131 2 * S12_pad ** 2 +
132 2 * S13_pad ** 2 +
133 2 * S23_pad ** 2))
134
135 return (
136 S11, S22, S33,
137 S12, S13, S23,
138 S,
139 S11_pad, S22_pad, S33_pad,
140 S12_pad, S13_pad, S23_pad,
141 S_pad)
142
143
144# ====================================================
145# Compute strain rates on UVP nodes without dealiasing
146# ====================================================
147
148@jax.jit
149def StrainsUVPnodes_NoDealias(
150 dudx, dvdx, dwdx,
151 dudy, dvdy, dwdy,
152 dudz, dvdz, dwdz,
153 ZeRo3D):
154 """
155 Parameters:
156 -----------
157 dudx, dvdx, dwdx : ndarray
158 Derivatives of velocity components in x-direction
159 dudy, dvdy, dwdy : ndarray
160 Derivatives of velocity components in y-direction
161 dudz, dvdz, dwdz : ndarray
162 Derivatives of velocity components in z-direction
163 ZeRo3D : ndarray
164 Pre-allocated zero array
165
166 Returns:
167 --------
168 S11, S22, S33, S12, S13, S23 : ndarray
169 Strain rate tensor components
170 S : ndarray
171 Strain rate magnitude
172 """
173
174 # Initialize arrays
175 uz, vz = ZeRo3D.copy(), ZeRo3D.copy()
176 wx, wy = ZeRo3D.copy(), ZeRo3D.copy()
177
178 # Average derivatives to uvp nodes
179 uz = uz.at[:, :, 1:nz - 1].set(StagGridAvg(dudz[:, :, 1:nz]))
180 vz = vz.at[:, :, 1:nz - 1].set(StagGridAvg(dvdz[:, :, 1:nz]))
181 wx = wx.at[:, :, 1:nz - 1].set(StagGridAvg(dwdx[:, :, 1:nz]))
182 wy = wy.at[:, :, 1:nz - 1].set(StagGridAvg(dwdy[:, :, 1:nz]))
183
184 # Set lower boundary conditions
185 uz = uz.at[:, :, 0].set(dudz[:, :, 0])
186 vz = vz.at[:, :, 0].set(dvdz[:, :, 0])
187 wx = wx.at[:, :, 0].set(0.5 * dwdx[:, :, 1])
188 wy = wy.at[:, :, 0].set(0.5 * dwdy[:, :, 1])
189
190 # Set upper boundary conditions
191 uz = uz.at[:, :, nz - 1].set(dudz[:, :, nz - 1])
192 vz = vz.at[:, :, nz - 1].set(dvdz[:, :, nz - 1])
193 wx = wx.at[:, :, nz - 1].set(dwdx[:, :, nz - 1])
194 wy = wy.at[:, :, nz - 1].set(dwdy[:, :, nz - 1])
195
196 # Compute strain rate tensors on uvp nodes
197 S11, S22, S33 = dudx.copy(), dvdy.copy(), dwdz.copy()
198 S12 = 0.5 * (dudy + dvdx)
199 S13 = 0.5 * (uz + wx)
200 S23 = 0.5 * (vz + wy)
201
202 # Compute strain rate magnitude
203 S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
204 2 * S12 ** 2 +
205 2 * S13 ** 2 +
206 2 * S23 ** 2))
207
208 return (
209 S11, S22, S33,
210 S12, S13, S23,
211 S)
212
213
214# =================================================
215# Compute strain rates on W nodes with dealiasing
216# =================================================
217
218@jax.jit
219def StrainsWnodes_Dealias(
220 dudx, dvdx, dwdx,
221 dudy, dvdy, dwdy,
222 dudz, dvdz, dwdz,
223 ZeRo3D, ZeRo3D_pad_fft):
224 """
225 Parameters:
226 -----------
227 dudx, dvdx, dwdx : ndarray
228 Derivatives of velocity components in x-direction
229 dudy, dvdy, dwdy : ndarray
230 Derivatives of velocity components in y-direction
231 dudz, dvdz, dwdz : ndarray
232 Derivatives of velocity components in z-direction
233 ZeRo3D : ndarray
234 Pre-allocated zero array
235 ZeRo3D_pad_fft : ndarray
236 Pre-allocated array for dealiasing
237
238 Returns:
239 --------
240 S13_pad, S23_pad : ndarray
241 Dealiased shear strain components at W nodes
242 S_pad : ndarray
243 Dealiased strain rate magnitude at W nodes
244 """
245
246 # Initialize arrays
247 ux, uy = ZeRo3D.copy(), ZeRo3D.copy()
248 vx, vy = ZeRo3D.copy(), ZeRo3D.copy()
249 wz = ZeRo3D.copy()
250
251 # Average derivatives to w-levels
252 ux = ux.at[:, :, 1:nz-1].set(StagGridAvg(dudx[:, :, :nz - 1]))
253 uy = uy.at[:, :, 1:nz-1].set(StagGridAvg(dudy[:, :, :nz - 1]))
254 vx = vx.at[:, :, 1:nz-1].set(StagGridAvg(dvdx[:, :, :nz - 1]))
255 vy = vy.at[:, :, 1:nz-1].set(StagGridAvg(dvdy[:, :, :nz - 1]))
256 uz, vz = dudz.copy(), dvdz.copy()
257 wx, wy = dwdx.copy(), dwdy.copy()
258 wz = wz.at[:, :, 1:nz-1].set(StagGridAvg(dwdz[:, :, :nz - 1]))
259
260 # Set bottom boundary conditions (at z = dz/2)
261 ux = ux.at[:, :, 0].set(dudx[:, :, 0])
262 uy = uy.at[:, :, 0].set(dudy[:, :, 0])
263 vx = vx.at[:, :, 0].set(dvdx[:, :, 0])
264 vy = vy.at[:, :, 0].set(dvdy[:, :, 0])
265 wx = wx.at[:, :, 0].set(0.5 * (dwdx[:, :, 0] + dwdx[:, :, 1]))
266 wy = wy.at[:, :, 0].set(0.5 * (dwdy[:, :, 0] + dwdy[:, :, 1]))
267 wz = wz.at[:, :, 0].set(dwdz[:, :, 0])
268
269 # Compute strain rates at w-levels
270 S11, S22, S33 = ux.copy(), vy.copy(), wz.copy()
271 S12 = 0.5 * (uy + vx)
272 S13 = 0.5 * (uz + wx)
273 S23 = 0.5 * (vz + wy)
274
275 # Dealiased variables - forward operations
276 S11_pad = Dealias1(FFT(S11), ZeRo3D_pad_fft)
277 S22_pad = Dealias1(FFT(S22), ZeRo3D_pad_fft)
278 S33_pad = Dealias1(FFT(S33), ZeRo3D_pad_fft)
279 S12_pad = Dealias1(FFT(S12), ZeRo3D_pad_fft)
280 S13_pad = Dealias1(FFT(S13), ZeRo3D_pad_fft)
281 S23_pad = Dealias1(FFT(S23), ZeRo3D_pad_fft)
282
283 # Compute strain rate magnitude at w-levels
284 S_pad = jnp.sqrt(2 * (S11_pad ** 2 + S22_pad ** 2 + S33_pad ** 2 +
285 2 * S12_pad ** 2 +
286 2 * S13_pad ** 2 +
287 2 * S23_pad ** 2))
288
289 return (S13_pad, S23_pad,
290 S_pad)
291
292
293# ====================================================
294# Compute strain rates on W nodes without dealiasing
295# ====================================================
296
297@jax.jit
298def StrainsWnodes_NoDealias(
299 dudx, dvdx, dwdx,
300 dudy, dvdy, dwdy,
301 dudz, dvdz, dwdz,
302 ZeRo3D):
303 """
304 Parameters:
305 -----------
306 dudx, dvdx, dwdx : ndarray
307 Derivatives of velocity components in x-direction
308 dudy, dvdy, dwdy : ndarray
309 Derivatives of velocity components in y-direction
310 dudz, dvdz, dwdz : ndarray
311 Derivatives of velocity components in z-direction
312 ZeRo3D : ndarray
313 Pre-allocated zero array
314
315 Returns:
316 --------
317 S13, S23 : ndarray
318 Shear strain components at W nodes
319 S : ndarray
320 Strain rate magnitude at W nodes
321 """
322
323 # Initialize arrays
324 ux, uy = ZeRo3D.copy(), ZeRo3D.copy()
325 vx, vy = ZeRo3D.copy(), ZeRo3D.copy()
326 wz = ZeRo3D.copy()
327
328 # Average derivatives to w-levels
329 ux = ux.at[:, :, 1:nz-1].set(StagGridAvg(dudx[:, :, :nz - 1]))
330 uy = uy.at[:, :, 1:nz-1].set(StagGridAvg(dudy[:, :, :nz - 1]))
331 vx = vx.at[:, :, 1:nz-1].set(StagGridAvg(dvdx[:, :, :nz - 1]))
332 vy = vy.at[:, :, 1:nz-1].set(StagGridAvg(dvdy[:, :, :nz - 1]))
333 uz, vz = dudz.copy(), dvdz.copy()
334 wx, wy = dwdx.copy(), dwdy.copy()
335 wz = wz.at[:, :, 1:nz-1].set(StagGridAvg(dwdz[:, :, :nz - 1]))
336
337 # Set bottom boundary conditions (at z = dz/2)
338 ux = ux.at[:, :, 0].set(dudx[:, :, 0])
339 uy = uy.at[:, :, 0].set(dudy[:, :, 0])
340 vx = vx.at[:, :, 0].set(dvdx[:, :, 0])
341 vy = vy.at[:, :, 0].set(dvdy[:, :, 0])
342 wx = wx.at[:, :, 0].set(0.5 * (dwdx[:, :, 0] + dwdx[:, :, 1]))
343 wy = wy.at[:, :, 0].set(0.5 * (dwdy[:, :, 0] + dwdy[:, :, 1]))
344 wz = wz.at[:, :, 0].set(dwdz[:, :, 0])
345
346 # Compute strain rates at w-levels
347 S11, S22, S33 = ux.copy(), vy.copy(), wz.copy()
348 S12 = 0.5 * (uy + vx)
349 S13 = 0.5 * (uz + wx)
350 S23 = 0.5 * (vz + wy)
351
352 # Compute strain rate magnitude at w-levels
353 S = jnp.sqrt(2 * (S11 ** 2 + S22 ** 2 + S33 ** 2 +
354 2 * S12 ** 2 +
355 2 * S13 ** 2 +
356 2 * S23 ** 2))
357
358 return (S13, S23,
359 S)