import numpy as np
import matplotlib.pyplot as plt
import os

# Création du dossier de résultats
output_dir = "results"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Dossier '{output_dir}' créé.")
else:
    print(f"Les images seront enregistrées dans le dossier '{output_dir}'.")

# ==============================================================================
# EXERCICE 2 : Implémentation de l'interpolation de Lagrange
# ==============================================================================

# ------------------------------------------------------------------------------
# 1. Définition des fonctions de base (Questions 1 à 6)
# ------------------------------------------------------------------------------

# Q1: Définir la fonction f1
def f1(x):
    """Fonction f1(x) = sin(x)"""
    return np.sin(x)

# Q2: Définir la subdivision régulière
def subdivision_reguliere(xmin, xmax, N):
    """
    Renvoie un tableau de N points équirépartis sur [xmin, xmax].
    Utilise np.linspace comme suggéré.
    """
    return np.linspace(xmin, xmax, N)

# Q3: Création de x_inter_1 pour N=5 sur [-3pi, 3pi]
xmin_1 = -3 * np.pi
xmax_1 = 3 * np.pi
N_1 = 5
x_inter_1 = subdivision_reguliere(xmin_1, xmax_1, N_1)

# Q4: Création de y_inter_1
y_inter_1 = f1(x_inter_1)

# Q5: Fonction poly_lagrange (Polynôme de base L_k)
def poly_lagrange(k, x, x_inter):
    """
    Calcule la valeur du k-ième polynôme de Lagrange L_k en un point x.
    L_k(x) vaut 1 en x_k et 0 aux autres noeuds.
    """
    n = len(x_inter)
    resultat = 1.0
    xk = x_inter[k]
    
    for i in range(n):
        if i != k:
            xi = x_inter[i]
            resultat = resultat * (x - xi) / (xk - xi)
    
    return resultat

# Q6: Fonction poly_interpolation (Polynôme P)
def poly_interpolation(x, x_inter, y_inter):
    """
    Calcule la valeur du polynôme d'interpolation P en un point x.
    P(x) est la somme pondérée des L_k(x) par les y_k.
    """
    n = len(x_inter)
    somme = 0.0
    for k in range(n):
        lk = poly_lagrange(k, x, x_inter)
        somme = somme + y_inter[k] * lk
    return somme

# ------------------------------------------------------------------------------
# 2. Visualisation pour N=5 (Questions 7 à 11)
# ------------------------------------------------------------------------------
print("--- Génération du graphique pour N=5 ---")

# Q7: Définir N_plot assez grand
N_plot = 200

# Q8: Générer tableau X_1
X_1 = np.linspace(xmin_1, xmax_1, N_plot)

# Q9: Générer tableau Y_1 (valeurs exactes)
Y_1 = f1(X_1)

# Q10: Générer tableau P_1 (valeurs interpolées)
P_1 = np.array([poly_interpolation(x, x_inter_1, y_inter_1) for x in X_1])

# Q11: Représentation graphique
plt.figure(figsize=(10, 6))
plt.plot(x_inter_1, y_inter_1, 'ro', label="Points d'interpolation", zorder=5) # (a)
plt.plot(X_1, Y_1, 'g-', label="f1(x) = sin(x) (Fonction)", linewidth=2)      # (b)
plt.plot(X_1, P_1, 'b-', label="P1(x) (Interpolation N=5)", linewidth=1.5)   # (c)
plt.title(f"Q11: Interpolation de Lagrange de sin(x) avec N={N_1}")
plt.legend()
plt.grid(True)

# Sauvegarde
filename = os.path.join(output_dir, "fig1_interpolation_sinus_N5.png")
plt.savefig(filename, dpi=300)
plt.close()
print(f"Image sauvegardée : {filename}")

# ------------------------------------------------------------------------------
# 3. Généralisation et étude du nombre de points (Questions 12 à 15)
# ------------------------------------------------------------------------------

# Q12: Implémenter interpolation_lagrange (fonction "wrapper")
def interpolation_lagrange(f, xmin, xmax, Ninter, Xplot):
    """
    Renvoie le triplet (x_inter, y_inter, P_vals).
    """
    x_inter = subdivision_reguliere(xmin, xmax, Ninter)
    y_inter = f(x_inter)
    P_vals = np.array([poly_interpolation(x, x_inter, y_inter) for x in Xplot])
    return x_inter, y_inter, P_vals

# Q13 & Q14: Tracés pour N = 3, 4, 7, 10, 15, 20
print("--- Génération des sous-figures pour N multiple ---")
liste_N = [3, 4, 7, 10, 15, 20]
plt.figure(figsize=(15, 10))

for i, N in enumerate(liste_N):
    xi, yi, Pi = interpolation_lagrange(f1, xmin_1, xmax_1, N, X_1)
    
    plt.subplot(2, 3, i+1)
    plt.plot(X_1, Y_1, 'g-', alpha=0.5, label='Exacte')
    plt.plot(X_1, Pi, 'b-', label=f'Poly N={N}')
    plt.plot(xi, yi, 'ro', markersize=4)
    plt.title(f"N = {N}")
    plt.ylim(-2, 2)
    plt.legend()

plt.suptitle("Q13/14: Évolution de l'interpolation selon N")
plt.tight_layout()

# Sauvegarde
filename = os.path.join(output_dir, "fig2_evolution_N.png")
plt.savefig(filename, dpi=300)
plt.close()
print(f"Image sauvegardée : {filename}")

# --- RÉPONSES THÉORIQUES (Q13, Q14, Q15) ---
# Q13/14 : Pour N petit, l'approximation est mauvaise. Elle s'améliore au centre
# avec N, mais les bords restent difficiles.
# Q15 : Visuellement, N=20 donne un résultat satisfaisant pour sin(x).

# ------------------------------------------------------------------------------
# 4. Calcul de l'erreur (Questions 16 à 19)
# ------------------------------------------------------------------------------

# Q16: Procédure error_inf
def error_inf(f, P_vals, X_test):
    f_vals = f(X_test)
    diff = np.abs(f_vals - P_vals)
    return np.max(diff)

# Q17: Évolution de l'erreur pour f1 (sinus)
print("--- Calcul et tracé des erreurs ---")
N_range = range(2, 25)
errors_f1 = []
X_test_dense = np.linspace(xmin_1, xmax_1, 2000)

for N in N_range:
    _, _, Pi = interpolation_lagrange(f1, xmin_1, xmax_1, N, X_test_dense)
    err = error_inf(f1, Pi, X_test_dense)
    errors_f1.append(err)

# Q18: Fonction f2 et calcul d'erreur
def f2(x):
    return 1 / (1 + 25 * x**2)

xmin_2, xmax_2 = -1, 1
X_test_2 = np.linspace(xmin_2, xmax_2, 2000)
errors_f2 = []

for N in N_range:
    _, _, Pi = interpolation_lagrange(f2, xmin_2, xmax_2, N, X_test_2)
    err = error_inf(f2, Pi, X_test_2)
    errors_f2.append(err)

# Tracé des erreurs en échelle logarithmique
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.semilogy(N_range, errors_f1, 'b-o')
plt.title("Erreur L-infini : f1(x)=sin(x)")
plt.xlabel("Nombre de points N")
plt.ylabel("Erreur (log)")
plt.grid(True, which="both")

plt.subplot(1, 2, 2)
plt.semilogy(N_range, errors_f2, 'r-o')
plt.title("Erreur L-infini : f2(x) (Runge)")
plt.xlabel("Nombre de points N")
plt.grid(True, which="both")

plt.suptitle("Q17/18: Comparaison de l'erreur d'interpolation")

# Sauvegarde
filename = os.path.join(output_dir, "fig3_erreurs_log.png")
plt.savefig(filename, dpi=300)
plt.close()
print(f"Image sauvegardée : {filename}")

# Q19: Tracer f2 et son polynôme (ex: N=15) pour voir le problème
print("--- Génération du phénomène de Runge ---")
N_runge = 15
xi_r, yi_r, Pi_r = interpolation_lagrange(f2, xmin_2, xmax_2, N_runge, X_test_2)

plt.figure(figsize=(8, 6))
plt.plot(X_test_2, f2(X_test_2), 'g-', label="f2(x) (Runge)")
plt.plot(X_test_2, Pi_r, 'b-', label=f"Interpolation N={N_runge}")
plt.plot(xi_r, yi_r, 'ro', label="Points")
plt.title(f"Q19: Phénomène de Runge (N={N_runge})")
plt.legend()
plt.ylim(-1, 2)

# Sauvegarde
filename = os.path.join(output_dir, "fig4_phenomene_runge.png")
plt.savefig(filename, dpi=300)
plt.close()
print(f"Image sauvegardée : {filename}")

# --- RÉPONSES THÉORIQUES (Q19) ---
# Q19 : L'erreur augmente avec N pour f2. Le polynôme oscille aux bords.
# C'est le phénomène de Runge.