import numpy as np
import matplotlib.pyplot as plt
import time
import os
import math

# Création du dossier pour les résultats
output_dir = "results"
os.makedirs(output_dir, exist_ok=True)
print(f"Les figures seront enregistrées dans le dossier : {output_dir}\n")

# ==========================================
# 1. Méthode de quadrature d'ordre p=0 (Rectangles)
# ==========================================
print("--- 1. Méthode des Rectangles ---")

# 1.1 Implémentation [cite: 6]
def rectangle_gauche(f, a, b, n):
    h = (b - a) / n
    x = np.linspace(a, b, n+1)[:-1] # On prend les bornes gauches
    return h * np.sum(f(x))

def rectangle_droite(f, a, b, n):
    h = (b - a) / n
    x = np.linspace(a, b, n+1)[1:] # On prend les bornes droites
    return h * np.sum(f(x))

def point_milieu(f, a, b, n):
    h = (b - a) / n
    x_milieu = np.linspace(a, b, n+1)[:-1] + h/2
    return h * np.sum(f(x_milieu))

# Fonction cible pour les tests numériques
def f_target(x):
    return np.sqrt(x) * np.exp(x**2)

# 1.2 Calcul pour n = 10, 100, 1000 [cite: 7]
ns = [10, 100, 1000]
print("\nApproximations pour f(x) = sqrt(x)exp(x^2) :")
for n in ns:
    val_g = rectangle_gauche(f_target, 0, 1, n)
    val_d = rectangle_droite(f_target, 0, 1, n)
    val_m = point_milieu(f_target, 0, 1, n)
    print(f"n={n:4d} | G: {val_g:.6f} | D: {val_d:.6f} | M: {val_m:.6f}")

# 1.3 Représentation graphique de l'évolution [cite: 8]
n_vals = range(10, 1001, 10)
vals_g = [rectangle_gauche(f_target, 0, 1, n) for n in n_vals]
vals_d = [rectangle_droite(f_target, 0, 1, n) for n in n_vals]
vals_m = [point_milieu(f_target, 0, 1, n) for n in n_vals]

plt.figure(figsize=(10, 6))
plt.plot(n_vals, vals_g, label='Rectangle Gauche')
plt.plot(n_vals, vals_d, label='Rectangle Droite')
plt.plot(n_vals, vals_m, label='Point Milieu')
plt.xlabel('n')
plt.ylabel('Valeur approchée')
plt.title('Convergence des méthodes (Rectangles)')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, '1_3_convergence_rectangles.png'))
plt.close()

# 1.4 Définition des fonctions tests f0 et f(x)=x^2 [cite: 9]
def f0(x): return np.full_like(x, 4.0)
def F0(x): return 4.0 * x

def f_sq(x): return x**2
def F_sq(x): return (x**3) / 3.0

# 1.5 Fonction erreur de quadrature [cite: 10]
def get_quadrature_error(method, f, F, a, b, n):
    I_approx = method(f, a, b, n)
    I_exact = F(b) - F(a)
    return abs(I_exact - I_approx)

# 1.6 Fonction plotError [cite: 11]
def plotError(funcs_dict, f, F, a, b, title_suffix):
    ns_error = np.unique(np.logspace(1, 3, 20).astype(int)) # De 10 à 1000
    plt.figure(figsize=(10, 6))
    
    for name, method in funcs_dict.items():
        errors = [get_quadrature_error(method, f, F, a, b, n) for n in ns_error]
        plt.loglog(ns_error, errors, '.-', label=name) # Echelle log-log souvent préférée pour l'erreur
        
    plt.xlabel('n (nombre de points)')
    plt.ylabel('Erreur absolue $|I_{ex} - I_n|$')
    plt.title(f'Erreur de quadrature : {title_suffix}')
    plt.legend()
    plt.grid(True, which="both", ls="-")
    filename = f"error_{title_suffix.replace(' ', '_').replace('(', '').replace(')', '')}.png"
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

# 1.7 Tracer l'erreur pour f0 et f_sq [cite: 12]
methods_p0 = {
    'Rect Gauche': rectangle_gauche,
    'Rect Droite': rectangle_droite,
    'Pt Milieu': point_milieu
}
plotError(methods_p0, f0, F0, 0, 1, "f0(x)=4")
plotError(methods_p0, f_sq, F_sq, 0, 1, "f(x)=x^2")


# ==========================================
# 2. Méthode de quadrature d'ordre p=1 (Trapèzes)
# ==========================================
print("\n--- 2. Méthode des Trapèzes ---")

# 2.1 Implémentation [cite: 14]
def trapeze(f, a, b, n):
    h = (b - a) / n
    x = np.linspace(a, b, n+1)
    y = f(x)
    # Formule: h * ( (y0 + yn)/2 + sum(y1...yn-1) )
    return h * (0.5*y[0] + 0.5*y[-1] + np.sum(y[1:-1]))

# 2.2 Calcul pour la cible [cite: 15, 16]
print("Approximations (Trapèzes) pour f(x) = sqrt(x)exp(x^2) :")
for n in ns:
    val_t = trapeze(f_target, 0, 1, n)
    print(f"n={n:4d} | Trapèzes: {val_t:.6f}")

# 2.3 Fonctions f1 et cos(x) [cite: 17]
def f1(x): return 3*x + 2
def F1(x): return 1.5*(x**2) + 2*x

def f_cos(x): return np.cos(x)
def F_cos(x): return np.sin(x)

# 2.4 & 2.5 Tracer l'erreur pour f1 et cos(x) [cite: 18, 19]
methods_p1 = {'Trapèzes': trapeze}
plotError(methods_p1, f1, F1, 0, 1, "f1(x)=3x+2")
plotError(methods_p1, f_cos, F_cos, 0, np.pi/2, "f(x)=cos(x)")


# ==========================================
# 3. Méthode de quadrature d'ordre p=2 (Simpson)
# ==========================================
print("\n--- 3. Méthode de Simpson ---")

# 3.1 Implémentation [cite: 21]
def simpson(f, a, b, n):
    if n % 2 != 0: n += 1 # Simpson requiert un nombre pair d'intervalles
    h = (b - a) / n
    x = np.linspace(a, b, n+1)
    y = f(x)
    # Simpson composite: h/3 * (y0 + yn + 4*sum(impairs) + 2*sum(pairs))
    return (h/3) * (y[0] + y[-1] + 4*np.sum(y[1:-1:2]) + 2*np.sum(y[2:-2:2]))

# 3.2 Calcul pour la cible [cite: 22, 23]
print("Approximations (Simpson) pour f(x) = sqrt(x)exp(x^2) :")
for n in ns:
    val_s = simpson(f_target, 0, 1, n)
    print(f"n={n:4d} | Simpson: {val_s:.6f}")

# 3.3 Fonctions f2 et cos(x) [cite: 24]
def f2(x): return x**2 - 2*x + 6
def F2(x): return (x**3)/3 - x**2 + 6*x

# Note: f_cos et F_cos déjà définis plus haut

# 3.4 & 3.5 Tracer l'erreur [cite: 25, 26]
methods_p2 = {'Simpson': simpson}
plotError(methods_p2, f2, F2, 0, 1, "f2(x)=x^2-2x+6")
plotError(methods_p2, f_cos, F_cos, 0, np.pi/2, "f(x)=cos(x)_Simpson")


# ==========================================
# 4. Comparaison des méthodes
# ==========================================
print("\n--- 4. Comparaison ---")

# 4.1 Comparaison des valeurs approchées [cite: 29]
# (Déjà affiché dans les boucles précédentes, on fait un récapitulatif ici)
print("Récapitulatif pour f_target (n=1000) :")
print(f"Rect G  : {rectangle_gauche(f_target, 0, 1, 1000):.8f}")
print(f"Milieu  : {point_milieu(f_target, 0, 1, 1000):.8f}")
print(f"Trapèze : {trapeze(f_target, 0, 1, 1000):.8f}")
print(f"Simpson : {simpson(f_target, 0, 1, 1000):.8f}")

# 4.2 & 4.3 Log-log plot pour sin(x) [cite: 33, 34]
def f_sin(x): return np.sin(x)
def F_sin(x): return -np.cos(x)

all_methods = {
    'Rect Gauche': rectangle_gauche,
    'Milieu': point_milieu,
    'Trapèzes': trapeze,
    'Simpson': simpson
}

plotError(all_methods, f_sin, F_sin, 0, np.pi, "Comparaison_f(x)=sin(x)")


# ==========================================
# 5. Vectorisation et vitesse d'exécution
# ==========================================
print("\n--- 5. Vectorisation ---")

# 5.1 Temps d'exécution pour n = 10^5 [cite: 44]
n_perf = 10**5
print(f"Mesure du temps pour n = {n_perf} (f_target) :")

for name, method in all_methods.items():
    start = time.time()
    res = method(f_target, 0, 1, n_perf)
    end = time.time()
    print(f"{name:12s} : {end - start:.6f} s")

# 5.2 Ré-implémentation Trapèzes (version vectorisée vs boucle explicite) [cite: 45]
# Note : Ma fonction 'trapeze' définie en 2.1 était DÉJÀ vectorisée (utilise np.sum).
# Pour l'exercice, je vais créer une version "naïve" avec une boucle for pour comparer.

def trapeze_naive(f, a, b, n):
    h = (b - a) / n
    s = 0.5 * (f(a) + f(b))
    for i in range(1, n):
        s += f(a + i*h)
    return s * h

def trapeze_vec(f, a, b, n):
    # C'est la même que 'trapeze' définie plus haut
    h = (b - a) / n
    x = np.linspace(a, b, n+1)
    y = f(x)
    return h * (0.5*y[0] + 0.5*y[-1] + np.sum(y[1:-1]))

# Vérification
val_naive = trapeze_naive(f_target, 0, 1, 100)
val_vec = trapeze_vec(f_target, 0, 1, 100)
print(f"\nVérification méthodes Trapèzes (n=100) -> Naïve: {val_naive:.6f}, Vec: {val_vec:.6f}")

# 5.3 Calcul vitesses pour n = 2^22 [cite: 46]
n_large = 2**22 # ~ 4 millions
print(f"Test de performance pour n = 2^22 ({n_large}) :")

start = time.time()
trapeze_naive(f_sq, 0, 1, n_large) # On utilise f_sq pour aller plus vite que exp/sqrt
end = time.time()
t_naive = end - start
print(f"Temps Naïf     : {t_naive:.4f} s")

start = time.time()
trapeze_vec(f_sq, 0, 1, n_large)
end = time.time()
t_vec = end - start
print(f"Temps Vectorisé: {t_vec:.4f} s")

# 5.4 Facteur d'efficacité [cite: 47]
if t_vec > 0:
    ratio = t_naive / t_vec
    print(f"Facteur d'efficacité (Naïf / Vectorisé) : {ratio:.2f}")
else:
    print("Temps vectorisé trop court pour calculer le ratio.")

print("\nTerminé. Vérifiez le dossier 'results'.")
