main.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #include "RBFInterp.hpp"
  2. #include "interp.hpp"
  3. #include "gnuplot.hpp"
  4. #include <iostream>
  5. #include <fstream>
  6. #include <cstdlib>
  7. #include <ctime>
  8. #include <algorithm>
  9. #include <string>
  10. #include "time_mes.hpp"
  11. using namespace std;
  12. // Fonction à interpoler z= f(x,y)
  13. float finit(int x,int y)
  14. {
  15. return 1.0+float(x*sin(float(x)/1000.)+2*y)/1000.;
  16. }
  17. int main()
  18. {
  19. typedef float real;
  20. // voir RBFInterp qui définit ceci:
  21. typedef typename RBFInterp<real,Gauss<real>>::PointValue PointValue;
  22. // qui est une map (i,j) -> valeur flottante.
  23. // taille de l'"image":
  24. int nx=1500,ny=1500;
  25. cout<<"nx : "<<nx<<" ny : "<<ny<<endl;
  26. real* image= new real[nx*ny];
  27. //
  28. std::srand(std::time(nullptr)); // seed with current time.
  29. // Choisir npts points d'interpolation au hasard dans l'image:
  30. PointValue Values;
  31. int npts= 50;
  32. for(int i=0;i<npts;i++)
  33. {
  34. // la position des points est choisie au hasard:
  35. auto vrand=float(std::rand())/RAND_MAX;
  36. int xx= static_cast<int>(nx*vrand);
  37. vrand=float(std::rand())/RAND_MAX;
  38. int yy= static_cast<int>(ny*vrand);
  39. auto k=make_pair(xx,yy);
  40. // valeur de l'image en ce point.
  41. Values[k]=finit(xx,yy);
  42. if(xx<0 || xx>=nx||yy<0 || yy>=ny)
  43. // vérifier que les points sont dans l'image.
  44. throw "pas bien";
  45. }
  46. cout<<"nb. values : "<<Values.size()<<endl;
  47. //
  48. double start,stop;
  49. // Choisir au hasard une partie des points d'interpolation:
  50. PointValue Test,NoTest;
  51. float portion_in_Notest = 0.8;
  52. for(typename PointValue::const_iterator I =Values.begin();I!=Values.end();I++)
  53. {
  54. float srand=float(std::rand())/RAND_MAX;
  55. if(srand<portion_in_Notest)
  56. // les points qui vont servir a construire l'interpolation.
  57. NoTest[I->first]= I->second;
  58. else
  59. // les points de test.
  60. Test[I->first]= I->second;
  61. }
  62. cout<<"estimation d'epsilon. nb. noeuds : "<<NoTest.size()<<
  63. " nb. points de test : "<<Test.size()<<endl;
  64. // interpolant:
  65. // définition de la fontion radiale utilisée (MQI semble mailleure)
  66. typedef MQI<real> typRad;
  67. //typedef TPS<real> typRad;
  68. //typedef TPSD<real> typRad;
  69. //typedef Gauss<real> typRad;
  70. real eps= 0.01; //valeur de départ de eps
  71. RBFInterp<real,typRad> A(eps);
  72. //
  73. // Les points d'interpolation :
  74. gnuplotfile(Values,"points");
  75. //--1--- Détermination d'un epsilon "optimal" pour les RBF.
  76. ofstream f; f.open("eps-err");//pour faire un graphique precision(epsilon)
  77. start=get_time();
  78. real epsm= eps;
  79. real testm=1.e10;
  80. for(int iter=1;iter<=25;iter++)
  81. {
  82. eps/= 10.;
  83. A.set_epsilon(eps);
  84. bool calcond=false; // on calcule ou non le conditionnement (pas
  85. // nécessaire, mais peut donner un renseignement
  86. //utile).
  87. try{
  88. A.build(NoTest,calcond);
  89. }
  90. catch(const LapackException& e )
  91. {
  92. // normalement quand epsilon est trop petit, le système
  93. // linéaire devient singulier. On va donc à priori venir ici
  94. // plutôt que de sortir normalement de la boucle.
  95. eps*=10;
  96. break;
  97. }
  98. auto test= A.test(NoTest,Test);
  99. cout<<"eps : "<<eps<<" test : "<<test<<endl;;
  100. if(calcond)
  101. {
  102. // si on a calculé le conditionnement de la matrice:
  103. if(calcond)
  104. {
  105. cout<<"norme matrice : "<<A.norme_matrice();
  106. cout<<" conditionnement : "<<A.cond();
  107. cout<<endl<<endl;
  108. }
  109. }
  110. f<<eps<<" "<<test<<endl;
  111. if(test<testm)
  112. {
  113. // la valeur epsm choisie pour eps.
  114. testm=test;
  115. epsm=eps;
  116. }
  117. }
  118. stop=get_time();
  119. cout <<endl<<"--- (durée) preparation :"<<stop-start<<endl;
  120. f.close();
  121. cout<<"estimation ok"<<endl;
  122. // On introduit le meilleur epsilon trouvé pour les RBF.
  123. A.set_epsilon(epsm);
  124. cout<<"epsilon utilisé : "<<epsm<<endl;
  125. start=get_time();
  126. try{
  127. A.build(Values);
  128. }
  129. catch(const LapackException& e )
  130. {
  131. cout<<"Lapack exception: "<<e.Info<<endl;
  132. throw;
  133. }
  134. stop=get_time();
  135. cout <<"--- (durée) interpolants : "<<stop-start<<endl;
  136. cout<<"interpolants ok"<<endl;
  137. start=get_time();
  138. int increment=2; //pour utiliser l'interp. quad.
  139. // (un point sur 2 dans caque direction)
  140. A.Interp(Values,image,nx,ny,increment);
  141. stop=get_time();
  142. cout <<"--- (durée) interpolation RBF : "<<stop-start<<endl;
  143. cout<<"interpolation ok"<<endl;
  144. start=get_time();
  145. interpquad(image,nx,ny);
  146. stop=get_time();
  147. cout <<"--- (durée) interpolation quad. : "<<stop-start<<endl;
  148. string aa;cout<<"g pour sortir les fichiers gnuplot ou tout autre caractère sinon :";
  149. cin>>aa;
  150. if(aa=="g")
  151. {
  152. cout<<"gnuplot files :"<<endl;
  153. // interpolation sur la grille :
  154. gnuplotfile(nx,ny,image,"image");
  155. }
  156. // erreur :
  157. real* diff= new real[nx*ny];
  158. auto indix= [nx](int i, int j){return i*nx+j;};
  159. #pragma omp parallel for
  160. for(int i=0;i<nx;i++)
  161. for(int j=0;j<ny;j++)
  162. diff[indix(i,j)]=abs(image[indix(i,j)]-finit(i,j));
  163. if(aa=="g")
  164. gnuplotfile(nx,ny,diff,"diff");
  165. // erreur relative
  166. #pragma omp parallel for
  167. for(int i=0;i<nx;i++)
  168. for(int j=0;j<ny;j++)
  169. diff[indix(i,j)]/=abs(finit(i,j));
  170. auto vmax=max_element(diff,diff+nx*ny);
  171. cout<<"erreur relative max: "<<*vmax<<endl;
  172. if(aa=="g")
  173. gnuplotfile(nx,ny,diff,"diffrel");
  174. delete[] image;
  175. delete[] diff;
  176. }