import numpy as np
import math
from scipy.optimize import fsolve

def f(t, y):
    return y**2

def g(t, y):
    return t / y

def RKF(f, y0, t0, tn, hmax, hmin, epsilon):
    t = t0
    h = hmax
    t_values = [t]
    y_values = [y0]
    y_current = y0
    y_n1 = 1
    
    for i in range(1, 100):
        
        # 使用 RKF 的斜率计算
        k1 = h * f(t, y_current)
        k2 = h * f(t + h / 4, y_current + k1 / 4)
        k3 = h * f(t + 3 * h / 8, y_current + 3 * k1 / 32 + 9 * k2 / 32)
        k4 = h * f(t + 12 * h / 13, y_current + 1932 * k1 / 2197 - 7200 * k2 / 2197 + 7296 * k3 / 2197)
        k5 = h * f(t + h, y_current + 439 * k1 / 216 - 8 * k2 + 3680 * k3 / 513 - 845 * k4 / 4104)
        k6 = h * f(t + h / 2, y_current - 8 * k1 / 27 + 2 * k2 - 3544 * k3 / 2565 + 1859 * k4 / 4104 - 11 * k5 / 40)
        
        R = abs(k1 / 360 - 128 * k3 / 4275 - 2197 * k4 / 75240 + k5 / 50 + 2 * k6 / 55)

        if(R <= h * epsilon):
            t = t + h
            y_n1 = y_current + (25 * k1 / 216 + 1408 * k3 / 2565 + 2197 * k4 / 4104 - k5 / 5)
            y_values.append(y_n1)
            t_values.append(t)
        
        delta = 0.84 * math.pow(epsilon * h / R, 1 / 4)

        if delta <= 0.1:
            h = 0.1 * h
        elif delta >= 4:
            h = 4 * h
        else:
            h = delta * h

        if(h > hmax):
            h = hmax

        if(t == tn):
            break

        if((t + h) > tn):
            h = tn - t

        if(h < hmin):
            print("步长太小")
            break
       
        y_current = y_n1
    
    return t_values, y_values

# 参数设置
y0 = 1          
t0 = 0          
tn = 0.4        
hmax = 0.1
hmin = 0.0001
epsilon = 0.0000000001    
print("方程一")
# 计算数值解
t_values, y_values = RKF(f, y0, t0, tn, hmax, hmin, epsilon)

# 输出
for t, y in zip(t_values, y_values):
    print(f"y({t:.6f}) = {y:.6f}")

y0 = 1          
t0 = 2.0   
tn = 2.6        
hmax = 0.1
hmin = 0.000001
epsilon = 0.0000000001
print("方程二")
# 计算数值解
t_values, y_values = RKF(g, y0, t0, tn, hmax, hmin, epsilon)

# 输出
for t, y in zip(t_values, y_values):
    print(f"y({t:.6f}) = {y:.6f}")