PSCF v1.2
cuda/FFT.tpp
1#ifndef PRDC_CUDA_FFT_TPP
2#define PRDC_CUDA_FFT_TPP
3
4/*
5* PSCF Package
6*
7* Copyright 2016 - 2022, The Regents of the University of Minnesota
8* Distributed under the terms of the GNU General Public License.
9*/
10
11#include "FFT.h"
12#include "VecOp.h"
13
14/*
15* A note about const_casts:
16*
17* The cuFFT library is used in this file to perform discrete Fourier
18* transforms. cuFFT's complex-to-real inverse transform overwrites its
19* input array, but all other out-of-place transforms leave the input
20* array unaltered. However, all transforms in the cuFFT library require
21* non-const pointers to the input array, even though they do not alter
22* the array.
23*
24* In order to maintain const-correctness in PSCF, this FFT class accepts
25* const input arrays for its methods that perform a Fourier transform,
26* unless the transform is expected to modify / overwrite its input (as
27* is the case for complex-to-real inverse transforms). This denotes to
28* the caller of the method that the input array will not be altered,
29* which is an accurate representation of the expected behavior.
30*
31* However, the const-correctness of this FFT class creates a conflict
32* with the cuFFT library's requirement of non-const inputs. This conflict
33* is resolved using a const_cast, in which the const pointer to the input
34* array is made non-const when passed into cuFFT functions. The use of
35* const_cast is reserved only for these few select cases in which we are
36* confident that the input array will not be modified.
37*
38* For more information about the relevant cuFFT methods, see the cuFFT
39* documentation at https://docs.nvidia.com/cuda/cufft/. Unfortunately,
40* there is no explicitly documented guarantee that the transforms do not
41* modify their input, though it is implied. In Section 2.4, it is stated
42* that "[o]ut-of-place complex-to-real FFT will always overwrite input
43* buffer." No such claim is made for any other out-of-place transforms,
44* implying that they do not overwrite their inputs. Further, "[t]he
45* cuFFT API is modeled after FFTW," (beginning of Section 2), and FFTW
46* is much more explicit in their documentation
47* (https://www.fftw.org/fftw3_doc/index.html). In Section 4.3.2, it is
48* stated that, by default, "an out-of-place transform must not change
49* its input array," except for complex-to-real transforms, in which case
50* "no input-preserving algorithms are implemented." Finally, we note
51* that the unit tests for this FFT class check that the input array is
52* unaltered, allowing developers to continually ensure that the cuFFT
53* functions do not modify their input unexpectedly.
54*/
55
56namespace Pscf {
57namespace Prdc {
58namespace Cuda {
59
60 using namespace Util;
61
62 /*
63 * Default constructor.
64 */
65 template <int D>
67 : meshDimensions_(0),
68 rSize_(0),
69 kSize_(0),
70 rcfPlan_(0),
71 criPlan_(0),
72 ccPlan_(0),
73 isSetup_(false)
74 {}
75
76 /*
77 * Destructor.
78 */
79 template <int D>
80 FFT<D>::~FFT()
81 {
82 if (rcfPlan_) {
83 cufftDestroy(rcfPlan_);
84 }
85 if (criPlan_) {
86 cufftDestroy(criPlan_);
87 }
88 if (ccPlan_) {
89 cufftDestroy(ccPlan_);
90 }
91 }
92
93 /*
94 * Setup grid dimensions, plans and work space.
95 */
96 template <int D>
97 void FFT<D>::setup(IntVec<D> const & meshDimensions)
98 {
99 // Precondition
100 UTIL_CHECK(!isSetup_);
101
102 // Set mesh dimensions and sizes
103 rSize_ = 1;
104 kSize_ = 1;
105 for (int i = 0; i < D; ++i) {
106 UTIL_CHECK(meshDimensions[i] > 0);
107 meshDimensions_[i] = meshDimensions[i];
108 rSize_ *= meshDimensions[i];
109 if (i < D - 1) {
110 kSize_ *= meshDimensions[i];
111 } else {
112 kSize_ *= (meshDimensions[i]/2 + 1);
113 }
114 }
115
116 // Reallocate kFieldCopy_ array if necessary
117 if (kFieldCopy_.isAllocated()) {
118 if (kFieldCopy_.capacity() != kSize_) {
119 kFieldCopy_.deallocate();
120 kFieldCopy_.allocate(meshDimensions);
121 }
122 }
123
124 // Make FFTW plans (explicit specializations)
125 makePlans();
126
127 isSetup_ = true;
128 }
129
130 /*
131 * Compute forward (real-to-complex) discrete Fourier transform.
132 */
133 template <int D>
134 void FFT<D>::forwardTransform(RField<D> const & rField,
135 RFieldDft<D>& kField) const
136 {
137 // Preconditions
138 UTIL_CHECK(isSetup_);
139 UTIL_CHECK(rField.capacity() == rSize_);
140 UTIL_CHECK(kField.capacity() == kSize_);
141
142 // Perform transform
143 // (See note at top of file explaining this use of const_cast)
144 cufftResult result;
145 #ifdef SINGLE_PRECISION
146 result = cufftExecR2C(rcfPlan_, const_cast<cudaReal*>(rField.cArray()),
147 kField.cArray());
148 #else
149 result = cufftExecD2Z(rcfPlan_, const_cast<cudaReal*>(rField.cArray()),
150 kField.cArray());
151 #endif
152 if (result != CUFFT_SUCCESS) {
153 UTIL_THROW("Failure in cufft real-to-complex forward transform");
154 }
155
156 // Rescale output data in-place
157 cudaReal scale = 1.0/cudaReal(rSize_);
158 VecOp::mulEqS(kField, scale);
159 }
160
161 /*
162 * Compute inverse (complex-to-real) DFT, overwriting the input.
163 */
164 template <int D>
165 void FFT<D>::inverseTransformUnsafe(RFieldDft<D>& kField,
166 RField<D>& rField) const
167 {
168 // Preconditions
169 UTIL_CHECK(isSetup_);
170 UTIL_CHECK(rField.capacity() == rSize_);
171 UTIL_CHECK(kField.capacity() == kSize_);
172 UTIL_CHECK(rField.meshDimensions() == meshDimensions_);
173 UTIL_CHECK(kField.meshDimensions() == meshDimensions_);
174
175 cufftResult result;
176 #ifdef SINGLE_PRECISION
177 result = cufftExecC2R(criPlan_, kField.cArray(), rField.cArray());
178 #else
179 result = cufftExecZ2D(criPlan_, kField.cArray(), rField.cArray());
180 #endif
181 if (result != CUFFT_SUCCESS) {
182 UTIL_THROW( "Failure in cufft complex-to-real inverse transform");
183 }
184
185 }
186
187 /*
188 * Compute inverse (complex-to-real) DFT without overwriting input.
189 */
190 template <int D>
191 void FFT<D>::inverseTransformSafe(RFieldDft<D> const & kField,
192 RField<D>& rField) const
193 {
194 // if kFieldCopy_ has been previously allocated, check size is correct
195 if (kFieldCopy_.isAllocated()) {
196 UTIL_CHECK(kFieldCopy_.capacity() == kSize_);
197 UTIL_CHECK(kFieldCopy_.meshDimensions() == meshDimensions_);
198 }
199
200 // make copy of kField (allocates kFieldCopy_ if necessary)
201 kFieldCopy_ = kField;
202
203 // Perform transform using copy of kField
204 inverseTransformUnsafe(kFieldCopy_, rField);
205 }
206
207 // Complex-to-Complex Transforms
208
209 /*
210 * Execute forward complex-to-complex transform.
211 */
212 template <int D>
213 void FFT<D>::forwardTransform(CField<D> const & rField, CField<D>& kField)
214 const
215 {
216 // Preconditions
217 UTIL_CHECK(isSetup_);
218 UTIL_CHECK(rField.capacity() == rSize_);
219 UTIL_CHECK(kField.capacity() == rSize_);
220 UTIL_CHECK(rField.meshDimensions() == meshDimensions_);
221 UTIL_CHECK(kField.meshDimensions() == meshDimensions_);
222
223 // Perform transform
224 // (See note at top of file explaining this use of const_cast)
225 cufftResult result;
226 #ifdef SINGLE_PRECISION
227 result = cufftExecC2C(ccPlan_, const_cast<cudaComplex*>(rField.cArray()),
228 kField.cArray(), CUFFT_FORWARD);
229 #else
230 result = cufftExecZ2Z(ccPlan_, const_cast<cudaComplex*>(rField.cArray()),
231 kField.cArray(), CUFFT_FORWARD);
232 #endif
233 if (result != CUFFT_SUCCESS) {
234 UTIL_THROW("Failure in cufft complex-to-complex forward transform");
235 }
236
237 // Rescale output data in-place
238 cudaReal scale = 1.0/cudaReal(rSize_);
239 VecOp::mulEqS(kField, scale);
240 }
241
242 /*
243 * Execute inverse (complex-to-complex) transform.
244 */
245 template <int D>
246 void FFT<D>::inverseTransform(CField<D> const & kField, CField<D>& rField)
247 const
248 {
249 // Preconditions
250 UTIL_CHECK(isSetup_);
251 UTIL_CHECK(rField.capacity() == rSize_);
252 UTIL_CHECK(kField.capacity() == rSize_);
253 UTIL_CHECK(rField.meshDimensions() == meshDimensions_);
254 UTIL_CHECK(kField.meshDimensions() == meshDimensions_);
255
256 // Perform transform
257 // (See note at top of file explaining this use of const_cast)
258 cufftResult result;
259 #ifdef SINGLE_PRECISION
260 result = cufftExecC2C(ccPlan_, const_cast<cudaComplex*>(kField.cArray()),
261 rField.cArray(), CUFFT_INVERSE);
262 #else
263 result = cufftExecZ2Z(ccPlan_, const_cast<cudaComplex*>(kField.cArray()),
264 rField.cArray(), CUFFT_INVERSE);
265 #endif
266 if (result != CUFFT_SUCCESS) {
267 UTIL_THROW( "Failure in cufft complex-to-complex inverse transform");
268 }
269
270 }
271
272}
273}
274}
275#endif
FFT()
Default constructor.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
Definition global.h:68
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
Definition global.h:51
PSCF package top-level namespace.
Definition param_pc.dox:1
Utility classes for scientific computation.