import sympy as sp

x = sp.symbols('x')

x_values = [x0, x1, x2, x3]  
y_values = [y0, y1, y2, y3]  

def lagrange_basis(i, x_values):
    term = 1
    for j in range(len(x_values)):
        if i != j:
            term *= (x - x_values[j]) / (x_values[i] - x_values[j])
    return term

lagrange_polynomial = 0
for i in range(len(x_values)):
    lagrange_polynomial += y_values[i] * lagrange_basis(i, x_values)

lagrange_polynomial = sp.simplify(lagrange_polynomial)
print(lagrange_polynomial)
