RBFInterp.hpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #include <vector>
  2. #include <iostream>
  3. #include <map>
  4. #include <cmath>
  5. #include "protos_lapack.hpp"
  6. #include "RBF_functions.hpp"
  7. using namespace std;
  8. using namespace lapack_c;
  9. #ifndef RBFInterp__hpp
  10. #define RBFInterp__hpp
  11. //-------------------------------------------------------------------------
  12. // Interpolation rbf.
  13. //-------------------------------------------------------------------------
  14. template<typename real,class Radial> class RBFInterp{
  15. // on met tout dans une classe paramétrée par le type de flottants,
  16. // et par Radial qui est le type (cllase) de fonction d'interpolation.
  17. double * matrix , *b; // pour le système linaéire
  18. real * image; // l'"image"
  19. real eps;
  20. int imgsize_x,imgsize_y; // tailles
  21. // work sert à lapack (tableau de travail).
  22. double rcond,innorm; double *work; bool rcondok;
  23. public:
  24. // Ce qui définit les points de mesure
  25. // (x,y) -> valeur
  26. // NB:
  27. // Il ne faut pas d'ex-aequo pour (x,y), d'où l'utilisation d'une map
  28. // pour décrire les noeuds de l'interpolation.
  29. //Clé de la map: paire d'entiers (ordonnée lexicographiquement)
  30. // Les clés sont uniques.
  31. typedef std::map<pair<int,int>,real> PointValue;
  32. // consructeur:
  33. RBFInterp(real epsilon):eps(epsilon){}
  34. //changer epsilon
  35. void set_epsilon(real epsilon){eps=epsilon;}
  36. // Calcule les RBF.
  37. void build(PointValue& values,bool comp_cond= false)
  38. // comp_cond : on calcule ou non le conditionnement L\infinity de
  39. // la matrice.
  40. {
  41. Radial R(eps); // fontion d'interpolation, à eps fixé.
  42. int np= values.size();
  43. matrix = new double[np*np];
  44. b= new double[np];
  45. // construire le système linéaire (on n'utilise pas la symétrie de la
  46. // matrice).
  47. // 1) matrice:
  48. auto indix= [np](int i, int j){return i*np+j;};//lambda function.
  49. int i=0;
  50. for(typename PointValue::const_iterator I=values.begin();
  51. I!=values.end();I++) //parcourir la map. Boucle sur les fonctions
  52. {
  53. auto v= I->first; //la fonction est centrée en ce point (paire i,j)
  54. int j=i;
  55. for(typename PointValue::const_iterator J=I;J!=values.end();++J)
  56. // boucle sur les points
  57. {
  58. //calculer la distance r entre le point I et le point J
  59. // et y applique la fonction radiale ->s.
  60. auto s = R(r(v,J->first));
  61. matrix[indix(i,j)]=s;
  62. if(i!=j)//matrice symétrique
  63. matrix[indix(j,i)] =s;
  64. ++j;
  65. }
  66. ++i;
  67. }
  68. //2) second membre.
  69. i=0;
  70. for(typename PointValue::iterator I= values.begin();I!=values.end();I++)
  71. b[i++]=I->second;
  72. // resolution (lapack + blas).
  73. int* ipiv= new int[np];
  74. int nrhs=1,lda=np,ldb=np,info;
  75. char trans='N'; char norm='I';
  76. //
  77. rcondok = comp_cond;
  78. if(comp_cond)
  79. {
  80. // norme oo de la matrice:
  81. work= new double[4*np];
  82. // norme infinie de la matrice:
  83. innorm= dlange_(&norm,&np,&np,matrix,&lda,work);
  84. }
  85. // factorisation PLU:
  86. dgetrf_(&np,&np,matrix,&lda,ipiv,&info);
  87. if(info != 0) throw LapackException(info);
  88. if(comp_cond)
  89. {
  90. int *iwork=new int[np];
  91. dgecon_(&norm,&np,matrix,&lda,&innorm,&rcond,work,iwork,&info);
  92. rcond= 1.0/rcond;
  93. if(info != 0) throw LapackException(info);
  94. delete[] work; delete[] iwork;
  95. }
  96. // résolution du système factorisé:
  97. dgetrs_(&trans,&np,&nrhs,matrix,&lda,ipiv,b,&ldb,&info);
  98. if(info != 0) throw LapackException(info);
  99. //
  100. delete[] matrix;
  101. delete[] ipiv;
  102. }
  103. real test(const PointValue& interpvalues,const PointValue& testvalues)
  104. {
  105. // Estimer l'erreur.
  106. // interpvalues: les noeuds d'interpolation
  107. // testvalues : les noeuds test.
  108. Radial R(eps);
  109. real ret=0.0;
  110. // boucle sur les noeuds test :
  111. for(typename PointValue::const_iterator K=testvalues.begin();
  112. K!=testvalues.end();K++)
  113. {
  114. int k=0;
  115. real u=0.0;
  116. for(typename PointValue::const_iterator L=interpvalues.begin();
  117. L!=interpvalues.end();L++)
  118. {
  119. auto dist= r(K->first,L->first);
  120. u+=b[k++]*R(dist);
  121. }
  122. ret=max(ret,abs(K->second-u));
  123. }
  124. return ret;
  125. }
  126. real Interpolator(PointValue& values,int nx,int i,int j,Radial R)
  127. {
  128. real s=0.0;
  129. int k=0;
  130. for(typename PointValue::const_iterator K=values.begin();
  131. K!=values.end();K++)
  132. s+=b[k++]* R(r(K->first,make_pair(i,j)));
  133. return s;
  134. }
  135. // interpolation sur une image nx par ny
  136. void Interp(PointValue& values,real *image,int nx,int ny,int increment=1)
  137. {
  138. Radial R(eps);
  139. auto indix= [nx](int i, int j){return i*nx+j;};
  140. #pragma omp parallel for
  141. for(int i=0;i<ny;i+=increment)
  142. for(int j=0;j<nx;j+=increment)
  143. {
  144. // real s=0.0;
  145. // int k=0;
  146. // for(typename PointValue::const_iterator K=values.begin();
  147. // K!=values.end();K++)
  148. // s+=b[k++]* R(r(K->first,make_pair(i,j)));
  149. //image[indix(i,j)] = s;
  150. image[indix(i,j)] = Interpolator(values,nx,i,j,R);
  151. }
  152. if(increment == 2)
  153. {
  154. for(int i=ny-3;i<ny;i++)
  155. for(int j=0;j<nx;j++)
  156. image[indix(i,j)] = Interpolator(values,nx,i,j,R);
  157. for(int j=ny-3;j<nx;j++)
  158. for(int i=0;i<ny;i++)
  159. image[indix(i,j)] = Interpolator(values,nx,i,j,R);
  160. }
  161. }
  162. ~RBFInterp()
  163. {
  164. delete[] b;
  165. }
  166. double cond() const
  167. {
  168. // accesseur pour le conditionnement
  169. if(rcondok)
  170. return rcond;
  171. else
  172. throw "rcond pas calculé";
  173. }
  174. double norme_matrice() const
  175. {
  176. if(rcondok)
  177. return innorm;
  178. else
  179. throw "rcond pas calculé";
  180. }
  181. private:
  182. // distance euclidienne de 2 points .
  183. inline float r(pair<int,int> x1,pair<int,int> x2){return sqrt(r2(x1,x2));}
  184. // distance euclidienne de 2 points au carré.
  185. inline float r2(pair<int,int> x1,pair<int,int> x2)
  186. {return pow(get<0>(x1)-get<0>(x2),2) + pow(get<1>(x1)-get<1>(x2),2);}
  187. };
  188. #endif