adaptation.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import numpy as np
  2. from numpy import concatenate as cat
  3. from copy import copy
  4. import matplotlib.pyplot as plt
  5. import warnings
  6. from .preprocess import shape, discretization, boundaryCondition
  7. plt.rc('text', usetex=True)
  8. plt.rc('font', family='serif')
  9. # supress the deprecation warning
  10. warnings.filterwarnings("ignore", ".*GUI is implemented.*")
  11. class hdpg1d(object):
  12. """
  13. 1D HDG solver
  14. """
  15. def __init__(self, coeff):
  16. self.numEle = coeff.numEle
  17. self.numBasisFuncs = coeff.pOrder + 1
  18. self.coeff = coeff
  19. self.mesh = np.linspace(0, 1, self.numEle + 1)
  20. self.primalSoln = None
  21. self.adjointSoln = None
  22. self.estErrorList = [[], []]
  23. self.trueErrorList = [[], []]
  24. def separateSoln(self, soln):
  25. """Separate gradState (q and u), stateFace from the given soln"""
  26. gradState, stateFace = np.split(
  27. soln, [len(soln) - self.numEle + 1])
  28. return gradState, stateFace
  29. def plotState(self, counter):
  30. """Plot solution u with smooth higher oredr quadrature"""
  31. uSmooth = np.array([])
  32. uNode = np.zeros(self.numEle + 1)
  33. xSmooth = np.array([])
  34. gradState, _ = self.separateSoln(self.primalSoln)
  35. halfLenState = int(len(gradState) / 2)
  36. state = gradState[halfLenState:2 * halfLenState]
  37. # quadrature rule
  38. gorder = 10 * self.numBasisFuncs
  39. xi, wi = np.polynomial.legendre.leggauss(gorder)
  40. shp, shpx = shape(xi, self.numBasisFuncs)
  41. for j in range(1, self.numEle + 1):
  42. xSmooth = np.hstack((xSmooth, (self.mesh[(j - 1)] + self.mesh[j]) / 2 + (
  43. self.mesh[j] - self.mesh[j - 1]) / 2 * xi))
  44. uSmooth = np.hstack(
  45. (uSmooth, shp.T.dot(state[(j - 1) * self.numBasisFuncs:j * self.numBasisFuncs])))
  46. uNode[j - 1] = state[(j - 1) * self.numBasisFuncs]
  47. uNode[-1] = state[-1]
  48. plt.figure(1)
  49. plt.plot(xSmooth, uSmooth, '-', color='C3')
  50. plt.plot(self.mesh, uNode, 'C3.')
  51. plt.xlabel('$x$', fontsize=17)
  52. plt.ylabel('$u$', fontsize=17)
  53. # plt.axis([-0.05, 1.05, 0, 1.3])
  54. plt.grid()
  55. plt.pause(5e-1)
  56. plt.clf()
  57. def meshAdapt(self, index):
  58. """Given the index list, adapt the mesh"""
  59. inValue = np.zeros(len(index))
  60. for i in np.arange(len(index)):
  61. inValue[i] = (self.mesh[index[i]] +
  62. self.mesh[index[i] - 1]) / 2
  63. self.mesh = np.sort(np.insert(self.mesh, 0, inValue))
  64. def solveLocal(self):
  65. """Solve the primal problem"""
  66. if 'matLocal' in locals():
  67. # if matLocal exists,
  68. # only change the mesh instead of initializing again
  69. matLocal.mesh = self.mesh
  70. else:
  71. matLocal = discretization(self.coeff, self.mesh)
  72. matGroup = matLocal.matGroup()
  73. A, B, _, C, D, E, F, G, H, L, R = matGroup
  74. # solve
  75. K = -cat((C.T, G), axis=1)\
  76. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))
  77. .dot(cat((C, E)))) + H
  78. F_hat = np.array([L]).T - cat((C.T, G), axis=1)\
  79. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]])))\
  80. .dot(np.array([cat((R, F))]).T)
  81. stateFace = np.linalg.solve(K, F_hat)
  82. gradState = np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))\
  83. .dot(np.array([np.concatenate((R, F))]).T -
  84. cat((C, E)).dot(stateFace))
  85. self.primalSoln = cat((gradState.A1, stateFace.A1))
  86. def solveAdjoint(self):
  87. """Solve the adjoint problem"""
  88. # solve in the enriched space
  89. _coeff = copy(self.coeff)
  90. _coeff.pOrder = _coeff.pOrder + 1
  91. if 'matAdjoint' in locals():
  92. matAdjoint.mesh = self.mesh
  93. else:
  94. matAdjoint = discretization(_coeff, self.mesh)
  95. matGroup = matAdjoint.matGroup()
  96. A, B, _, C, D, E, F, G, H, L, R = matGroup
  97. # add adjoint LHS conditions
  98. F = np.zeros(len(F))
  99. R[-1] = -boundaryCondition(1)[1]
  100. # assemble global matrix LHS
  101. LHS = np.bmat([[A, -B, C],
  102. [B.T, D, E],
  103. [C.T, G, H]])
  104. # solve
  105. soln = np.linalg.solve(LHS.T, cat((R, F, L)))
  106. self.adjointSoln = soln
  107. def residual(self):
  108. enrich = 1
  109. if 'matResidual' in locals():
  110. matResidual.mesh = self.mesh
  111. else:
  112. matResidual = discretization(self.coeff, self.mesh, enrich)
  113. matGroup = matResidual.matGroup()
  114. A, B, BonU, C, D, E, F, G, H, L, R = matGroup
  115. LHS = np.bmat([[A, -B, C],
  116. [BonU, D, E]])
  117. RHS = cat((R, F))
  118. residual = np.zeros(self.numEle)
  119. numEnrich = self.numBasisFuncs + enrich
  120. primalGradState, primalStateFace = self.separateSoln(self.primalSoln)
  121. adjointGradState, adjointStateFace = self.separateSoln(
  122. self.adjointSoln)
  123. for i in np.arange(self.numEle):
  124. primalResidual = (LHS.dot(self.primalSoln) - RHS).A1
  125. uLength = self.numEle * numEnrich
  126. stepLength = i * numEnrich
  127. uDWR = primalResidual[stepLength:stepLength + numEnrich].dot(
  128. (1 - adjointGradState)[stepLength:stepLength + numEnrich])
  129. qDWR = primalResidual[uLength + stepLength:uLength +
  130. stepLength + numEnrich]\
  131. .dot((1 - adjointGradState)[uLength + stepLength:uLength +
  132. stepLength + numEnrich])
  133. residual[i] = uDWR + qDWR
  134. # sort residual index
  135. residualIndex = np.argsort(np.abs(residual))
  136. # select top \theta% elements with the largest error
  137. theta = 0.15
  138. refineIndex = residualIndex[
  139. int(self.numEle * (1 - theta)):len(residual)] + 1
  140. return np.abs(np.sum(residual)), refineIndex
  141. def adaptive(self):
  142. tol = 1e-10
  143. estError = 10
  144. counter = 0
  145. ceilCounter = 30
  146. while estError > tol and counter < ceilCounter:
  147. print("Iteration {}. Target function error {:.3e}.".format(
  148. counter, estError))
  149. # solve
  150. self.solveLocal()
  151. self.solveAdjoint()
  152. # plot the solution at certain counter
  153. if counter in [0, 4, 9, 19]:
  154. self.plotState(counter)
  155. # record error
  156. self.trueErrorList[0].append(self.numEle)
  157. self.trueErrorList[1].append(
  158. self.primalSoln[self.numEle * self.numBasisFuncs - 1])
  159. estError, index = self.residual()
  160. self.estErrorList[0].append(self.numEle)
  161. self.estErrorList[1].append(estError)
  162. # adapt
  163. index = index.tolist()
  164. self.meshAdapt(index)
  165. self.numEle = self.numEle + len(index)
  166. counter += 1