PSCF v1.2
FFTBatched.tpp
1#ifndef PRDC_CUDA_FFT_BATCHED_TPP
2#define PRDC_CUDA_FFT_BATCHED_TPP
3
4/*
5* PSCF Package
6*
7* Copyright 2016 - 2022, The Regents of the University of Minnesota
8* Distributed under the terms of the GNU General Public License.
9*/
10
11#include "FFTBatched.h"
12#include "VecOp.h"
13
14/*
15* A note about const_casts:
16*
17* The cuFFT library is used in this file to perform discrete Fourier
18* transforms. cuFFT's complex-to-real inverse transform overwrites its
19* input array, but all other out-of-place transforms leave the input
20* array unaltered. However, all transforms in the cuFFT library require
21* non-const pointers to the input array, even though they do not alter
22* the array.
23*
24* In order to maintain const-correctness in PSCF, this FFT class accepts
25* const input arrays for its methods that perform a Fourier transform,
26* unless the transform is expected to modify / overwrite its input (as
27* is the case for complex-to-real inverse transforms). This denotes to
28* the caller of the method that the input array will not be altered,
29* which is an accurate representation of the expected behavior.
30*
31* However, the const-correctness of this FFT class creates a conflict
32* with the cuFFT library's requirement of non-const inputs. This conflict
33* is resolved using a const_cast, in which the const pointer to the input
34* array is made non-const when passed into cuFFT functions. The use of
35* const_cast is reserved only for these few select cases in which we are
36* confident that the input array will not be modified.
37*
38* For more information about the relevant cuFFT methods, see the cuFFT
39* documentation at https://docs.nvidia.com/cuda/cufft/. Unfortunately,
40* there is no explicitly documented guarantee that the transforms do not
41* modify their input, though it is implied. In Section 2.4, it is stated
42* that "[o]ut-of-place complex-to-real FFT will always overwrite input
43* buffer." No such claim is made for any other out-of-place transforms,
44* implying that they do not overwrite their inputs. Further, "[t]he
45* cuFFT API is modeled after FFTW," (beginning of Section 2), and FFTW
46* is much more explicit in their documentation
47* (https://www.fftw.org/fftw3_doc/index.html). In Section 4.3.2, it is
48* stated that, by default, "an out-of-place transform must not change
49* its input array," except for complex-to-real transforms, in which case
50* "no input-preserving algorithms are implemented." Finally, we note
51* that the unit tests for this FFT class check that the input array is
52* unaltered, allowing developers to continually ensure that the cuFFT
53* functions do not modify their input unexpectedly.
54*/
55
56namespace Pscf {
57namespace Prdc {
58namespace Cuda {
59
60 using namespace Util;
61
62 /*
63 * Default constructor.
64 */
65 template <int D>
67 : meshDimensions_(0),
68 kMeshDimensions_(0),
69 rSize_(0),
70 kSize_(0),
71 fPlan_(0),
72 iPlan_(0),
73 isSetup_(false)
74 {}
75
76 /*
77 * Destructor.
78 */
79 template <int D>
81 {
82 if (fPlan_) {
83 cufftDestroy(fPlan_);
84 }
85 if (iPlan_) {
86 cufftDestroy(iPlan_);
87 }
88 }
89
90 /*
91 * Set up FFT calculation (store grid dimensions and make FFT plan)
92 */
93 template <int D>
94 void FFTBatched<D>::setup(const IntVec<D>& meshDimensions, int batchSize)
95 {
96 // Preconditions
97 UTIL_CHECK(!isSetup_);
98
99 // Set Mesh dimensions
100 rSize_ = 1;
101 kSize_ = 1;
102 for (int i = 0; i < D; ++i) {
103 UTIL_CHECK(meshDimensions[i] > 0);
104 meshDimensions_[i] = meshDimensions[i];
105 if (i < D - 1) {
106 kMeshDimensions_[i] = meshDimensions[i];
107 } else {
108 kMeshDimensions_[i] = (meshDimensions[i]/2 + 1);
109 }
110 rSize_ *= meshDimensions_[i];
111 kSize_ *= kMeshDimensions_[i];
112 }
113
114 // Make FFT plans
115 makePlans(batchSize);
116
117 isSetup_ = true;
118 }
119
123 template <int D>
125 {
126 UTIL_CHECK(isSetup_);
127
128 if (batchSize == batchSize_) {
129 // Nothing to do
130 return;
131 } else {
132 // Remake FFT plans
133 makePlans(batchSize);
134 }
135 }
136
137 /*
138 * Make plans for variable batch size
139 */
140 template <int D>
141 void FFTBatched<D>::makePlans(int batchSize)
142 {
143 batchSize_ = batchSize;
144
145 UTIL_CHECK(kSize_ > 0);
146 UTIL_CHECK(rSize_ > 0);
147
148 int n[D];
149 for(int i = 0; i < D; i++) {
150 UTIL_CHECK(meshDimensions_[i] > 0);
151 n[i] = meshDimensions_[i];
152 }
153
154 #ifdef SINGLE_PRECISION
155 if (cufftPlanMany(&fPlan_, D, n, //plan, rank, n
156 NULL, 1, rSize_, //inembed, istride, idist
157 NULL, 1, kSize_, //onembed, ostride, odist
158 CUFFT_R2C, batchSize_) != CUFFT_SUCCESS) {
159 std::cout<<"FFTBatched: plan creation failed "<<std::endl;
160 exit(1);
161 }
162 if (cufftPlanMany(&iPlan_, D, n, //plan, rank, n
163 NULL, 1, kSize_, //inembed, istride, idist
164 NULL, 1, rSize_, //onembed, ostride, odist
165 CUFFT_C2R, batchSize_) != CUFFT_SUCCESS) {
166 std::cout<<"FFTBatched: plan creation failed "<<std::endl;
167 exit(1);
168 }
169 #else
170 if (cufftPlanMany(&fPlan_, D, n, // plan, rank, n
171 NULL, 1, rSize_, // inembed, istride, idist
172 NULL, 1, kSize_, // onembed, ostride, odist
173 CUFFT_D2Z, batchSize_) != CUFFT_SUCCESS) {
174 std::cout<<"FFTBatched: plan creation failed "<<std::endl;
175 exit(1);
176 }
177 if (cufftPlanMany(&iPlan_, D, n, // plan, rank, n
178 NULL, 1, kSize_, // inembed, istride, idist
179 NULL, 1, rSize_, // onembed, ostride, odist
180 CUFFT_Z2D, batchSize_) != CUFFT_SUCCESS) {
181 std::cout<<"FFTBatched: plan creation failed "<<std::endl;
182 exit(1);
183 }
184 #endif
185 }
186
187 /*
188 * Execute forward transform.
189 */
190 template <int D>
192 DeviceArray<cudaComplex>& kFields)
193 const
194 {
195 // Preconditions
196 UTIL_CHECK(isSetup_);
197 UTIL_CHECK(rFields.capacity() == rSize_ * batchSize_);
198 UTIL_CHECK(kFields.capacity() == kSize_ * batchSize_);
199
200 // Perform FFT
201 // (See note at top of file explaining this use of const_cast)
202 cufftResult result;
203 #ifdef SINGLE_PRECISION
204 result = cufftExecR2C(fPlan_, const_cast<cudaReal*>(rFields.cArray()),
205 kFields.cArray());
206 #else
207 result = cufftExecD2Z(fPlan_, const_cast<cudaReal*>(rFields.cArray()),
208 kFields.cArray());
209 #endif
210
211 if (result != CUFFT_SUCCESS) {
212 UTIL_THROW("Failure in cufft real-to-complex forward transform.");
213 }
214
215 // Rescale output data in-place
216 cudaReal scale = 1.0/cudaReal(rSize_);
217 VecOp::mulEqS(kFields, scale);
218 }
219
220 /*
221 * Execute inverse (complex-to-real) transform.
222 */
223 template <int D>
224 void
226 DeviceArray<cudaReal>& rFields)
227 const
228 {
229 // Preconditions
230 UTIL_CHECK(isSetup_);
231 UTIL_CHECK(kFields.capacity() == kSize_ * batchSize_);
232 UTIL_CHECK(rFields.capacity() == rSize_ * batchSize_);
233
234 // Perform FFT
235 cufftResult result;
236 #ifdef SINGLE_PRECISION
237 result = cufftExecC2R(iPlan_, kFields.cArray(), rFields.cArray());
238 #else
239 result = cufftExecZ2D(iPlan_, kFields.cArray(), rFields.cArray());
240 #endif
241
242 if (result != CUFFT_SUCCESS) {
243 UTIL_THROW("Failure in cufft complex-to-real inverse transform.");
244 }
245 }
246
247}
248}
249}
250#endif
Dynamic array on the GPU device with aligned data.
Definition rpg/System.h:32
int capacity() const
Return allocated capacity.
Data * cArray()
Return pointer to underlying C array.
An IntVec<D, T> is a D-component vector of elements of integer type T.
Definition IntVec.h:27
Batched Fourier transform wrapper for real data.
Definition FFTBatched.h:32
virtual ~FFTBatched()
Destructor.
void forwardTransform(DeviceArray< cudaReal > const &rFields, DeviceArray< cudaComplex > &kFields) const
Compute batched forward (real-to-complex) DFTs.
void inverseTransformUnsafe(DeviceArray< cudaComplex > &kFields, DeviceArray< cudaReal > &rFields) const
Compute inverse (complex-to-real) Fourier transform.
void setup(IntVec< D > const &meshDimensions, int batchSize)
Set up FFT calculation (get grid dimensions and make FFT plan)
FFTBatched()
Default constructor.
void resetBatchSize(int batchSize)
Set the batch size to a new value.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
Definition global.h:68
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
Definition global.h:51
void mulEqS(DeviceArray< cudaReal > &a, cudaReal const b, const int beginIdA, const int n)
Vector multiplication in-place, a[i] *= b, kernel wrapper (cudaReal).
Definition VecOp.cu:1875
PSCF package top-level namespace.
Definition param_pc.dox:1
Utility classes for scientific computation.