adaptation.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import numpy as np
  2. from numpy import concatenate as cat
  3. import matplotlib.pyplot as plt
  4. import warnings
  5. from .preprocess import shape, discretization, boundaryCondition
  6. plt.rc('text', usetex=True)
  7. plt.rc('font', family='serif')
  8. # supress the deprecation warning
  9. warnings.filterwarnings("ignore", ".*GUI is implemented.*")
  10. class hdpg1d(object):
  11. """
  12. 1D HDG solver
  13. """
  14. def __init__(self, coeff):
  15. self.numEle = coeff.numEle
  16. self.numBasisFuncs = coeff.pOrder + 1
  17. self.tau_pos = coeff.tauPlus
  18. self.tau_neg = coeff.tauMinus
  19. self.c = coeff.convection
  20. self.kappa = coeff.diffusion
  21. self.coeff = coeff
  22. self.mesh = np.linspace(0, 1, self.numEle + 1)
  23. self.u = []
  24. self.estErrorList = [[], []]
  25. self.trueErrorList = [[], []]
  26. def plotU(self, counter):
  27. """Plot solution u with smooth higher oredr quadrature"""
  28. uSmooth = np.array([])
  29. uNode = np.zeros(self.numEle + 1)
  30. xSmooth = np.array([])
  31. u = self.u[int(len(self.u) / 2):len(self.u)]
  32. # quadrature rule
  33. gorder = 10 * self.numBasisFuncs
  34. xi, wi = np.polynomial.legendre.leggauss(gorder)
  35. shp, shpx = shape(xi, self.numBasisFuncs)
  36. for j in range(1, self.numEle + 1):
  37. xSmooth = np.hstack((xSmooth, (self.mesh[(j - 1)] + self.mesh[j]) / 2 + (
  38. self.mesh[j] - self.mesh[j - 1]) / 2 * xi))
  39. uSmooth = np.hstack(
  40. (uSmooth, shp.T.dot(u[(j - 1) * self.numBasisFuncs:j * self.numBasisFuncs])))
  41. uNode[j - 1] = u[(j - 1) * self.numBasisFuncs]
  42. uNode[-1] = u[-1]
  43. plt.figure(1)
  44. plt.plot(xSmooth, uSmooth, '-', color='C3')
  45. plt.plot(self.mesh, uNode, 'C3.')
  46. plt.xlabel('$x$', fontsize=17)
  47. plt.ylabel('$u$', fontsize=17)
  48. # plt.axis([-0.05, 1.05, 0, 1.3])
  49. plt.grid()
  50. plt.pause(5e-1)
  51. plt.clf()
  52. def meshAdapt(self, index):
  53. """Given the index list, adapt the mesh"""
  54. inValue = np.zeros(len(index))
  55. for i in np.arange(len(index)):
  56. inValue[i] = (self.mesh[index[i]] +
  57. self.mesh[index[i] - 1]) / 2
  58. self.mesh = np.sort(np.insert(self.mesh, 0, inValue))
  59. def solveLocal(self):
  60. """Solve the primal problem"""
  61. if 'matLocal' in locals():
  62. # if matLocal exists,
  63. # only change the mesh instead of initializing again
  64. matLocal.mesh = self.mesh
  65. else:
  66. matLocal = discretization(self.coeff, self.mesh)
  67. matGroup = matLocal.matGroup()
  68. A, B, _, C, D, E, F, G, H, L, R = matGroup
  69. # solve
  70. K = -cat((C.T, G), axis=1)\
  71. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))
  72. .dot(cat((C, E)))) + H
  73. F_hat = np.array([L]).T - cat((C.T, G), axis=1)\
  74. .dot(np.linalg.inv(np.bmat([[A, -B], [B.T, D]])))\
  75. .dot(np.array([cat((R, F))]).T)
  76. uFace = np.linalg.solve(K, F_hat)
  77. u = np.linalg.inv(np.bmat([[A, -B], [B.T, D]]))\
  78. .dot(np.array([np.concatenate((R, F))]).T -
  79. cat((C, E)).dot(uFace))
  80. return u.A1, uFace.A1
  81. def solveAdjoint(self):
  82. """Solve the adjoint problem"""
  83. # solve in the enriched space
  84. self.coeff.pOrder += 1
  85. if 'matAdjoint' in locals():
  86. matAdjoint.mesh = self.mesh
  87. else:
  88. matAdjoint = discretization(self.coeff, self.mesh)
  89. self.coeff.pOrder = self.coeff.pOrder - 1
  90. matGroup = matAdjoint.matGroup()
  91. A, B, _, C, D, E, F, G, H, L, R = matGroup
  92. # add adjoint LHS conditions
  93. F = np.zeros(len(F))
  94. R[-1] = -boundaryCondition(1)[1]
  95. # assemble global matrix LHS
  96. LHS = np.bmat([[A, -B, C],
  97. [B.T, D, E],
  98. [C.T, G, H]])
  99. # solve
  100. U = np.linalg.solve(LHS.T, cat((R, F, L)))
  101. return U[0:2 * len(C)], U[len(C):len(U)]
  102. def residual(self, U, hat_U, z, hat_z):
  103. enrich = 1
  104. if 'matResidual' in locals():
  105. matResidual.mesh = self.mesh
  106. else:
  107. matResidual = discretization(self.coeff, self.mesh, enrich)
  108. matGroup = matResidual.matGroup()
  109. A, B, BonU, C, D, E, F, G, H, L, R = matGroup
  110. LHS = np.bmat([[A, -B, C],
  111. [BonU, D, E]])
  112. RHS = cat((R, F))
  113. residual = np.zeros(self.numEle)
  114. numEnrich = self.numBasisFuncs + enrich
  115. for i in np.arange(self.numEle):
  116. primalResidual = (LHS.dot(cat((U, hat_U))) - RHS).A1
  117. uLength = self.numEle * numEnrich
  118. stepLength = i * numEnrich
  119. uDWR = primalResidual[stepLength:stepLength + numEnrich].dot(
  120. (1 - z)[stepLength:stepLength + numEnrich])
  121. qDWR = primalResidual[uLength + stepLength:uLength +
  122. stepLength + numEnrich]\
  123. .dot((1 - z)[uLength + stepLength:uLength +
  124. stepLength + numEnrich])
  125. residual[i] = uDWR + qDWR
  126. # sort residual index
  127. com_index = np.argsort(np.abs(residual))
  128. # select \theta% elements with the large error
  129. theta = 0.15
  130. refine_index = com_index[
  131. int(self.numEle * (1 - theta)):len(residual)] + 1
  132. return np.abs(np.sum(residual)), refine_index
  133. def adaptive(self):
  134. tol = 1e-10
  135. estError = 10
  136. counter = 0
  137. ceilCounter = 30
  138. while estError > tol and counter < ceilCounter:
  139. print("Iteration {}. Target function error {:.3e}.".format(
  140. counter, estError))
  141. # solve
  142. u, uFace = self.solveLocal()
  143. adjoint, adjointFace = self.solveAdjoint()
  144. self.u = u
  145. # plot the solution at certain counter
  146. if counter in [0, 4, 9, 19]:
  147. self.plotU(counter)
  148. # record error
  149. self.trueErrorList[0].append(self.numEle)
  150. self.trueErrorList[1].append(
  151. u[self.numEle * self.numBasisFuncs - 1])
  152. estError, index = self.residual(u, uFace, adjoint, adjointFace)
  153. self.estErrorList[0].append(self.numEle)
  154. self.estErrorList[1].append(estError)
  155. # adapt
  156. index = index.tolist()
  157. self.meshAdapt(index)
  158. self.numEle = self.numEle + len(index)
  159. counter += 1