PSCF v1.3
rpc/scft/iterator/AmIteratorBasis.tpp
1#ifndef RPC_AM_ITERATOR_BASIS_TPP
2#define RPC_AM_ITERATOR_BASIS_TPP
3
4/*
5* PSCF - Polymer Self-Consistent Field
6*
7* Copyright 2015 - 2025, The Regents of the University of Minnesota
8* Distributed under the terms of the GNU General Public License.
9*/
10
11#include "AmIteratorBasis.h"
12#include <rpc/system/System.h>
13#include <pscf/inter/Interaction.h>
14#include <pscf/iterator/NanException.h>
15#include <util/global.h>
16#include <cmath>
17
18namespace Pscf{
19namespace Rpc {
20
21 using namespace Util;
22
23 // Constructor
24 template <int D>
26 : Iterator<D>(system)
27 {
28 isSymmetric_ = true;
29 setClassName("AmIteratorBasis");
30 }
31
32 // Destructor
33 template <int D>
36
37 // Read parameters from file
38 template <int D>
40 {
41 // Call parent class readParameters
44
45 // Allocate local modified copy of Interaction class
46 interaction_.setNMonomer(system().mixture().nMonomer());
47
48 // Default parameter values
49 isFlexible_ = 1;
50 scaleStress_ = 10.0;
51
52 int np = system().domain().unitCell().nParameter();
53 UTIL_CHECK(np > 0);
54 UTIL_CHECK(np <= 6);
55 UTIL_CHECK(system().domain().unitCell().lattice() != UnitCell<D>::Null);
56
57 // Read optional isFlexible boolean (true by default)
58 readOptional(in, "isFlexible", isFlexible_);
59
60 // Populate flexibleParams_ based on isFlexible_ (all 0s or all 1s),
61 // then optionally overwrite with user input from param file
62 if (isFlexible_) {
63 flexibleParams_.clear();
64 for (int i = 0; i < np; i++) {
65 flexibleParams_.append(true); // Set all values to true
66 }
67 // Read optional flexibleParams_ array to overwrite current array
68 readOptionalFSArray(in, "flexibleParams", flexibleParams_, np);
69 if (nFlexibleParams() == 0) isFlexible_ = false;
70 } else { // isFlexible_ = false
71 flexibleParams_.clear();
72 for (int i = 0; i < np; i++) {
73 flexibleParams_.append(false); // Set all values to false
74 }
75 }
76
77 // Read optional scaleStress value
78 readOptional(in, "scaleStress", scaleStress_);
79 }
80
81 // Output timing results to log file.
82 template<int D>
83 void AmIteratorBasis<D>::outputTimers(std::ostream& out) const
84 {
85 // Output timing results, if requested.
86 out << "\n";
87 out << "Iterator times contributions:\n";
89 }
90
91 // Protected virtual function
92
93 // Setup before entering iteration loop
94 template <int D>
95 void AmIteratorBasis<D>::setup(bool isContinuation)
96 {
98 interaction_.update(system().interaction());
99 }
100
101 // Private virtual functions used to implement AM algorithm
102
103 // Assign one array to another
104 template <int D>
105 void AmIteratorBasis<D>::setEqual(DArray<double>& a,
106 DArray<double> const & b)
107 { a = b; }
108
109 // Compute and return inner product of two vectors.
110 template <int D>
111 double AmIteratorBasis<D>::dotProduct(DArray<double> const & a,
112 DArray<double> const & b)
113 {
114 const int n = a.capacity();
115 UTIL_CHECK(b.capacity() == n);
116 double product = 0.0;
117 for (int i = 0; i < n; i++) {
118 // if either value is NaN, throw NanException
119 if (std::isnan(a[i]) || std::isnan(b[i])) {
120 throw NanException("AmIteratorBasis::dotProduct", __FILE__,
121 __LINE__, 0);
122 }
123 product += a[i] * b[i];
124 }
125 return product;
126 }
127
128 // Compute and return maximum element of a vector.
129 template <int D>
130 double AmIteratorBasis<D>::maxAbs(DArray<double> const & a)
131 {
132 const int n = a.capacity();
133 double max = 0.0;
134 double value;
135 for (int i = 0; i < n; i++) {
136 value = a[i];
137 if (std::isnan(value)) { // if value is NaN, throw NanException
138 throw NanException("AmIteratorBasis::dotProduct", __FILE__,
139 __LINE__, 0);
140 }
141 if (fabs(value) > max) {
142 max = fabs(value);
143 }
144 }
145 return max;
146 }
147
148 // Update basis
149 template <int D>
150 void
151 AmIteratorBasis<D>::updateBasis(RingBuffer< DArray<double> > & basis,
152 RingBuffer< DArray<double> > const & hists)
153 {
154 // Make sure at least two histories are stored
155 UTIL_CHECK(hists.size() >= 2);
156
157 const int n = hists[0].capacity();
158 DArray<double> newbasis;
159 newbasis.allocate(n);
160
161 // New basis vector is difference between two most recent states
162 for (int i = 0; i < n; i++) {
163 newbasis[i] = hists[0][i] - hists[1][i];
164 }
165 basis.append(newbasis);
166 }
167
168 // Add linear combination of basis vectors to trial field.
169 template <int D>
170 void
171 AmIteratorBasis<D>::addHistories(DArray<double>& trial,
172 RingBuffer<DArray<double> > const & basis,
173 DArray<double> coeffs,
174 int nHist)
175 {
176 int n = trial.capacity();
177 for (int i = 0; i < nHist; i++) {
178 for (int j = 0; j < n; j++) {
179 // Not clear on the origin of the -1 factor
180 trial[j] += coeffs[i] * -1 * basis[i][j];
181 }
182 }
183 }
184
185 // Add predicted error to field trial.
186 template <int D>
187 void AmIteratorBasis<D>::addPredictedError(DArray<double>& fieldTrial,
188 DArray<double> const & resTrial,
189 double lambda)
190 {
191 int n = fieldTrial.capacity();
192 for (int i = 0; i < n; i++) {
193 fieldTrial[i] += lambda * resTrial[i];
194 }
195 }
196
197 // Private virtual functions to exchange data with parent system
198
199 // Does the system have an initial field guess?
200 template <int D>
201 bool AmIteratorBasis<D>::hasInitialGuess()
202 { return system().w().hasData(); }
203
204 // Compute and return number of elements in a residual vector
205 template <int D>
206 int AmIteratorBasis<D>::nElements()
207 {
208 const int nMonomer = system().mixture().nMonomer();
209 const int nBasis = system().domain().basis().nBasis();
210
211 int nEle = nMonomer*nBasis;
212 if (isFlexible()) {
213 nEle += nFlexibleParams();
214 }
215
216 return nEle;
217 }
218
219 // Get the current w fields and lattice parameters
220 template <int D>
221 void AmIteratorBasis<D>::getCurrent(DArray<double>& curr)
222 {
223 const int nMonomer = system().mixture().nMonomer();
224 const int nBasis = system().domain().basis().nBasis();
225 const DArray< DArray<double> > & currSys = system().w().basis();
226
227 // Straighten out fields into linear arrays
228 for (int i = 0; i < nMonomer; i++) {
229 for (int k = 0; k < nBasis; k++) {
230 curr[i*nBasis+k] = currSys[i][k];
231 }
232 }
233
234 // Add elements associated with unit cell parameters (if any)
235 if (isFlexible()) {
236 const int nParam = system().domain().unitCell().nParameter();
237 const FSArray<double,6> currParam
238 = system().domain().unitCell().parameters();
239 int counter = 0;
240 for (int i = 0; i < nParam; i++) {
241 if (flexibleParams_[i]) {
242 curr[nMonomer*nBasis + counter] = scaleStress_*currParam[i];
243 counter++;
244 }
245 }
246 UTIL_CHECK(counter == nFlexibleParams());
247 }
248
249 }
250
251 // Perform the main system computation (solve the MDE)
252 template <int D>
253 void AmIteratorBasis<D>::evaluate()
254 {
255 // Solve MDEs for current omega field
256 // (computes stress if isFlexible_ == true)
257 system().compute(isFlexible_);
258 }
259
260 // Compute the residual for the current system state
261 template <int D>
262 void AmIteratorBasis<D>::getResidual(DArray<double>& resid)
263 {
264 const int n = nElements();
265 const int nMonomer = system().mixture().nMonomer();
266 const int nBasis = system().domain().basis().nBasis();
267
268 // Initialize residual vector to zero
269 for (int i = 0 ; i < n; ++i) {
270 resid[i] = 0.0;
271 }
272
273 // Compute SCF residual vector elements
274 for (int i = 0; i < nMonomer; ++i) {
275 for (int j = 0; j < nMonomer; ++j) {
276 double chi = interaction_.chi(i,j);
277 double p = interaction_.p(i,j);
278 DArray<double> const & c = system().c().basis(j);
279 DArray<double> const & w = system().w().basis(j);
280 for (int k = 0; k < nBasis; ++k) {
281 int idx = i*nBasis + k;
282 resid[idx] += chi*c[k] - p*w[k];
283 }
284 }
285 }
286
287 // If iterator has mask, account for it in residual values
288 if (system().mask().hasData()) {
289 DArray<double> const & mask = system().mask().basis();
290 double sumChiInv = interaction_.sumChiInverse();
291 for (int i = 0; i < nMonomer; ++i) {
292 for (int k = 0; k < nBasis; ++k) {
293 int idx = i*nBasis + k;
294 resid[idx] -= mask[k] / sumChiInv;
295 }
296 }
297 }
298
299 // If iterator has external fields, account for them in the values
300 // of the residuals
301 if (system().h().hasData()) {
302 for (int i = 0; i < nMonomer; ++i) {
303 for (int j = 0; j < nMonomer; ++j) {
304 double p = interaction_.p(i,j);
305 DArray<double> const & h = system().h().basis(j);
306 for (int k = 0; k < nBasis; ++k) {
307 int idx = i*nBasis + k;
308 resid[idx] += p * h[k];
309 }
310 }
311 }
312 }
313
314 // If not canonical, account for incompressibility
315 if (!system().mixture().isCanonical()) {
316 if (!system().mask().hasData()) {
317 for (int i = 0; i < nMonomer; ++i) {
318 resid[i*nBasis] -= 1.0 / interaction_.sumChiInverse();
319 }
320 }
321 } else {
322 // Explicitly set homogeneous residual components
323 for (int i = 0; i < nMonomer; ++i) {
324 resid[i*nBasis] = 0.0;
325 }
326 }
327
328 // If variable unit cell, compute stress residuals
329 if (isFlexible()) {
330 const int nParam = system().domain().unitCell().nParameter();
331
332 // Combined -1 factor and stress scaling here. This is okay:
333 // - residuals only show up as dot products (U, v, norm)
334 // or with their absolute value taken (max), so the
335 // sign on a given residual vector element is not relevant
336 // as long as it is consistent across all vectors
337 // - The scaling is applied here and to the unit cell param
338 // storage, so that updating is done on the same scale,
339 // and then undone right before passing to the unit cell.
340
341 int counter = 0;
342 for (int i = 0; i < nParam ; i++) {
343 if (flexibleParams_[i]) {
344 double str = stress(i);
345
346 resid[nMonomer*nBasis + counter] = -1 * scaleStress_ * str;
347 counter++;
348 }
349 }
350 UTIL_CHECK(counter == nFlexibleParams());
351 }
352
353 }
354
355 // Update the current system field coordinates
356 template <int D>
357 void AmIteratorBasis<D>::update(DArray<double>& newGuess)
358 {
359 // Convert back to field format
360 const int nMonomer = system().mixture().nMonomer();
361 const int nBasis = system().domain().basis().nBasis();
362
363 DArray< DArray<double> > wField;
364 wField.allocate(nMonomer);
365
366 // Restructure in format of monomers, basis functions
367 for (int i = 0; i < nMonomer; i++) {
368 wField[i].allocate(nBasis);
369 for (int k = 0; k < nBasis; k++)
370 {
371 wField[i][k] = newGuess[i*nBasis + k];
372 }
373 }
374 // If canonical, explicitly set homogeneous field components
375 if (system().mixture().isCanonical()) {
376 double chi;
377 for (int i = 0; i < nMonomer; ++i) {
378 wField[i][0] = 0.0; // initialize to 0
379 for (int j = 0; j < nMonomer; ++j) {
380 chi = interaction_.chi(i,j);
381 wField[i][0] += chi * system().c().basis(j)[0];
382 }
383 }
384 // If iterator has external fields, include them in homogeneous field
385 if (system().h().hasData()) {
386 for (int i = 0; i < nMonomer; ++i) {
387 wField[i][0] += system().h().basis(i)[0];
388 }
389 }
390 }
391 system().w().setBasis(wField);
392
393 if (isFlexible()) {
394 const int nParam = system().domain().unitCell().nParameter();
395 const int begin = nMonomer*nBasis;
396
397 FSArray<double,6> parameters;
398 parameters = system().domain().unitCell().parameters();
399
400 double coeff = 1.0 / scaleStress_;
401 int counter = 0;
402 for (int i = 0; i < nParam; i++) {
403 if (flexibleParams_[i]) {
404 parameters[i] = coeff * newGuess[begin + counter];
405 counter++;
406 }
407 }
408 UTIL_CHECK(counter == nFlexibleParams());
409
410 system().setUnitCell(parameters);
411 }
412 }
413
414 // Output relevant system details to the iteration log.
415 template<int D>
416 void AmIteratorBasis<D>::outputToLog()
417 {
418 if (isFlexible() && verbose() > 1) {
419 const int nParam = system().domain().unitCell().nParameter();
420 const int nMonomer = system().mixture().nMonomer();
421 const int nBasis = system().domain().basis().nBasis();
422 int counter = 0;
423 for (int i = 0; i < nParam; i++) {
424 if (flexibleParams_[i]) {
425 double str = residual()[nMonomer*nBasis + counter] /
426 (-1.0 * scaleStress_);
427 Log::file()
428 << " Cell Param " << i << " = "
429 << Dbl(system().domain().unitCell().parameters()[i], 15)
430 << " , stress = "
431 << Dbl(str, 15)
432 << "\n";
433 counter++;
434 }
435 }
436 }
437 }
438
439}
440}
441#endif
Exception thrown when not-a-number (NaN) is encountered.
Base template for UnitCell<D> classes, D=1, 2 or 3.
Definition UnitCell.h:56
FSArrayParam< Type, N > & readOptionalFSArray(std::istream &in, const char *label, FSArray< Type, N > &array, int size)
Add and read an optional FSArray < Type, N > array parameter.
void setup(bool isContinuation)
Setup iterator just before entering iteration loop.
void outputTimers(std::ostream &out) const
Output timing results to log file.
void setClassName(const char *className)
Set class name string.
ScalarParam< Type > & readOptional(std::istream &in, const char *label, Type &value)
Add and read a new optional ScalarParam < Type > object.
AmIteratorBasis(System< D > &system)
Constructor.
void readParameters(std::istream &in)
Read all parameters and initialize.
Base class for iterative solvers for SCF equations.
Main class, representing one complete system.
int capacity() const
Return allocated size.
Definition Array.h:159
Dynamically allocatable contiguous array template.
Definition DArray.h:32
void allocate(int capacity)
Allocate the underlying C array.
Definition DArray.h:199
Wrapper for a double precision number, for formatted ostream output.
Definition Dbl.h:40
A fixed capacity (static) contiguous array with a variable logical size.
Definition FSArray.h:38
static std::ostream & file()
Get log ostream by reference.
Definition Log.cpp:59
Class for storing history of previous values in an array.
Definition RingBuffer.h:27
File containing preprocessor macros for error handling.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
Definition global.h:68
Real periodic fields, SCFT and PS-FTS (CPU).
Definition param_pc.dox:2
PSCF package top-level namespace.
Definition param_pc.dox:1
float product(float a, float b)
Product for float Data.
Definition product.h:22