import numpy as np
import matplotlib.pyplot as plt

# ==============================================================================
# EXERCICE 1 : DÉCOMPOSITION LU ET RÉSOLUTION
# ==============================================================================

def LU(A):
    """
    Calcule la décomposition LU d'une matrice A carrée.
    Retourne L (triangulaire inférieure avec 1 sur diag) et U (triangulaire supérieure).
    """
    n = A.shape[0]
    L = np.eye(n)            # Initialisation de L à l'identité
    U = A.copy().astype(float) # Initialisation de U à A (copie pour ne pas modifier l'original)

    # Algorithme du pivot de Gauss sans échange de lignes
    for k in range(n-1):     # Boucle sur les colonnes pivots
        pivot = U[k, k]
        if pivot == 0:
            raise ValueError("Pivot nul rencontré, la décomposition impossible sans permutation.")
            
        for i in range(k+1, n): # Boucle sur les lignes sous le pivot
            coef = U[i, k] / pivot
            L[i, k] = coef
            # Mise à jour de la ligne i de U : L_i <- L_i - coef * L_k
            U[i, k:] = U[i, k:] - coef * U[k, k:]
            U[i, k] = 0 # Force le 0 pour éviter les erreurs d'arrondi
            
    return L, U

def descente(L, b):
    """ Résout Ly = b (Système triangulaire inférieur) """
    n = len(b)
    y = np.zeros(n)
    for i in range(n):
        somme = np.dot(L[i, :i], y[:i]) # Somme des L[i,j]*y[j] pour j < i
        y[i] = (b[i] - somme) / L[i, i]
    return y

def remontee(U, y):
    """ Résout Ux = y (Système triangulaire supérieur) """
    n = len(y)
    x = np.zeros(n)
    for i in range(n-1, -1, -1):
        somme = np.dot(U[i, i+1:], x[i+1:]) # Somme des U[i,j]*x[j] pour j > i
        x[i] = (y[i] - somme) / U[i, i]
    return x

def solve_LU(L, U, b):
    """ Résout le système Ax = b en utilisant L et U """
    y = descente(L, b)
    x = remontee(U, y)
    return x

# --- TEST DE L'EXERCICE 1 ---
print("--- TEST EXERCICE 1 ---")
A_test = np.array([[2, 3, 5], 
                   [2, 2, 4], 
                   [2, 1, 4]], dtype=float)
b_test = np.array([1.5, 2, 2.5])

L_calc, U_calc = LU(A_test)
print("Matrice L :\n", L_calc)
print("Matrice U :\n", U_calc)

x_sol = solve_LU(L_calc, U_calc, b_test)
print("Solution x trouvée :", x_sol)
print("Vérification A.x :", A_test @ x_sol)
print("\n")


# ==============================================================================
# EXERCICE 2 : RÉSOLUTION NUMÉRIQUE EDO
# ==============================================================================

def mat(N):
    """
    Construit la matrice du système M = (1/h^2)A + I
    pour l'équation -u'' + u = f avec u(0)=u(1)=0.
    """
    h = 1 / (N - 1)
    
    # Construction de la matrice A (Laplacien discret 1D)
    # Diagonale principale = 2, diagonales adjacentes = -1
    # Les bords (lignes 0 et N-1) restent à 0 dans A pour respecter les conditions aux limites
    # après ajout de l'identité.
    A = np.zeros((N, N))
    
    # Remplissage des points intérieurs
    for i in range(1, N-1):
        A[i, i] = 2
        A[i, i-1] = -1
        A[i, i+1] = -1
        
    # Matrice finale M
    I = np.eye(N)
    M = (1 / h**2) * A + I
    
    # Note : Aux bords i=0 et i=N-1, A[i,:] est nul, donc M[i,i] = 1.
    # L'équation devient 1*u[0] = f[0] = 0. C'est cohérent.
    return M

def vecF(N, func):
    """
    Construit le vecteur second membre f.
    """
    x = np.linspace(0, 1, N)
    f_vec = np.zeros(N)
    
    # On évalue f seulement sur les points intérieurs.
    # Les bords sont à 0 (conditions homogènes).
    for i in range(1, N-1):
        f_vec[i] = func(x[i])
        
    return f_vec

# --- Fonctions mathématiques données ---
def u_exact_func(x):
    return np.sin(np.pi * x)**2

def f1_func(x):
    # f correspondante à u_exact
    return np.sin(np.pi * x)**2 - 2 * (np.pi**2) * (1 - 2 * np.sin(np.pi * x)**2)

def f2_func(x):
    # Second membre de la question 7
    return (x**2) * np.cos(x) - x * np.sin(x)

# --- Résolution principale (Question 3 à 6) ---
N = 1000
x_plot = np.linspace(0, 1, N)

# 1. Construction du système
M_sys = mat(N)
b_f1 = vecF(N, f1_func)

# 2. Résolution via LU
L_M, U_M = LU(M_sys)
u_approx = solve_LU(L_M, U_M, b_f1)
u_exacte = u_exact_func(x_plot)

# 3. Tracé Comparatif
plt.figure(figsize=(10, 6))
plt.plot(x_plot, u_exacte, 'k-', linewidth=2, label=r'$u_{exact}(x)$')
plt.plot(x_plot, u_approx, 'r--', linewidth=2, label=r'$u_{approx}(x)$ (LU)')
plt.title(f"Résolution de -u''+u=f (N={N})")
plt.xlabel('x')
plt.ylabel('u')
plt.legend()
plt.grid(True)
plt.show()

# --- Question 7 : Second membre f2 ---
# On réutilise L_M et U_M déjà calculés (gain de temps énorme)
b_f2 = vecF(N, f2_func)
u_approx_2 = solve_LU(L_M, U_M, b_f2)

plt.figure(figsize=(10, 6))
plt.plot(x_plot, u_approx_2, 'b-', label='Solution pour f2')
plt.title("Solution avec $f(x) = x^2 \cos(x) - x \sin(x)$")
plt.xlabel('x')
plt.grid(True)
plt.legend()
plt.show()


# ==============================================================================
# BONUS : ÉTUDE DE CONVERGENCE (Question 8)
# ==============================================================================

print("Calcul de la convergence en cours...")
N_values = [10, 50, 100, 200, 500] # Valeurs de N à tester
errors = []
h_values = []

for n_val in N_values:
    h = 1 / (n_val - 1)
    h_values.append(h)
    
    # Résolution
    M_val = mat(n_val)
    b_val = vecF(n_val, f1_func)
    L_val, U_val = LU(M_val)
    u_num = solve_LU(L_val, U_val, b_val)
    
    # Comparaison avec exact
    x_val = np.linspace(0, 1, n_val)
    u_ex = u_exact_func(x_val)
    
    # Calcul norme L2 de l'erreur
    # (Norme vectorielle classique normalisée par la racine de N pour ne pas dépendre de la taille)
    err = np.linalg.norm(u_num - u_ex) / np.sqrt(n_val)
    errors.append(err)

# Tracé log-log
plt.figure(figsize=(8, 6))
plt.loglog(h_values, errors, 'o-', label='Erreur numérique')
# Pente théorique en h^2 (ordre 2)
plt.loglog(h_values, [10*h**2 for h in h_values], 'k--', label='Pente $O(h^2)$')
plt.xlabel('Pas h (log)')
plt.ylabel('Erreur (log)')
plt.title('Convergence de la méthode des différences finies')
plt.legend()
plt.grid(True, which="both", ls="-")
plt.show()