PSCF v1.3.3
rpc/fts/compressor/LrCompressor.tpp
1#ifndef RPC_LR_COMPRESSOR_TPP
2#define RPC_LR_COMPRESSOR_TPP
3
4/*
5* PSCF - Polymer Self-Consistent Field
6*
7* Copyright 2015 - 2025, The Regents of the University of Minnesota
8* Distributed under the terms of the GNU General Public License.
9*/
10
11#include "LrCompressor.h"
12#include <rpc/system/System.h>
13#include <rpc/solvers/Mixture.h>
14#include <rpc/field/Domain.h>
15#include <prdc/crystal/shiftToMinimum.h>
16#include <pscf/mesh/MeshIterator.h>
17#include <pscf/iterator/NanException.h>
18#include <util/global.h>
19#include <util/format/Dbl.h>
20
21
22namespace Pscf {
23namespace Rpc{
24
25 using namespace Util;
26
27 // Constructor
28 template <int D>
30 : Compressor<D>(system),
31 epsilon_(0.0),
32 itr_(0),
33 maxItr_(100),
34 totalItr_(0),
35 errorType_("rms"),
36 verbose_(0),
37 intra_(system),
38 isIntraCalculated_(false),
39 isAllocated_(false)
40 { setClassName("LrCompressor"); }
41
42 // Destructor
43 template <int D>
46
47 // Read parameters from file
48 template <int D>
49 void LrCompressor<D>::readParameters(std::istream& in)
50 {
51 // Default values
52 maxItr_ = 100;
53 errorType_ = "rms";
54
55 read(in, "epsilon", epsilon_);
56 readOptional(in, "maxItr", maxItr_);
57 readOptional(in, "verbose", verbose_);
58 readOptional(in, "errorType", errorType_);
59 }
60
61 // Initialize just before entry to iterative loop.
62 template <int D>
64 {
65 const int nMonomer = system().mixture().nMonomer();
66 IntVec<D> const & dimensions = system().domain().mesh().dimensions();
67 for (int i = 0; i < D; ++i) {
68 if (i < D - 1) {
69 kMeshDimensions_[i] = dimensions[i];
70 } else {
71 kMeshDimensions_[i] = dimensions[i]/2 + 1;
72 }
73 }
74
75 // Allocate memory required by AM algorithm if not done earlier.
76 if (!isAllocated_){
77 resid_.allocate(dimensions);
78 residK_.allocate(dimensions);
79 wFieldTmp_.allocate(nMonomer);
80 intraCorrelationK_.allocate(kMeshDimensions_);
81 for (int i = 0; i < nMonomer; ++i) {
82 wFieldTmp_[i].allocate(dimensions);
83 }
84 isAllocated_ = true;
85 }
86
87 // Compute intraCorrelation
88 if (!isIntraCalculated_){
89 intra_.computeIntraCorrelations(intraCorrelationK_);
90 isIntraCalculated_ = true;
91 }
92 }
93
94 /*
95 * Adjust pressure field find partial saddle point.
96 */
97 template <int D>
99 {
100 // Initialization and allocate operations on entry to loop.
101 setup();
102 UTIL_CHECK(isAllocated_);
103
104 // Start overall timer
105 timerTotal_.start();
106
107 // Solve MDE
108 timerMDE_.start();
109 system().compute();
110 ++mdeCounter_;
111 timerMDE_.stop();
112
113 // Iterative loop
114 for (itr_ = 0; itr_ < maxItr_; ++itr_) {
115
116 if (verbose_ > 2) {
117 Log::file() << "------------------------------- \n";
118 }
119
120 if (verbose_ > 0){
121 Log::file() << " Iteration " << Int(itr_,5);
122 }
123
124 // Compute residual vector
125 getResidual();
126 double error;
127 try {
128 error = computeError(verbose_);
129 } catch (const NanException&) {
130 Log::file() << ", error = NaN" << std::endl;
131 break; // Exit loop if a NanException is caught
132 }
133 if (verbose_ > 0) {
134 Log::file() << ", error = " << Dbl(error, 15) << std::endl;
135 }
136
137 // Check for convergence
138 if (error < epsilon_) {
139
140 // Successful completion (i.e., converged within tolerance)
141 timerTotal_.stop();
142
143 if (verbose_ > 2) {
144 Log::file() << "-------------------------------\n";
145 }
146 if (verbose_ > 0) {
147 Log::file() << " Converged\n";
148 }
149 if (verbose_ == 2) {
150 Log::file() << "\n";
151 computeError(2);
152 }
153 //mdeCounter_ += itr_;
154 totalItr_ += itr_;
155
156 return 0; // Success
157
158 } else{
159
160 // Not yet converged.
161 updateWFields();
162 timerMDE_.start();
163 system().compute();
164 ++mdeCounter_;
165 timerMDE_.stop();
166
167 }
168
169 }
170
171 // Failure: iteration counter itr reached maxItr without converging
172 timerTotal_.stop();
173
174 Log::file() << "Iterator failed to converge.\n";
175 return 1;
176
177 }
178
179 /*
180 * Compute the residual for the current system state
181 */
182 template <int D>
183 void LrCompressor<D>::getResidual()
184 {
185 const int nMonomer = system().mixture().nMonomer();
186 const int meshSize = system().domain().mesh().size();
187
188 // Initialize residuals
189 for (int i = 0 ; i < meshSize; ++i) {
190 resid_[i] = -1.0;
191 }
192
193 // Compute SCF residual vector elements
194 for (int j = 0; j < nMonomer; ++j) {
195 for (int k = 0; k < meshSize; ++k) {
196 resid_[k] += system().c().rgrid(j)[k];
197 }
198 }
199 }
200
201 // update system w field using linear response approximation
202 template <int D>
203 void LrCompressor<D>::updateWFields()
204 {
205 const int nMonomer = system().mixture().nMonomer();
206 const int meshSize = system().domain().mesh().size();
207 const double vMonomer = system().mixture().vMonomer();
208
209 // Convert residual to Fourier Space
210 system().domain().fft().forwardTransform(resid_, residK_);
211 MeshIterator<D> iter;
212 iter.setDimensions(residK_.dftDimensions());
213 for (iter.begin(); !iter.atEnd(); ++iter) {
214 residK_[iter.rank()][0] *= 1.0 / (vMonomer * intraCorrelationK_[iter.rank()]);
215 residK_[iter.rank()][1] *= 1.0 / (vMonomer * intraCorrelationK_[iter.rank()]);
216 }
217
218 // Convert back to real Space (destroys residK_)
219 system().domain().fft().inverseTransformUnsafe(residK_, resid_);
220
221 for (int i = 0; i < nMonomer; i++){
222 for (int k = 0; k < meshSize; k++){
223 wFieldTmp_[i][k] = system().w().rgrid(i)[k] + resid_[k];
224 }
225 }
226 system().w().setRGrid(wFieldTmp_);
227 }
228
229 template<int D>
230 void LrCompressor<D>::outputToLog()
231 {}
232
233 template<int D>
234 void LrCompressor<D>::outputTimers(std::ostream& out) const
235 {
236 // Output timing results, if requested.
237 double total = timerTotal_.time();
238 out << "\n";
239 out << "LrCompressor time contributions:\n";
240 out << " ";
241 out << "Total" << std::setw(22)<< "Per Iteration"
242 << std::setw(9) << "Fraction" << "\n";
243 out << "MDE solution: "
244 << Dbl(timerMDE_.time(), 9, 3) << " s, "
245 << Dbl(timerMDE_.time()/totalItr_, 9, 3) << " s, "
246 << Dbl(timerMDE_.time()/total, 9, 3) << "\n";
247 out << "\n";
248 }
249
250 template<int D>
252 {
253 timerTotal_.clear();
254 timerMDE_.clear();
255 mdeCounter_ = 0;
256 totalItr_ = 0;
257 }
258
259 template <int D>
260 double LrCompressor<D>::maxAbs(RField<D> const & a)
261 {
262 const int n = a.capacity();
263 double max = 0.0;
264 double value;
265 for (int i = 0; i < n; i++) {
266 value = a[i];
267 if (std::isnan(value)) { // if value is NaN, throw NanException
268 throw NanException("LrCompressor::dotProduct",__FILE__,__LINE__,0);
269 }
270 if (fabs(value) > max) {
271 max = fabs(value);
272 }
273 }
274 return max;
275 }
276
277 // Compute and return inner product of two vectors.
278 template <int D>
279 double LrCompressor<D>::dotProduct(RField<D> const & a,
280 RField<D> const & b)
281 {
282 const int n = a.capacity();
283 UTIL_CHECK(b.capacity() == n);
284 double product = 0.0;
285 for (int i = 0; i < n; i++) {
286 // if either value is NaN, throw NanException
287 if (std::isnan(a[i]) || std::isnan(b[i])) {
288 throw NanException("AmCompressor::dotProduct",__FILE__,__LINE__,0);
289 }
290 product += a[i] * b[i];
291 }
292 return product;
293 }
294
295 // Compute L2 norm of an RField
296 template <int D>
297 double LrCompressor<D>::norm(RField<D> const & a)
298 {
299 double normSq = dotProduct(a, a);
300 return sqrt(normSq);
301 }
302
303 // Compute and return the scalar error
304 template <int D>
305 double LrCompressor<D>::computeError(int verbose)
306 {
307 const int meshSize = system().domain().mesh().size();
308 double error = 0.0;
309
310 // Find max residual vector element
311 double maxRes = maxAbs(resid_);
312 // Find norm of residual vector
313 double normRes = norm(resid_);
314 // Find root-mean-squared residual element value
315 double rmsRes = normRes/sqrt(meshSize);
316 if (errorType_ == "max") {
317 error = maxRes;
318 } else if (errorType_ == "norm") {
319 error = normRes;
320 } else if (errorType_ == "rms") {
321 error = rmsRes;
322 } else {
323 UTIL_THROW("Invalid iterator error type in parameter file.");
324 }
325
326 if (verbose > 1) {
327 Log::file() << "\n";
328 Log::file() << "Max Residual = " << Dbl(maxRes,15) << "\n";
329 Log::file() << "Residual Norm = " << Dbl(normRes,15) << "\n";
330 Log::file() << "RMS Residual = " << Dbl(rmsRes,15) << "\n";
331
332 // Check if calculation has diverged (normRes will be NaN)
333 UTIL_CHECK(!std::isnan(normRes));
334 }
335 return error;
336 }
337
338}
339}
340#endif
An IntVec<D, T> is a D-component vector of elements of integer type T.
Definition IntVec.h:27
Iterator over points in a Mesh<D>.
void setDimensions(const IntVec< D > &dimensions)
Set the grid dimensions in all directions.
Exception thrown when not-a-number (NaN) is encountered.
Field of real double precision values on an FFT mesh.
Definition cpu/RField.h:29
System< D > const & system() const
Return const reference to parent system.
int mdeCounter_
Count how many times MDE has been solved.
ScalarParam< Type > & read(std::istream &in, const char *label, Type &value)
Add and read a new required ScalarParam < Type > object.
void setClassName(const char *className)
Set class name string.
LrCompressor(System< D > &system)
Constructor.
ScalarParam< Type > & readOptional(std::istream &in, const char *label, Type &value)
Add and read a new optional ScalarParam < Type > object.
int compress()
Iterate pressure field to obtain partial saddle point.
void readParameters(std::istream &in)
Read all parameters and initialize.
void setup()
Initialize just before entry to iterative loop.
void outputTimers(std::ostream &out) const
Return compressor times contributions.
Main class, representing a complete physical system.
int capacity() const
Return allocated size.
Definition Array.h:159
Wrapper for a double precision number, for formatted ostream output.
Definition Dbl.h:40
Wrapper for an int, for formatted ostream output.
Definition Int.h:37
static std::ostream & file()
Get log ostream by reference.
Definition Log.cpp:59
File containing preprocessor macros for error handling.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
Definition global.h:68
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
Definition global.h:49
double maxAbs(Array< double > const &in)
Get maximum absolute magnitude of array elements .
Definition Reduce.cpp:34
double max(Array< double > const &in)
Get maximum of array elements .
Definition Reduce.cpp:20
Real periodic fields, SCFT and PS-FTS (CPU).
Definition param_pc.dox:2
PSCF package top-level namespace.
float product(float a, float b)
Product for float Data.
Definition product.h:22