adaptation.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. from numpy import concatenate as cat
  3. from scipy.sparse import csr_matrix
  4. import scipy.sparse.linalg as spla
  5. from copy import copy
  6. import matplotlib.pyplot as plt
  7. import warnings
  8. from .preprocess import shape, discretization, boundaryCondition
  9. plt.rc('text', usetex=True)
  10. plt.rc('font', family='serif')
  11. # supress the deprecation warning
  12. warnings.filterwarnings("ignore", ".*GUI is implemented.*")
  13. class hdpg1d(object):
  14. """
  15. 1D HDG solver
  16. """
  17. def __init__(self, coeff):
  18. self.numEle = coeff.numEle
  19. self.numBasisFuncs = coeff.pOrder + 1
  20. self.coeff = coeff
  21. self.mesh = np.linspace(0, 1, self.numEle + 1)
  22. self.enrichOrder = 1
  23. self.primalSoln = None
  24. self.adjointSoln = None
  25. self.estErrorList = [[], []]
  26. self.trueErrorList = [[], []]
  27. def separateSoln(self, soln):
  28. """Separate gradState (q and u), stateFace from the given soln"""
  29. gradState, stateFace = np.split(
  30. soln, [len(soln) - self.numEle + 1])
  31. return gradState, stateFace
  32. def plotState(self, counter):
  33. """Plot solution u with smooth higher oredr quadrature"""
  34. stateSmooth = np.array([])
  35. stateNode = np.zeros(self.numEle + 1)
  36. xSmooth = np.array([])
  37. gradState, _ = self.separateSoln(self.primalSoln)
  38. halfLenState = int(len(gradState) / 2)
  39. state = gradState[halfLenState:2 * halfLenState]
  40. # quadrature rule
  41. gorder = 10 * self.numBasisFuncs
  42. xi, wi = np.polynomial.legendre.leggauss(gorder)
  43. shp, shpx = shape(xi, self.numBasisFuncs)
  44. for j in range(1, self.numEle + 1):
  45. xSmooth = np.hstack((xSmooth, (self.mesh[(j - 1)] + self.mesh[j]) / 2 + (
  46. self.mesh[j] - self.mesh[j - 1]) / 2 * xi))
  47. stateSmooth = np.hstack(
  48. (stateSmooth, shp.T.dot(state[(j - 1) * self.numBasisFuncs:j * self.numBasisFuncs])))
  49. stateNode[j - 1] = state[(j - 1) * self.numBasisFuncs]
  50. stateNode[-1] = state[-1]
  51. plt.figure(1)
  52. plt.plot(xSmooth, stateSmooth, '-', color='C3')
  53. plt.plot(self.mesh, stateNode, 'C3.')
  54. plt.xlabel('$x$', fontsize=17)
  55. plt.ylabel('$u$', fontsize=17)
  56. # plt.axis([-0.05, 1.05, 0, 1.3])
  57. plt.grid()
  58. plt.pause(5e-1)
  59. def meshAdapt(self, index):
  60. """Given the index list, adapt the mesh"""
  61. inValue = np.zeros(len(index))
  62. for i in np.arange(len(index)):
  63. inValue[i] = (self.mesh[index[i]] +
  64. self.mesh[index[i] - 1]) / 2
  65. self.mesh = np.sort(np.insert(self.mesh, 0, inValue))
  66. def solvePrimal(self):
  67. """Solve the primal problem"""
  68. if 'matLocal' in locals():
  69. # if matLocal exists,
  70. # only change the mesh instead of initializing again
  71. matLocal.mesh = self.mesh
  72. else:
  73. matLocal = discretization(self.coeff, self.mesh)
  74. matGroup = matLocal.matGroup()
  75. A, B, _, C, D, E, F, G, H, L, R = matGroup
  76. # solve by exploiting the local global separation
  77. K = -cat((C.T, G), axis=1)\
  78. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))
  79. .dot(cat((C, E)))) + H
  80. sK = csr_matrix(K)
  81. F_hat = np.array([L]).T - cat((C.T, G), axis=1)\
  82. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]])))\
  83. .dot(np.array([cat((R, F))]).T)
  84. def invRHS(vec):
  85. """Construct preconditioner"""
  86. matVec = spla.spsolve(sK, vec)
  87. return matVec
  88. n = len(F_hat)
  89. preconditioner = spla.LinearOperator((n, n), invRHS)
  90. stateFace = spla.gmres(sK, F_hat, M=preconditioner)[0]
  91. # stateFace = np.linalg.solve(K, F_hat)
  92. gradState = np.linalg.inv(np.asarray(np.bmat([[A, -B], [B.T, D]]))).dot(
  93. cat((R, F)) - cat((C, E)).dot(stateFace))
  94. self.primalSoln = cat((gradState, stateFace))
  95. def solveAdjoint(self):
  96. """Solve the adjoint problem"""
  97. # solve in the enriched space
  98. _coeff = copy(self.coeff)
  99. _coeff.pOrder = _coeff.pOrder + 1
  100. if 'matAdjoint' in locals():
  101. matAdjoint.mesh = self.mesh
  102. else:
  103. matAdjoint = discretization(_coeff, self.mesh)
  104. matGroup = matAdjoint.matGroup()
  105. A, B, _, C, D, E, F, G, H, L, R = matGroup
  106. # add adjoint LHS conditions
  107. F = np.zeros(len(F))
  108. R[-1] = -boundaryCondition('adjoint')[1]
  109. # assemble global matrix LHS
  110. LHS = np.bmat([[A, -B, C],
  111. [B.T, D, E],
  112. [C.T, G, H]])
  113. sLHS = csr_matrix(LHS)
  114. RHS = cat((R, F, L))
  115. # solve in one shoot using GMRES
  116. def invRHS(vec):
  117. """Construct preconditioner"""
  118. matVec = spla.spsolve(sLHS, vec)
  119. return matVec
  120. n = len(RHS)
  121. preconditioner = spla.LinearOperator((n, n), invRHS)
  122. soln = spla.gmres(sLHS, RHS, M=preconditioner)[0]
  123. # soln = np.linalg.solve(LHS.T, RHS)
  124. self.adjointSoln = soln
  125. def DWResidual(self):
  126. if 'matResidual' in locals():
  127. matResidual.mesh = self.mesh
  128. else:
  129. matResidual = discretization(
  130. self.coeff, self.mesh, self.enrichOrder)
  131. matGroup = matResidual.matGroup()
  132. A, B, BonQ, C, D, E, F, G, H, L, R = matGroup
  133. LHS = np.bmat([[A, -B, C],
  134. [BonQ, D, E]])
  135. RHS = cat((R, F))
  136. residual = np.zeros(self.numEle)
  137. numEnrich = self.numBasisFuncs + self.enrichOrder
  138. adjointGradState, adjointStateFace = self.separateSoln(
  139. self.adjointSoln)
  140. for i in np.arange(self.numEle):
  141. primalResidual = (LHS.dot(self.primalSoln) - RHS).A1
  142. uLength = self.numEle * numEnrich
  143. stepLength = i * numEnrich
  144. uDWR = primalResidual[stepLength:stepLength + numEnrich].dot(
  145. (1 - adjointGradState)[stepLength:stepLength + numEnrich])
  146. qDWR = primalResidual[uLength + stepLength:uLength +
  147. stepLength + numEnrich]\
  148. .dot((1 - adjointGradState)[uLength + stepLength:uLength +
  149. stepLength + numEnrich])
  150. residual[i] = uDWR + qDWR
  151. # sort residual index
  152. residualIndex = np.argsort(np.abs(residual))
  153. # select top \theta% elements with the largest error
  154. theta = 0.15
  155. refineIndex = residualIndex[
  156. int(self.numEle * (1 - theta)):len(residual)] + 1
  157. return np.abs(np.sum(residual)), refineIndex
  158. def adaptive(self):
  159. TOL = self.coeff.TOL
  160. estError = 10
  161. nodeCount = 0
  162. maxCount = self.coeff.MAXIT
  163. while estError > TOL and nodeCount < maxCount:
  164. # solve
  165. self.solvePrimal()
  166. self.solveAdjoint()
  167. # plot the solution at certain counter
  168. if nodeCount in [0, 4, 9, 19, maxCount]:
  169. plt.clf()
  170. self.plotState(nodeCount)
  171. # record error
  172. self.trueErrorList[0].append(self.numEle)
  173. self.trueErrorList[1].append(
  174. self.primalSoln[self.numEle * self.numBasisFuncs - 1])
  175. estError, index = self.DWResidual()
  176. self.estErrorList[0].append(self.numEle)
  177. self.estErrorList[1].append(estError)
  178. # adapt
  179. index = index.tolist()
  180. self.meshAdapt(index)
  181. self.numEle = self.numEle + len(index)
  182. nodeCount += 1
  183. print("Iteration {}. Estimated target function error {:.3e}."
  184. .format(nodeCount, estError))
  185. if nodeCount == maxCount:
  186. print("Max iteration number is reached "
  187. "while the convergence criterion is not satisfied.\n"
  188. "Check the problem statement or "
  189. "raise the max iteration number, then try again.\n")