adaptation.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import numpy as np
  2. from numpy import concatenate as cat
  3. from copy import copy
  4. import matplotlib.pyplot as plt
  5. import warnings
  6. from .preprocess import shape, discretization, boundaryCondition
  7. plt.rc('text', usetex=True)
  8. plt.rc('font', family='serif')
  9. # supress the deprecation warning
  10. warnings.filterwarnings("ignore", ".*GUI is implemented.*")
  11. class hdpg1d(object):
  12. """
  13. 1D HDG solver
  14. """
  15. def __init__(self, coeff):
  16. self.numEle = coeff.numEle
  17. self.numBasisFuncs = coeff.pOrder + 1
  18. self.coeff = coeff
  19. self.mesh = np.linspace(0, 1, self.numEle + 1)
  20. self.enrichOrder = 1
  21. self.primalSoln = None
  22. self.adjointSoln = None
  23. self.estErrorList = [[], []]
  24. self.trueErrorList = [[], []]
  25. def separateSoln(self, soln):
  26. """Separate gradState (q and u), stateFace from the given soln"""
  27. gradState, stateFace = np.split(
  28. soln, [len(soln) - self.numEle + 1])
  29. return gradState, stateFace
  30. def plotState(self, counter):
  31. """Plot solution u with smooth higher oredr quadrature"""
  32. stateSmooth = np.array([])
  33. stateNode = np.zeros(self.numEle + 1)
  34. xSmooth = np.array([])
  35. gradState, _ = self.separateSoln(self.primalSoln)
  36. halfLenState = int(len(gradState) / 2)
  37. state = gradState[halfLenState:2 * halfLenState]
  38. # quadrature rule
  39. gorder = 10 * self.numBasisFuncs
  40. xi, wi = np.polynomial.legendre.leggauss(gorder)
  41. shp, shpx = shape(xi, self.numBasisFuncs)
  42. for j in range(1, self.numEle + 1):
  43. xSmooth = np.hstack((xSmooth, (self.mesh[(j - 1)] + self.mesh[j]) / 2 + (
  44. self.mesh[j] - self.mesh[j - 1]) / 2 * xi))
  45. stateSmooth = np.hstack(
  46. (stateSmooth, shp.T.dot(state[(j - 1) * self.numBasisFuncs:j * self.numBasisFuncs])))
  47. stateNode[j - 1] = state[(j - 1) * self.numBasisFuncs]
  48. stateNode[-1] = state[-1]
  49. plt.figure(1)
  50. plt.plot(xSmooth, stateSmooth, '-', color='C3')
  51. plt.plot(self.mesh, stateNode, 'C3.')
  52. plt.xlabel('$x$', fontsize=17)
  53. plt.ylabel('$u$', fontsize=17)
  54. # plt.axis([-0.05, 1.05, 0, 1.3])
  55. plt.grid()
  56. plt.pause(5e-1)
  57. def meshAdapt(self, index):
  58. """Given the index list, adapt the mesh"""
  59. inValue = np.zeros(len(index))
  60. for i in np.arange(len(index)):
  61. inValue[i] = (self.mesh[index[i]] +
  62. self.mesh[index[i] - 1]) / 2
  63. self.mesh = np.sort(np.insert(self.mesh, 0, inValue))
  64. def solvePrimal(self):
  65. """Solve the primal problem"""
  66. if 'matLocal' in locals():
  67. # if matLocal exists,
  68. # only change the mesh instead of initializing again
  69. matLocal.mesh = self.mesh
  70. else:
  71. matLocal = discretization(self.coeff, self.mesh)
  72. matGroup = matLocal.matGroup()
  73. A, B, _, C, D, E, F, G, H, L, R = matGroup
  74. # solve by exploiting the local global separation
  75. K = -cat((C.T, G), axis=1)\
  76. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))
  77. .dot(cat((C, E)))) + H
  78. F_hat = np.array([L]).T - cat((C.T, G), axis=1)\
  79. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]])))\
  80. .dot(np.array([cat((R, F))]).T)
  81. stateFace = np.linalg.solve(K, F_hat)
  82. gradState = np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))\
  83. .dot(np.array([np.concatenate((R, F))]).T -
  84. cat((C, E)).dot(stateFace))
  85. self.primalSoln = cat((gradState.A1, stateFace.A1))
  86. def solveAdjoint(self):
  87. """Solve the adjoint problem"""
  88. # solve in the enriched space
  89. _coeff = copy(self.coeff)
  90. _coeff.pOrder = _coeff.pOrder + 1
  91. if 'matAdjoint' in locals():
  92. matAdjoint.mesh = self.mesh
  93. else:
  94. matAdjoint = discretization(_coeff, self.mesh)
  95. matGroup = matAdjoint.matGroup()
  96. A, B, _, C, D, E, F, G, H, L, R = matGroup
  97. # add adjoint LHS conditions
  98. F = np.zeros(len(F))
  99. R[-1] = -boundaryCondition(1)[1]
  100. # assemble global matrix LHS
  101. LHS = np.bmat([[A, -B, C],
  102. [B.T, D, E],
  103. [C.T, G, H]])
  104. # solve in one shoot
  105. soln = np.linalg.solve(LHS.T, cat((R, F, L)))
  106. self.adjointSoln = soln
  107. def DWResidual(self):
  108. if 'matResidual' in locals():
  109. matResidual.mesh = self.mesh
  110. else:
  111. matResidual = discretization(
  112. self.coeff, self.mesh, self.enrichOrder)
  113. matGroup = matResidual.matGroup()
  114. A, B, BonQ, C, D, E, F, G, H, L, R = matGroup
  115. LHS = np.bmat([[A, -B, C],
  116. [BonQ, D, E]])
  117. RHS = cat((R, F))
  118. residual = np.zeros(self.numEle)
  119. numEnrich = self.numBasisFuncs + self.enrichOrder
  120. adjointGradState, adjointStateFace = self.separateSoln(
  121. self.adjointSoln)
  122. for i in np.arange(self.numEle):
  123. primalResidual = (LHS.dot(self.primalSoln) - RHS).A1
  124. uLength = self.numEle * numEnrich
  125. stepLength = i * numEnrich
  126. uDWR = primalResidual[stepLength:stepLength + numEnrich].dot(
  127. (1 - adjointGradState)[stepLength:stepLength + numEnrich])
  128. qDWR = primalResidual[uLength + stepLength:uLength +
  129. stepLength + numEnrich]\
  130. .dot((1 - adjointGradState)[uLength + stepLength:uLength +
  131. stepLength + numEnrich])
  132. residual[i] = uDWR + qDWR
  133. # sort residual index
  134. residualIndex = np.argsort(np.abs(residual))
  135. # select top \theta% elements with the largest error
  136. theta = 0.15
  137. refineIndex = residualIndex[
  138. int(self.numEle * (1 - theta)):len(residual)] + 1
  139. return np.abs(np.sum(residual)), refineIndex
  140. def adaptive(self):
  141. TOL = 1e-10
  142. estError = 10
  143. nodeCount = 0
  144. maxCount = 30
  145. while estError > TOL and nodeCount < maxCount:
  146. # solve
  147. self.solvePrimal()
  148. self.solveAdjoint()
  149. # plot the solution at certain counter
  150. if nodeCount in [0, 4, 9, 19, maxCount]:
  151. plt.clf()
  152. self.plotState(nodeCount)
  153. # record error
  154. self.trueErrorList[0].append(self.numEle)
  155. self.trueErrorList[1].append(
  156. self.primalSoln[self.numEle * self.numBasisFuncs - 1])
  157. estError, index = self.DWResidual()
  158. self.estErrorList[0].append(self.numEle)
  159. self.estErrorList[1].append(estError)
  160. # adapt
  161. index = index.tolist()
  162. self.meshAdapt(index)
  163. self.numEle = self.numEle + len(index)
  164. nodeCount += 1
  165. print("Iteration {}. Estimated target function error {:.3e}."
  166. .format(nodeCount, estError))
  167. if nodeCount == maxCount:
  168. print("Max iteration number is reached "
  169. "while the convergence criterion is not satisfied.\n"
  170. "Check the problem statement or "
  171. "raise the max iteration number, then try again.\n")