1#ifndef PSCF_AM_ITERATOR_TMPL_TPP
2#define PSCF_AM_ITERATOR_TMPL_TPP
11#include <pscf/inter/Interaction.h>
12#include <pscf/math/LuSolver.h>
13#include "NanException.h"
14#include <util/containers/FArray.h>
15#include <util/format/Dbl.h>
16#include <util/format/Int.h>
17#include <util/misc/Timer.h>
18#include <util/misc/FileMaster.h>
31 template <
typename Iterator,
typename T>
33 : errorType_(
"relNormResid"),
44 { setClassName(
"AmIteratorTmpl"); }
49 template <
typename Iterator,
typename T>
56 template <
typename Iterator,
typename T>
60 read(in,
"epsilon", epsilon_);
64 readOptional(in,
"maxItr", maxItr_);
68 readOptional(in,
"maxHist", maxHist_);
76 readOptional(in,
"verbose", verbose_);
80 readOptional(in,
"correctionRamp", r_);
84 readOptional(in,
"hasAmTest", hasAmTest_);
91 template <
typename Iterator,
typename T>
95 setup(isContinuation);
110 nBasis_ = fieldBasis_.size();
111 for (itr_ = 0; itr_ < maxItr_; ++itr_) {
115 fieldHists_.append(temp_);
120 Log::file() <<
"------------------------------- \n";
127 lambda_ = computeLambda(r_);
134 resHists_.append(temp_);
140 error_ = computeError(verbose_);
142 Log::file() <<
", error = NaN" << std::endl;
145 if (verbose_ > 0 && verbose_ < 3) {
158 correctionError_ = computeError(0);
160 projectionRatio_ += projectionError_/preError_;
161 correctionRatio_ += correctionError_/preError_;
164 preError_ = correctionError_;
169 if (error_ < epsilon_) {
176 Log::file() <<
"-------------------------------\n";
219 Log::file() <<
"Iterator failed to converge.\n";
229 template <
typename Iterator,
typename T>
231 { maxItr_ = maxItr; }
236 template <
typename Iterator,
typename T>
238 { maxHist_ = maxHist; }
243 template <
typename Iterator,
typename T>
246 errorType_ = errorType;
249 bool isValid = isValidErrorType();
253 std::string msg =
"Invalid iterator error type [";
255 msg +=
"] input in AmIteratorTmpl::setErrorType";
263 template <
typename Iterator,
typename T>
269 readOptional(in,
"errorType", errorType_);
272 bool isValid = isValidErrorType();
276 std::string msg =
"Invalid iterator error type [";
278 msg +=
"] in parameter file";
287 template <
typename Iterator,
typename T>
291 if (errorType_ ==
"norm") errorType_ =
"normResid";
292 if (errorType_ ==
"rms") errorType_ =
"rmsResid";
293 if (errorType_ ==
"max") errorType_ =
"maxResid";
294 if (errorType_ ==
"relNorm") errorType_ =
"relNormResid";
298 valid = (errorType_ ==
"normResid"
299 || errorType_ ==
"rmsResid"
300 || errorType_ ==
"maxResid"
301 || errorType_ ==
"relNormResid");
310 template <
typename Iterator,
typename T>
314 if (isAllocatedAM_)
return;
317 nElem_ = nElements();
320 fieldHists_.allocate(maxHist_+1);
321 resHists_.allocate(maxHist_+1);
322 fieldBasis_.allocate(maxHist_);
323 resBasis_.allocate(maxHist_);
326 fieldTrial_.allocate(nElem_);
327 resTrial_.allocate(nElem_);
328 temp_.allocate(nElem_);
331 U_.allocate(maxHist_, maxHist_);
332 v_.allocate(maxHist_);
333 coeffs_.allocate(maxHist_);
335 isAllocatedAM_ =
true;
341 template <
typename Iterator,
typename T>
344 if (!isAllocatedAM_)
return;
348 Log::file() <<
"Clearing AM field history and basis vectors.\n";
363 template <
typename Iterator,
typename T>
368 if (itr_ == 0 && nBasis_ == 0) {
370 for (m = 0; m < maxHist_; ++m) {
373 for (n = 0; n < maxHist_; ++n) {
380 if (itr_ == 0)
return;
383 if (fieldHists_.size() > 1) {
386 updateBasis(fieldBasis_, fieldHists_);
389 updateBasis(resBasis_, resHists_);
392 nBasis_ = fieldBasis_.size();
396 updateU(U_, resBasis_, nBasis_);
402 updateV(v_, resHists_[0], resBasis_, nBasis_);
408 coeffs_[0] = v_[0] / U_(0,0);
410 if (nBasis_ < maxHist_) {
419 for (
int i = 0; i < nBasis_; ++i) {
421 for (
int j = 0; j < nBasis_; ++j) {
422 tempU(i,j) = U_(i,j);
428 solver.allocate(nBasis_);
429 solver.computeLU(tempU);
430 solver.solve(tempv,tempcoeffs);
433 for (
int i = 0; i < nBasis_; ++i) {
434 coeffs_[i] = tempcoeffs[i];
438 if (nBasis_ == maxHist_) {
440 solver.allocate(maxHist_);
441 solver.computeLU(U_);
442 solver.solve(v_, coeffs_);
448 template <
typename Iterator,
typename T>
449 void AmIteratorTmpl<Iterator,T>::updateGuess()
453 setEqual(fieldTrial_, fieldHists_[0]);
454 setEqual(resTrial_, resHists_[0]);
460 addHistories(fieldTrial_, fieldBasis_, coeffs_, nBasis_);
461 addHistories(resTrial_, resBasis_, coeffs_, nBasis_);
483 projectionError_ = computeError(temp_, fieldTrial_, errorType_, 0);
490 addPredictedError(fieldTrial_, resTrial_,lambda_);
505 template <
typename Iterator,
typename T>
508 if (!isAllocatedAM()) {
514 if (!isContinuation) {
526 template <
typename Iterator,
typename T>
530 if (nBasis_ < maxHist_) {
531 lambda = 1.0 - pow(r, nBasis_ + 1);
541 template <
typename Iterator,
typename T>
544 double normSq = dotProduct(a, a);
549 template <
typename Iterator,
typename T>
557 for (
int m = maxHist-1; m > 0; --m) {
558 for (
int n = maxHist-1; n > 0; --n) {
564 for (
int m = 0; m < nHist; ++m) {
565 double dotprod = dotProduct(resBasis[0],resBasis[m]);
571 template <
typename Iterator,
typename T>
574 T
const & resCurrent,
578 for (
int m = 0; m < nHist; ++m) {
579 v[m] = dotProduct(resCurrent, resBasis[m]);
583 template <
typename Iterator,
typename T>
585 std::string errorType,
591 double maxRes = maxAbs(residTrial);
594 double normRes = norm(residTrial);
600 double rmsRes = normRes/sqrt(nElements());
603 double normField = norm(fieldTrial);
604 double relNormRes = normRes/normField;
607 if (errorType ==
"maxResid") {
609 }
else if (errorType ==
"normResid") {
611 }
else if (errorType ==
"rmsResid") {
613 }
else if (errorType ==
"relNormResid") {
616 UTIL_THROW(
"Invalid iterator error type in parameter file.");
621 Log::file() <<
"Max Residual = " <<
Dbl(maxRes,15) <<
"\n";
622 Log::file() <<
"Residual Norm = " <<
Dbl(normRes,15) <<
"\n";
623 Log::file() <<
"RMS Residual = " <<
Dbl(rmsRes,15) <<
"\n";
631 template <
typename Iterator,
typename T>
634 return computeError(resHists_[0], fieldHists_[0], errorType_, verbose);
637 template <
typename Iterator,
typename T>
641 double total = timerTotal_.time();
644 out <<
"Total" << std::setw(22)<<
"Per Iteration"
645 << std::setw(9) <<
"Fraction" <<
"\n";
646 out <<
"MDE solution: "
647 <<
Dbl(timerMDE_.time(), 9, 3) <<
" s, "
648 <<
Dbl(timerMDE_.time()/totalItr_, 9, 3) <<
" s, "
649 <<
Dbl(timerMDE_.time()/total, 9, 3) <<
"\n";
650 out <<
"residual computation: "
651 <<
Dbl(timerResid_.time(), 9, 3) <<
" s, "
652 <<
Dbl(timerResid_.time()/totalItr_, 9, 3) <<
" s, "
653 <<
Dbl(timerResid_.time()/total, 9, 3) <<
"\n";
654 out <<
"mixing coefficients: "
655 <<
Dbl(timerCoeff_.time(), 9, 3) <<
" s, "
656 <<
Dbl(timerCoeff_.time()/totalItr_, 9, 3) <<
" s, "
657 <<
Dbl(timerCoeff_.time()/total, 9, 3) <<
"\n";
658 out <<
"checking convergence: "
659 <<
Dbl(timerError_.time(), 9, 3) <<
" s, "
660 <<
Dbl(timerError_.time()/totalItr_, 9, 3) <<
" s, "
661 <<
Dbl(timerError_.time()/total, 9, 3) <<
"\n";
662 out <<
"updating guess: "
663 <<
Dbl(timerOmega_.time(), 9, 3) <<
" s, "
664 <<
Dbl(timerOmega_.time()/totalItr_, 9, 3) <<
" s, "
665 <<
Dbl(timerOmega_.time()/total, 9, 3)<<
"\n";
666 out <<
"total time: "
667 <<
Dbl(total, 9, 3) <<
" s, "
668 <<
Dbl(total/totalItr_, 9, 3) <<
" s \n";
673 out <<
"Average Projection Step Reduction Ratio: "
674 <<
Dbl(projectionRatio_/testCounter, 3, 3)<<
"\n";
675 out <<
"Average Correction Step Reduction Ratio: "
676 <<
Dbl(correctionRatio_/testCounter, 3, 3)<<
"\n";
681 template <
typename Iterator,
typename T>
Template for Anderson mixing iterator algorithm.
virtual void setup(bool isContinuation)
Initialize just before entry to iterative loop.
void setMaxHist(int maxHist)
Set value of maxHist (number of retained previous states)
virtual double computeLambda(double r)
Compute mixing parameter for correction step of Anderson mixing.
virtual double computeError(T &residTrial, T &fieldTrial, std::string errorType, int verbose)
Compute and return error used to test for convergence.
void allocateAM()
Allocate memory required by AM algorithm, if necessary.
virtual void clear()
Clear information about history.
void setErrorType(std::string errorType)
Set and validate value of errorType string.
void clearTimers()
Clear timers.
~AmIteratorTmpl()
Destructor.
void readErrorType(std::istream &in)
Read and validate the optional errorType string parameter.
void setMaxItr(int maxItr)
Set value of maxItr.
AmIteratorTmpl()
Constructor.
virtual bool isValidErrorType()
Checks if a string is a valid error type.
void readParameters(std::istream &in)
Read all parameters and initialize.
int solve(bool isContinuation=false)
Iterate to a solution.
void outputTimers(std::ostream &out)
Log output timing results.
virtual double norm(T const &hist)
Find the L2 norm of a vector.
Exception thrown when not-a-number (NaN) is encountered.
Dynamically allocatable contiguous array template.
void allocate(int capacity)
Allocate the underlying C array.
Dynamically allocated Matrix.
void allocate(int capacity1, int capacity2)
Allocate memory for a matrix.
Wrapper for a double precision number, for formatted ostream output.
Wrapper for an int, for formatted ostream output.
static std::ostream & file()
Get log ostream by reference.
int capacity1() const
Get number of rows (range of the first array index).
Class for storing history of previous values in an array.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
#define UTIL_THROW(msg)
Macro for throwing an Exception, reporting function, file and line number.
PSCF package top-level namespace.
Utility classes for scientific computation.