1#ifndef PSCF_AM_ITERATOR_TMPL_TPP
2#define PSCF_AM_ITERATOR_TMPL_TPP
11#include <pscf/inter/Interaction.h>
12#include <pscf/math/LuSolver.h>
13#include "NanException.h"
14#include <util/containers/FArray.h>
15#include <util/format/Dbl.h>
16#include <util/format/Int.h>
17#include <util/misc/Timer.h>
30 template <
typename Iterator,
typename T>
32 : errorType_(
"relNormResid"),
43 { setClassName(
"AmIteratorTmpl"); }
48 template <
typename Iterator,
typename T>
55 template <
typename Iterator,
typename T>
59 read(in,
"epsilon", epsilon_);
63 readOptional(in,
"maxItr", maxItr_);
67 readOptional(in,
"maxHist", maxHist_);
74 readOptional(in,
"verbose", verbose_);
78 readOptional(in,
"outputTime", outputTime_);
84 template <
typename Iterator,
typename T>
88 setup(isContinuation);
113 nBasis_ = fieldBasis_.size();
114 for (itr_ = 0; itr_ < maxItr_; ++itr_) {
118 fieldHists_.append(temp_);
123 Log::file() <<
"------------------------------- \n";
128 if (nBasis_ < maxHist_) {
129 lambda_ = 1.0 - pow(0.9, nBasis_ + 1);
139 resHists_.append(temp_);
146 error = computeError(verbose_);
148 Log::file() <<
", error = NaN" << std::endl;
152 Log::file() <<
", error = " <<
Dbl(error, 15) << std::endl;
160 if (error < epsilon_) {
167 Log::file() <<
"-------------------------------\n";
179 double total = timerTotal.
time();
181 Log::file() <<
"Iterator times contributions:\n";
184 << timerMDE.
time() <<
" s, "
185 << timerMDE.
time()/total <<
"\n";
187 << timerResid.
time() <<
" s, "
188 << timerResid.
time()/total <<
"\n";
190 << timerCoeff.
time() <<
" s, "
191 << timerCoeff.
time()/total <<
"\n";
193 << timerError.
time() <<
" s, "
194 << timerError.
time()/total <<
"\n";
196 << timerOmega.
time() <<
" s, "
197 << timerOmega.
time()/total <<
"\n";
233 Log::file() <<
"Iterator failed to converge.\n";
243 template <
typename Iterator,
typename T>
245 { maxItr_ = maxItr; }
250 template <
typename Iterator,
typename T>
252 { maxHist_ = maxHist; }
257 template <
typename Iterator,
typename T>
260 errorType_ = errorType;
262 if (!isValidErrorType()) {
263 std::string msg =
"Invalid iterator error type [";
265 msg +=
"] input in AmIteratorTmpl::setErrorType";
273 template <
typename Iterator,
typename T>
278 readOptional(in,
"errorType", errorType_);
280 if (!isValidErrorType()) {
281 std::string msg =
"Invalid iterator error type [";
283 msg +=
"] in parameter file";
289 template <
typename Iterator,
typename T>
293 if (errorType_ ==
"norm") errorType_ =
"normResid";
294 if (errorType_ ==
"rms") errorType_ =
"rmsResid";
295 if (errorType_ ==
"max") errorType_ =
"maxResid";
296 if (errorType_ ==
"relNorm") errorType_ =
"relNormResid";
300 valid = (errorType_ ==
"normResid"
301 || errorType_ ==
"rmsResid"
302 || errorType_ ==
"maxResid"
303 || errorType_ ==
"relNormResid");
312 template <
typename Iterator,
typename T>
316 if (isAllocatedAM_)
return;
319 nElem_ = nElements();
322 fieldHists_.allocate(maxHist_+1);
323 resHists_.allocate(maxHist_+1);
324 fieldBasis_.allocate(maxHist_);
325 resBasis_.allocate(maxHist_);
328 fieldTrial_.allocate(nElem_);
329 resTrial_.allocate(nElem_);
330 temp_.allocate(nElem_);
333 U_.allocate(maxHist_, maxHist_);
334 v_.allocate(maxHist_);
335 coeffs_.allocate(maxHist_);
337 isAllocatedAM_ =
true;
340 template <
typename Iterator,
typename T>
343 if (!isAllocatedAM_)
return;
346 Log::file() <<
"Clearing ring buffers\n";
360 template <
typename Iterator,
typename T>
365 if (itr_ == 0 && nBasis_ == 0) {
367 for (m = 0; m < maxHist_; ++m) {
370 for (n = 0; n < maxHist_; ++n) {
377 if (itr_ == 0)
return;
380 updateBasis(fieldBasis_, fieldHists_);
383 updateBasis(resBasis_, resHists_);
386 nBasis_ = fieldBasis_.size();
390 updateU(U_, resBasis_, nBasis_);
391 updateV(v_, resHists_[0], resBasis_, nBasis_);
398 coeffs_[0] = v_[0] / U_(0,0);
400 if (nBasis_ < maxHist_) {
409 for (
int i = 0; i < nBasis_; ++i) {
411 for (
int j = 0; j < nBasis_; ++j) {
412 tempU(i,j) = U_(i,j);
418 solver.allocate(nBasis_);
419 solver.computeLU(tempU);
420 solver.solve(tempv,tempcoeffs);
423 for (
int i = 0; i < nBasis_; ++i) {
424 coeffs_[i] = tempcoeffs[i];
428 if (nBasis_ == maxHist_) {
431 solver.computeLU(U_);
432 solver.solve(v_, coeffs_);
438 template <
typename Iterator,
typename T>
439 void AmIteratorTmpl<Iterator,T>::updateGuess()
443 setEqual(fieldTrial_, fieldHists_[0]);
444 setEqual(resTrial_, resHists_[0]);
450 addHistories(fieldTrial_, fieldBasis_, coeffs_, nBasis_);
451 addHistories(resTrial_, resBasis_, coeffs_, nBasis_);
456 addPredictedError(fieldTrial_, resTrial_,lambda_);
470 template <
typename Iterator,
typename T>
473 if (!isAllocatedAM()) {
476 if (!isContinuation) {
485 template <
typename Iterator,
typename T>
488 double normSq = dotProduct(a, a);
493 template <
typename Iterator,
typename T>
501 for (
int m = maxHist-1; m > 0; --m) {
502 for (
int n = maxHist-1; n > 0; --n) {
508 for (
int m = 0; m < nHist; ++m) {
509 double dotprod = dotProduct(resBasis[0],resBasis[m]);
515 template <
typename Iterator,
typename T>
518 T
const & resCurrent,
522 for (
int m = 0; m < nHist; ++m) {
523 v[m] = dotProduct(resCurrent, resBasis[m]);
527 template <
typename Iterator,
typename T>
537 double maxRes = maxAbs(resHists_[0]);
538 Log::file() <<
"Max Residual = " <<
Dbl(maxRes,15) <<
"\n";
541 double normRes = norm(resHists_[0]);
542 Log::file() <<
"Residual Norm = " <<
Dbl(normRes,15) <<
"\n";
545 double rmsRes = normRes/sqrt(nElem_);
546 Log::file() <<
"RMS Residual = " <<
Dbl(rmsRes,15) <<
"\n";
549 double normField = norm(fieldHists_[0]);
550 double relNormRes = normRes/normField;
551 Log::file() <<
"Relative Norm = " <<
Dbl(relNormRes,15) << std::endl;
557 if (errorType_ ==
"maxResid") {
559 }
else if (errorType_ ==
"normResid") {
561 }
else if (errorType_ ==
"rmsResid") {
563 }
else if (errorType_ ==
"relNormResid") {
566 UTIL_THROW(
"Invalid iterator error type in parameter file.");
572 if (errorType_ ==
"maxResid") {
573 error = maxAbs(resHists_[0]);
574 }
else if (errorType_ ==
"normResid") {
575 error = norm(resHists_[0]);
576 }
else if (errorType_ ==
"rmsResid") {
577 error = norm(resHists_[0])/sqrt(nElem_);
578 }
else if (errorType_ ==
"relNormResid") {
579 double normRes = norm(resHists_[0]);
580 double normField = norm(fieldHists_[0]);
581 error = normRes/normField;
583 UTIL_THROW(
"Invalid iterator error type in parameter file.");
Template for Anderson mixing iterator algorithm.
virtual void setup(bool isContinuation)
Initialize just before entry to iterative loop.
void setMaxHist(int maxHist)
Set value of maxHist (number of retained previous states)
virtual double computeError(int verbose)
Compute and return error used to test for convergence.
void allocateAM()
Allocate memory required by AM algorithm, if necessary.
virtual void clear()
Clear information about history.
void setErrorType(std::string errorType)
Set and validate value of errorType string.
~AmIteratorTmpl()
Destructor.
void readErrorType(std::istream &in)
Read and validate the optional errorType string parameter.
void setMaxItr(int maxItr)
Set value of maxItr.
AmIteratorTmpl()
Constructor.
virtual bool isValidErrorType()
Checks if a string is a valid error type.
void readParameters(std::istream &in)
Read all parameters and initialize.
int solve(bool isContinuation=false)
Iterate to a solution.
virtual double norm(T const &hist)
Find the L2 norm of a vector.
Exception thrown when not-a-number (NaN) is encountered.
Dynamically allocatable contiguous array template.
void allocate(int capacity)
Allocate the underlying C array.
Dynamically allocated Matrix.
void allocate(int capacity1, int capacity2)
Allocate memory for a matrix.
Wrapper for a double precision number, for formatted ostream output.
Wrapper for an int, for formatted ostream output.
static std::ostream & file()
Get log ostream by reference.
int capacity1() const
Get number of rows (range of the first array index).
Class for storing history of previous values in an array.
void start(TimePoint begin)
Start timing from an externally supplied time.
double time()
Return the accumulated time, in seconds.
void stop(TimePoint end)
Stop the clock at an externally supplied time.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
C++ namespace for polymer self-consistent field theory (PSCF).
Utility classes for scientific computation.