PSCF v1.3
rpc/fts/compressor/LrCompressor.tpp
1#ifndef RPC_LR_COMPRESSOR_TPP
2#define RPC_LR_COMPRESSOR_TPP
3
4/*
5* PSCF - Polymer Self-Consistent Field
6*
7* Copyright 2015 - 2025, The Regents of the University of Minnesota
8* Distributed under the terms of the GNU General Public License.
9*/
10
11#include "LrCompressor.h"
12#include <rpc/system/System.h>
13#include <prdc/crystal/shiftToMinimum.h>
14#include <pscf/mesh/MeshIterator.h>
15#include <pscf/iterator/NanException.h>
16#include <util/global.h>
17#include <util/format/Dbl.h>
18
19
20namespace Pscf {
21namespace Rpc{
22
23 using namespace Util;
24
25 // Constructor
26 template <int D>
28 : Compressor<D>(system),
29 epsilon_(0.0),
30 itr_(0),
31 maxItr_(0),
32 totalItr_(0),
33 errorType_("rmsResid"),
34 verbose_(0),
35 intra_(system),
36 isIntraCalculated_(false),
37 isAllocated_(false)
38 { setClassName("LrCompressor"); }
39
40 // Destructor
41 template <int D>
44
45 // Read parameters from file
46 template <int D>
47 void LrCompressor<D>::readParameters(std::istream& in)
48 {
49 maxItr_ = 60;
50 read(in, "epsilon", epsilon_);
51 readOptional(in, "maxItr", maxItr_);
52 readOptional(in, "verbose", verbose_);
53 readOptional(in, "errorType", errorType_);
54 }
55
56 // Initialize just before entry to iterative loop.
57 template <int D>
59 {
60 const int nMonomer = system().mixture().nMonomer();
61 IntVec<D> const & dimensions = system().domain().mesh().dimensions();
62 for (int i = 0; i < D; ++i) {
63 if (i < D - 1) {
64 kMeshDimensions_[i] = dimensions[i];
65 } else {
66 kMeshDimensions_[i] = dimensions[i]/2 + 1;
67 }
68 }
69
70 // Allocate memory required by AM algorithm if not done earlier.
71 if (!isAllocated_){
72 resid_.allocate(dimensions);
73 residK_.allocate(dimensions);
74 wFieldTmp_.allocate(nMonomer);
75 intraCorrelationK_.allocate(kMeshDimensions_);
76 for (int i = 0; i < nMonomer; ++i) {
77 wFieldTmp_[i].allocate(dimensions);
78 }
79 isAllocated_ = true;
80 }
81
82 // Compute intraCorrelation
83 if (!isIntraCalculated_){
84 intra_.computeIntraCorrelations(intraCorrelationK_);
85 isIntraCalculated_ = true;
86 }
87 }
88
89 /*
90 * Adjust pressure field find partial saddle point.
91 */
92 template <int D>
94 {
95 // Initialization and allocate operations on entry to loop.
96 setup();
97 UTIL_CHECK(isAllocated_);
98
99 // Start overall timer
100 timerTotal_.start();
101
102 // Solve MDE
103 timerMDE_.start();
104 system().compute();
105 ++mdeCounter_;
106 timerMDE_.stop();
107
108 // Iterative loop
109 for (itr_ = 0; itr_ < maxItr_; ++itr_) {
110
111 if (verbose_ > 2) {
112 Log::file() << "------------------------------- \n";
113 }
114
115 if (verbose_ > 0){
116 Log::file() << " Iteration " << Int(itr_,5);
117 }
118
119 // Compute residual vector
120 getResidual();
121 double error;
122 try {
123 error = computeError(verbose_);
124 } catch (const NanException&) {
125 Log::file() << ", error = NaN" << std::endl;
126 break; // Exit loop if a NanException is caught
127 }
128 if (verbose_ > 0) {
129 Log::file() << ", error = " << Dbl(error, 15) << std::endl;
130 }
131
132 // Check for convergence
133 if (error < epsilon_) {
134
135 // Successful completion (i.e., converged within tolerance)
136 timerTotal_.stop();
137
138 if (verbose_ > 2) {
139 Log::file() << "-------------------------------\n";
140 }
141 if (verbose_ > 0) {
142 Log::file() << " Converged\n";
143 }
144 if (verbose_ == 2) {
145 Log::file() << "\n";
146 computeError(2);
147 }
148 //mdeCounter_ += itr_;
149 totalItr_ += itr_;
150
151 return 0; // Success
152
153 } else{
154
155 // Not yet converged.
156 updateWFields();
157 timerMDE_.start();
158 system().compute();
159 ++mdeCounter_;
160 timerMDE_.stop();
161
162 }
163
164 }
165
166 // Failure: iteration counter itr reached maxItr without converging
167 timerTotal_.stop();
168
169 Log::file() << "Iterator failed to converge.\n";
170 return 1;
171
172 }
173
174 /*
175 * Compute the residual for the current system state
176 */
177 template <int D>
178 void LrCompressor<D>::getResidual()
179 {
180 const int nMonomer = system().mixture().nMonomer();
181 const int meshSize = system().domain().mesh().size();
182
183 // Initialize residuals
184 for (int i = 0 ; i < meshSize; ++i) {
185 resid_[i] = -1.0;
186 }
187
188 // Compute SCF residual vector elements
189 for (int j = 0; j < nMonomer; ++j) {
190 for (int k = 0; k < meshSize; ++k) {
191 resid_[k] += system().c().rgrid(j)[k];
192 }
193 }
194 }
195
196 // update system w field using linear response approximation
197 template <int D>
198 void LrCompressor<D>::updateWFields()
199 {
200 const int nMonomer = system().mixture().nMonomer();
201 const int meshSize = system().domain().mesh().size();
202 const double vMonomer = system().mixture().vMonomer();
203
204 // Convert residual to Fourier Space
205 system().domain().fft().forwardTransform(resid_, residK_);
206 MeshIterator<D> iter;
207 iter.setDimensions(residK_.dftDimensions());
208 for (iter.begin(); !iter.atEnd(); ++iter) {
209 residK_[iter.rank()][0] *= 1.0 / (vMonomer * intraCorrelationK_[iter.rank()]);
210 residK_[iter.rank()][1] *= 1.0 / (vMonomer * intraCorrelationK_[iter.rank()]);
211 }
212
213 // Convert back to real Space (destroys residK_)
214 system().domain().fft().inverseTransformUnsafe(residK_, resid_);
215
216 for (int i = 0; i < nMonomer; i++){
217 for (int k = 0; k < meshSize; k++){
218 wFieldTmp_[i][k] = system().w().rgrid(i)[k] + resid_[k];
219 }
220 }
221 system().w().setRGrid(wFieldTmp_);
222 }
223
224 template<int D>
225 void LrCompressor<D>::outputToLog()
226 {}
227
228 template<int D>
229 void LrCompressor<D>::outputTimers(std::ostream& out) const
230 {
231 // Output timing results, if requested.
232 double total = timerTotal_.time();
233 out << "\n";
234 out << "LrCompressor time contributions:\n";
235 out << " ";
236 out << "Total" << std::setw(22)<< "Per Iteration"
237 << std::setw(9) << "Fraction" << "\n";
238 out << "MDE solution: "
239 << Dbl(timerMDE_.time(), 9, 3) << " s, "
240 << Dbl(timerMDE_.time()/totalItr_, 9, 3) << " s, "
241 << Dbl(timerMDE_.time()/total, 9, 3) << "\n";
242 out << "\n";
243 }
244
245 template<int D>
247 {
248 timerTotal_.clear();
249 timerMDE_.clear();
250 mdeCounter_ = 0;
251 totalItr_ = 0;
252 }
253
254 template <int D>
255 double LrCompressor<D>::maxAbs(RField<D> const & a)
256 {
257 const int n = a.capacity();
258 double max = 0.0;
259 double value;
260 for (int i = 0; i < n; i++) {
261 value = a[i];
262 if (std::isnan(value)) { // if value is NaN, throw NanException
263 throw NanException("LrCompressor::dotProduct",__FILE__,__LINE__,0);
264 }
265 if (fabs(value) > max) {
266 max = fabs(value);
267 }
268 }
269 return max;
270 }
271
272 // Compute and return inner product of two vectors.
273 template <int D>
274 double LrCompressor<D>::dotProduct(RField<D> const & a,
275 RField<D> const & b)
276 {
277 const int n = a.capacity();
278 UTIL_CHECK(b.capacity() == n);
279 double product = 0.0;
280 for (int i = 0; i < n; i++) {
281 // if either value is NaN, throw NanException
282 if (std::isnan(a[i]) || std::isnan(b[i])) {
283 throw NanException("AmCompressor::dotProduct",__FILE__,__LINE__,0);
284 }
285 product += a[i] * b[i];
286 }
287 return product;
288 }
289
290 // Compute L2 norm of an RField
291 template <int D>
292 double LrCompressor<D>::norm(RField<D> const & a)
293 {
294 double normSq = dotProduct(a, a);
295 return sqrt(normSq);
296 }
297
298 // Compute and return the scalar error
299 template <int D>
300 double LrCompressor<D>::computeError(int verbose)
301 {
302 const int meshSize = system().domain().mesh().size();
303 double error = 0;
304
305 // Find max residual vector element
306 double maxRes = maxAbs(resid_);
307 // Find norm of residual vector
308 double normRes = norm(resid_);
309 // Find root-mean-squared residual element value
310 double rmsRes = normRes/sqrt(meshSize);
311 if (errorType_ == "maxResid") {
312 error = maxRes;
313 } else if (errorType_ == "normResid") {
314 error = normRes;
315 } else if (errorType_ == "rmsResid") {
316 error = rmsRes;
317 } else {
318 UTIL_THROW("Invalid iterator error type in parameter file.");
319 }
320
321 if (verbose > 1) {
322 Log::file() << "\n";
323 Log::file() << "Max Residual = " << Dbl(maxRes,15) << "\n";
324 Log::file() << "Residual Norm = " << Dbl(normRes,15) << "\n";
325 Log::file() << "RMS Residual = " << Dbl(rmsRes,15) << "\n";
326
327 // Check if calculation has diverged (normRes will be NaN)
328 UTIL_CHECK(!std::isnan(normRes));
329 }
330 return error;
331 }
332
333}
334}
335#endif
An IntVec<D, T> is a D-component vector of elements of integer type T.
Definition IntVec.h:27
Iterator over points in a Mesh<D>.
void setDimensions(const IntVec< D > &dimensions)
Set the grid dimensions in all directions.
Exception thrown when not-a-number (NaN) is encountered.
Field of real double precision values on an FFT mesh.
Definition cpu/RField.h:29
System< D > const & system() const
Return const reference to parent system.
Compressor(System< D > &system)
Constructor.
int mdeCounter_
Count how many times MDE has been solved.
ScalarParam< Type > & read(std::istream &in, const char *label, Type &value)
Add and read a new required ScalarParam < Type > object.
void setClassName(const char *className)
Set class name string.
LrCompressor(System< D > &system)
Constructor.
ScalarParam< Type > & readOptional(std::istream &in, const char *label, Type &value)
Add and read a new optional ScalarParam < Type > object.
int compress()
Iterate pressure field to obtain partial saddle point.
void readParameters(std::istream &in)
Read all parameters and initialize.
void setup()
Initialize just before entry to iterative loop.
void outputTimers(std::ostream &out) const
Return compressor times contributions.
Main class, representing one complete system.
int capacity() const
Return allocated size.
Definition Array.h:159
Wrapper for a double precision number, for formatted ostream output.
Definition Dbl.h:40
Wrapper for an int, for formatted ostream output.
Definition Int.h:37
static std::ostream & file()
Get log ostream by reference.
Definition Log.cpp:59
File containing preprocessor macros for error handling.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
Definition global.h:68
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
Definition global.h:49
Real periodic fields, SCFT and PS-FTS (CPU).
Definition param_pc.dox:2
PSCF package top-level namespace.
Definition param_pc.dox:1
float product(float a, float b)
Product for float Data.
Definition product.h:22