12#include "Propagator.h" 
   14#include <prdc/cuda/WaveList.h> 
   15#include <prdc/cuda/FFT.h> 
   16#include <prdc/cuda/resources.h> 
   17#include <prdc/crystal/UnitCell.h> 
   19#include <pscf/mesh/Mesh.h> 
   20#include <pscf/mesh/MeshIterator.h> 
   21#include <pscf/solvers/BlockTmpl.tpp> 
   36      __global__ 
void _realMulVConjVV(cudaReal* a, cudaComplex 
const * b,
 
   37                                      cudaComplex 
const * c,
 
   38                                      cudaReal 
const * d, 
const int n)
 
   40         int nThreads = blockDim.x * gridDim.x;
 
   41         int startID = blockIdx.x * blockDim.x + threadIdx.x;
 
   43         for (
int i = startID; i < n; i += 
nThreads) {
 
   52            a[i] = ((bt.x * ct.x) + (bt.y * ct.y)) * d[i];
 
   59      __global__ 
void _richardsonEx(cudaReal* qNew, 
 
   62                                    cudaReal 
const * expW2, 
const int n)
 
   64         int nThreads = blockDim.x * gridDim.x;
 
   65         int startID = blockIdx.x * blockDim.x + threadIdx.x;
 
   67         for (
int i = startID; i < n; i += 
nThreads) {
 
   68            q2 = qr2[i] * expW2[i];
 
   69            qNew[i] = (4.0 * q2 - qr[i]) / 3.0;
 
   76      __global__ 
void _addEqMulVVc(cudaReal* a, 
 
   82         int nThreads = blockDim.x * gridDim.x;
 
   83         int startID = blockIdx.x * blockDim.x + threadIdx.x;
 
   84         for (
int i = startID; i < n; i += 
nThreads) {
 
   85            a[i] += b[i] * c[i] * d;
 
   92      __global__ 
void _addEqMulVVV(cudaReal* a, 
 
   98         int nThreads = blockDim.x * gridDim.x;
 
   99         int startID = blockIdx.x * blockDim.x + threadIdx.x;
 
  100         for (
int i = startID; i < n; i += 
nThreads) {
 
  101            a[i] += b[i]*c[i]*d[i];
 
  132      _realMulVConjVV<<<nBlocks, nThreads>>>(a.
cArray(), b.
cArray(),
 
  159      _richardsonEx<<<nBlocks, nThreads>>>(qNew.
cArray(), qr.
cArray(),
 
  183      _addEqMulVVc<<<nBlocks, nThreads>>>(a.
cArray(), b.
cArray(),
 
  205      _addEqMulVVV<<<nBlocks, nThreads>>>(a.
cArray(), b.
cArray(),
 
  218      unitCellPtr_(nullptr),
 
  219      waveListPtr_(nullptr),
 
  227      useBatchedFFT_(true),
 
 
  256      unitCellPtr_ = &cell;
 
  257      waveListPtr_ = &wavelist;
 
 
  271      useBatchedFFT_ = useBatchedFFT;
 
  277      expW_.allocate(
mesh().dimensions());
 
  278      expKsq_.allocate(kMeshDimensions_);
 
  279      expKsq2_.allocate(kMeshDimensions_);
 
  281         expW2_.allocate(
mesh().dimensions());
 
  282         qrPair_.allocate(2 * 
mesh().size());
 
  283         qkPair_.allocate(2 * kSize_);
 
  286         expWInv_.allocate(
mesh().dimensions());
 
  287         qk_.allocate(
mesh().dimensions());
 
  300         tempNs = floor(
length() / (2.0 * 
ds) + 0.5);
 
  305         ds_ = 
length()/double(ns_ - 1);
 
  321      fftBatchedPair_.setup(
mesh().dimensions(), 2);
 
  324      if (useBatchedFFT_) {
 
  326         fftBatchedAll_.setup(
mesh().dimensions(), ns_);
 
  327         q0kBatched_.allocate(ns_ * kSize_);
 
  328         q1kBatched_.allocate(ns_ * kSize_);
 
 
  351         tempNs = floor(
length() / (2.0 * dsTarget_) + 0.5);
 
  356         ds_ = 
length()/double(ns_-1);
 
  365            if (useBatchedFFT_) {
 
  367               q0kBatched_.deallocate();
 
  368               q1kBatched_.deallocate();
 
  369               q0kBatched_.allocate(ns_ * kSize_);
 
  370               q1kBatched_.allocate(ns_ * kSize_);
 
  371               fftBatchedAll_.resetBatchSize(ns_);
 
 
  396      UTIL_CHECK(nParams_ == unitCell().nParameter());
 
 
  405   void Block<D>::computeExpKsq()
 
  411      if (!waveListPtr_->hasKSq()) {
 
  412         waveListPtr_->computeKSq();
 
  416      double bSqFactor = -1.0 * kuhn() * kuhn() / 6.0;
 
  423      VecOp::expVc(expKsq2_, waveListPtr_->kSq(), bSqFactor / 2.0);
 
  436      int nx = 
mesh().size();
 
 
  466      int nx = 
mesh().size();
 
  479      qr.associate(qrPair_, 0, 
mesh().dimensions());
 
  480      qr2.associate(qrPair_, nx, 
mesh().dimensions());
 
  481      qk.associate(qkPair_, 0, 
mesh().dimensions());
 
  482      qk2.associate(qkPair_, kSize_, 
mesh().dimensions());
 
  486      fftBatchedPair_.forwardTransform(qrPair_, qkPair_); 
 
  489      fftBatchedPair_.inverseTransformUnsafe(qkPair_, qrPair_); 
 
  491      fft().forwardTransform(qr2, qk2); 
 
  493      fft().inverseTransformUnsafe(qk2, qr2); 
 
  494      richardsonEx(qout, qr, qr2, expW2_); 
 
 
  514      int nx = 
mesh().size();
 
 
  531      int nx = 
mesh().size();
 
  542      fft().forwardTransform(qin, qk_);
 
  544      fft().inverseTransformUnsafe(qk_, qout);       
 
 
  556      int nx = 
mesh().size();
 
  565      fft().forwardTransform(qin, qk_);        
 
  567      fft().inverseTransformUnsafe(qk_, qout);       
 
 
  577      int nx = 
mesh().size();
 
  591      addEqMulVVc(
cField(), p0.
q(0), p1.
q(ns_ - 1), 1.0);
 
  592      addEqMulVVc(
cField(), p0.
q(ns_ - 1), p1.
q(0), 1.0);
 
  594      for (
int j = 1; j < ns_ - 1; j += 2) {
 
  596         addEqMulVVc(
cField(), p0.
q(j), p1.
q(ns_ - 1 - j), 4.0);
 
  598      for (
int j = 2; j < ns_ - 2; j += 2) {
 
  600         addEqMulVVc(
cField(), p0.
q(j), p1.
q(ns_ - 1 - j), 2.0);
 
 
  613      int nx = 
mesh().size();
 
  628      for (
int j = 1; j < ns_ - 1; ++j) {
 
  629         addEqMulVVV(
cField(), p0.
q(j), p1.
q(ns_ - 1 - j), expWInv_);
 
 
  644      int nx = 
mesh().size();
 
  662      if (!waveListPtr_->hasdKSq()) {
 
  663         waveListPtr_->computedKSq();
 
  667      double dels, normal, increment;
 
  675      for (i = 0; i < nParams_; ++i) {
 
  680      if (useBatchedFFT_) {
 
  683         UTIL_CHECK(
mesh().dimensions() == fftBatchedAll_.meshDimensions());
 
  690         fftBatchedAll_.forwardTransform(p0.
qAll(), q0kBatched_);
 
  691         fftBatchedAll_.forwardTransform(p1.
qAll(), q1kBatched_);
 
  695         if (!q0k_.isAllocated()) {
 
  696            q0k_.allocate(
mesh().dimensions());
 
  697            q1k_.allocate(
mesh().dimensions());
 
  703      for (j = 0; j < ns_ ; ++j) {
 
  705         if (useBatchedFFT_) { 
 
  708            q0k_.associate(q0kBatched_, j * kSize_, 
mesh().dimensions());
 
  709            q1k_.associate(q1kBatched_, (ns_-1-j) * kSize_, 
 
  710                           mesh().dimensions());
 
  714            fft().forwardTransform(p0.
q(j), q0k_);
 
  715            fft().forwardTransform(p1.
q(ns_-1-j), q1k_);
 
  720         if (j != 0 && j != ns_ - 1) {
 
  729         for (n = 0; n < nParams_ ; ++n) {
 
  732            realMulVConjVV(rTmp, q0k_, q1k_, waveListPtr_->dKSq(n));
 
  736            increment *= 
kuhn() * 
kuhn() * dels / normal;
 
  740         if (useBatchedFFT_) {
 
  748      for (i = 0; i < nParams_; ++i) {
 
  749         stress_[i] -= (dQ[i] * prefactor);
 
 
  763      int nx = 
mesh().size();
 
  778      if (useBatchedFFT_) {
 
  780         UTIL_CHECK(meshDim == fftBatchedAll_.meshDimensions());
 
  787         fftBatchedAll_.forwardTransform(p0.
qAll(), q0kBatched_);
 
  788         fftBatchedAll_.forwardTransform(p1.
qAll(), q1kBatched_);
 
  790         if (!q0k_.isAllocated()) {
 
  791            q0k_.allocate(meshDim);
 
  792            q1k_.allocate(meshDim);
 
  797      if (!waveListPtr_->hasdKSq()) {
 
  798         waveListPtr_->computedKSq();
 
  804      for (
int i = 0; i < nParams_; ++i) {
 
  814         if (useBatchedFFT_) { 
 
  817            q0k_.associate(q0kBatched_, 0, meshDim);
 
  818            q1k_.associate(q1kBatched_, (ns_- 2) * kSize_, meshDim);
 
  822            fft().forwardTransform(p0.
q(0), q0k_);
 
  823            fft().forwardTransform(p1.
q(ns_ - 2), q1k_);
 
  825         for (
int n = 0; n < nParams_ ; ++n) {
 
  826            realMulVConjVV(rTmp, q0k_, q1k_, waveListPtr_->dKSq(n));
 
  829            increment *= 0.5*bSq;
 
  832         if (useBatchedFFT_) {
 
  839      for (
int j = 1; j < ns_ - 2; ++j) {
 
  842         if (useBatchedFFT_) { 
 
  845            q0k_.associate(q0kBatched_, j * kSize_, meshDim);
 
  846            q1k_.associate(q1kBatched_, (ns_- 2 -j) * kSize_, meshDim);
 
  850            fft().forwardTransform(p0.
q(j), q0k_);
 
  851            fft().forwardTransform(p1.
q(ns_ - 2 - j), q1k_);
 
  855         for (
int n = 0; n < nParams_ ; ++n) {
 
  858            realMulVConjVV(rTmp, q0k_, q1k_, waveListPtr_->dKSq(n));
 
  868         if (useBatchedFFT_) {
 
  878         if (useBatchedFFT_) { 
 
  879            q0k_.associate(q0kBatched_, (ns_-2)*kSize_, meshDim);
 
  880            q1k_.associate(q1kBatched_, 0, meshDim);
 
  883            fft().forwardTransform(p0.
q(ns_-2), q0k_);
 
  884            fft().forwardTransform(p1.
q(0), q1k_);
 
  886         for (
int n = 0; n < nParams_ ; ++n) {
 
  887            realMulVConjVV(rTmp, q0k_, q1k_, waveListPtr_->dKSq(n));
 
  890            increment *= 0.5*bSq;
 
  893         if (useBatchedFFT_) {
 
  901      for (
int i = 0; i < nParams_; ++i) {
 
  902         stress_.append(-1.0*prefactor*dQ[i]);
 
 
virtual void setKuhn(double kuhn)
Dynamic array on the GPU device with aligned data.
int capacity() const
Return allocated capacity.
Data * cArray()
Return pointer to underlying C array.
virtual void setLength(double length)
Set the length of this block (only valid for thread model).
An IntVec<D, T> is a D-component vector of elements of integer type T.
Description of a regular grid of points in a periodic domain.
Fourier transform wrapper.
static void computeKMesh(IntVec< D > const &rMeshDimensions, IntVec< D > &kMeshDimensions, int &kSize)
Compute dimensions and size of k-space mesh for DFT of real data.
Fourier transform of a real field on an FFT mesh.
Field of real double precision values on an FFT mesh.
Class to calculate and store properties of wavevectors.
int nParameter() const
Get the number of parameters in the unit cell.
Base template for UnitCell<D> classes, D=1, 2 or 3.
void setLength(double newLength)
Set or reset block length (only used in thread model).
double ds() const
Contour length step size.
void allocate(double ds, bool useBatchedFFT=true)
Allocate memory and set contour step size for thread model.
double kuhn() const
Get monomer statistical segment length.
void stepHalfBondBead(RField< D > const &qin, RField< D > &qout)
Compute a half-bond operator for the bead model.
Mesh< D > const & mesh() const
Return associated spatial Mesh by const reference.
void stepBondBead(RField< D > const &qin, RField< D > &qout)
Compute a bond operator for the bead model.
int nBead() const
Get the number of beads in this block, in the bead model.
void setKuhn(double kuhn)
Set or reset monomer statistical segment length.
Propagator< D > & propagator(int directionId)
Get a Propagator for a specified direction.
void associate(Mesh< D > const &mesh, FFT< D > const &fft, UnitCell< D > const &cell, WaveList< D > &wavelist)
Create permanent associations with related objects.
void computeStressBead(double prefactor)
Compute stress contribution for this block, using bead model.
void setupSolver(RField< D > const &w)
Set solver for this block.
void computeConcentrationBead(double prefactor)
Compute the concentration for this block, using the bead model.
void computeConcentrationThread(double prefactor)
Compute unnormalized concentration for block by integration.
void stepFieldBead(RField< D > &q)
Apply the exponential field operator for the bead model.
void clearUnitCellData()
Clear all internal data that depends on lattice parameters.
RField< D > & cField()
Get the associated monomer concentration field.
FFT< D > const & fft() const
Return associated FFT<D> object by const reference.
void computeStressThread(double prefactor)
Compute stress contribution for this block, using thread model.
void stepBead(RField< D > const &qin, RField< D > &qout)
Compute one step of solution of MDE for the bead model.
double length() const
Get the length of this block, in the thread model.
void stepThread(RField< D > const &qin, RField< D > &qout)
Compute step of integration loop, from i to i+1.
MDE solver for one-direction of one block.
DeviceArray< cudaReal > const & qAll()
Return the full array of q-fields as an unrolled 1D array.
bool isHeadEnd() const
Is the head vertex a chain end?
bool isSolved() const
Has the modified diffusion equation been solved?
bool isTailEnd() const
Is the tail vertex a chain end?
RField< D > const & q(int i) const
Return q-field at specified slice.
int capacity() const
Return allocated size.
A fixed capacity (static) contiguous array with a variable logical size.
void clear()
Set logical size to zero.
void append(Data const &data)
Append data to the end of the array.
#define UTIL_CHECK(condition)
Assertion macro suitable for serial or parallel production code.
void setThreadsLogical(int nThreadsLogical)
Given total number of threads, set 1D execution configuration.
int nThreads()
Get the number of threads per block for execution.
int nBlocks()
Get the current number of blocks for execution.
bool isThread()
Is the thread model in use ?
bool isBead()
Is the bead model in use ?
double sum(Array< double > const &in)
Compute sum of array elements .
void divSV(Array< double > &a, double b, Array< double > const &c)
Vector division, a[i] = b / c[i].
void mulEqV(Array< double > &a, Array< double > const &b)
Vector multiplication in-place, a[i] *= b[i].
void mulEqS(Array< double > &a, double b)
Vector multiplication in-place, a[i] *= b.
void eqS(Array< double > &a, double b)
Vector assignment, a[i] = b.
void expVc(DeviceArray< cudaReal > &a, DeviceArray< cudaReal > const &b, cudaReal const c)
Vector exponentiation w/ coefficient, a[i] = exp(b[i]*c), kernel wrapper.
void mulVVPair(DeviceArray< cudaReal > &a1, DeviceArray< cudaReal > &a2, DeviceArray< cudaReal > const &b1, DeviceArray< cudaReal > const &b2, DeviceArray< cudaReal > const &s)
Vector multiplication in pairs, ax[i] = bx[i] * s[i], kernel wrapper.
void mulEqVPair(DeviceArray< cudaReal > &a1, DeviceArray< cudaReal > &a2, DeviceArray< cudaReal > const &s)
In-place vector multiplication in pairs, ax[i] *= s[i], kernel wrapper.
Fields, FFTs, and utilities for periodic boundary conditions (CUDA)
Periodic fields and crystallography.
SCFT and PS-FTS with real periodic fields (GPU)
PSCF package top-level namespace.