# -*- coding: utf-8 -*-
"""
Created on Wed Nov  1 12:45:35 2017

@author: sdjackson
"""

import numpy as np
import matplotlib.pyplot as plt
import timeit

G=6.67e-11 #Newton's constant in MKS units
Msun = 1.99e30 #Mass of the Sun in kilograms
AU = 1.5e11 #Astronomical unit in meters

def acc(r,m1,m2): #Returns an acceleration: [ax1,ay1,ax2,ay2]
    x1,y1,x2,y2=r
    r12 = np.sqrt((x1-x2)**2+(y1-y2)**2)
    return np.array([G*m2*(r[2]-r[0])/r12**3,G*m2*(r[3]-r[1])/r12**3,\
    -G*m1*(r[2]-r[0])/r12**3,-G*m1*(r[3]-r[1])/r12**3])

def rk4(r,v,m1,m2,h):
    x1,y1,x2,y2=r
    vx1,vy1,vx2,vy2=v    
    kr1=v*h
    kv1=acc(r,m1,m2)*h    
    kr2=(v+kv1/2)*h
    kv2=acc(r+kr1/2,m1,m2)*h    
    kr3=(v+kv2/2)*h
    kv3=acc(r+kr2/2,m1,m2)*h
    kr4=(v+kv3)*h
    kv4=acc(r+kr3,m1,m2)*h    
    r_new=r+(kr1+2*(kr2+kr3)+kr4)/6
    v_new=v+(kv1+2*(kv2+kv3)+kv4)/6
    return(r_new,v_new)

def r0v0(m1, m2, a, e):
    q = m1/m2        # Definition of mass ratio
    r0 = (1.0-e)*a/(1.0+q)   # Formulae for r0, v0
    v0 = np.sqrt(G*(m1+m2)/a) * np.sqrt((1.0+e)/(1.0-e)) / (1.0+q)
    return (r0, v0)

m2=Msun    
m1=m2/10.
a=5*AU
e=0.5

# Compute r0 and v0 from a, e, and masses
r0, v0 = r0v0(m1, m2, a, e)

# Initialize positions and velocities
q = m1/m2
r_i=np.array([r0,0,-q*r0,0]) #[x1,y1,x2,y2]
v_i=np.array([0,v0,0,-q*v0]) #[vx1,vy1,vx2,vy2]
r12=np.sqrt((r_i[2]-r_i[0])**2+(r_i[3]-r_i[1])**2)
pot0=-G*m1*m2/r12
ke0=.5*m1*(v_i[0]**2+v_i[1]**2)+.5*m2*(v_i[2]**2+v_i[3]**2)

# Compute orbital period and output period
period = np.sqrt(4*np.pi**2*a**3/(G*(m1+m2)))
nperiods=10
dt=period/1000.
t0,t=0,0

def rho(r,v,m1,m2,h,delta):
    a,b=rk4(r,v,m1,m2,h)
    x_1=rk4(a,b,m1,m2,h)[0]        
    x_2=rk4(r,v,m1,m2,2*h)[0]
    return 30*h*delta/np.sqrt((x_1[0]-x_2[0])**2+((x_1[1]-x_2[1]))**2)

delta=1#Choose delta (accuracy)

while np.abs(rho(r_i,v_i,m1,m2,dt,delta)-1)>=1e-4: #Calculate dt such that rho=1
    if rho(r_i,v_i,m1,m2,dt,delta)>1:
        dt+=0.01*dt
    elif rho(r_i,v_i,m1,m2,dt,delta)<1:
        dt-=0.01*dt

R=[r_i]
V=[v_i]
start1=timeit.default_timer()
j=0
while t<nperiods*period:
    r=rho(r_i,v_i,m1,m2,dt,delta)
    if r>=1:
        r_n,v_n=rk4(r_i,v_i,m1,m2,dt)    
        R=np.append(R,[r_n],axis=0)
        V=np.append(V,[v_n],axis=0)    
        r_i=r_n
        v_i=v_n
        r12=np.sqrt((r_i[2]-r_i[0])**2+(r_i[3]-r_i[1])**2)
        pot=-G*m1*m2/r12
        pot0=np.append(pot0,pot)
        ke=.5*m1*(v_i[0]**2+v_i[1]**2)+.5*m2*(v_i[2]**2+v_i[3]**2)
        ke0=np.append(ke0,ke)
        t+=2*dt
        t0=np.append(t0,t)
        dt=1.5*dt
        j+=1
        continue
    if r<1:
        dt=dt*r**(1/4)
stop1=timeit.default_timer()
start2=timeit.default_timer()
plt.plot(np.log(-1*(ke0+pot0)))
plt.suptitle('Total Energy for %i Orbit(s), $\delta$ = %f' %(nperiods,delta))
plt.xlabel('Steps')
plt.ylabel('Energy')
plt.show() 
stop2=timeit.default_timer()
print('t =',stop1-start1+(stop2-start2)) #Time for energy plots/calculations
start3=timeit.default_timer()
plt.plot(R[:,0],R[:,1],'ro',label='m1')
plt.plot(R[:,2],R[:,3],label='m2')
plt.legend(loc='best')
plt.xlabel('$x$ (meters)')
plt.ylabel('$y$ (meters)')
plt.title('%i Orbit(s), $\delta$ = %f' %(nperiods,delta))
plt.show()
stop3=timeit.default_timer()
print('t =',stop1-start1+(stop3-start3)) #Time for orbit plots/calculations
print('Steps =',j)

#########################################Problem 2
r_i=np.array([r0,0,-q*r0,0])
v_i=np.array([0,v0,0,-q*v0])
def Verlet(m1,m2,r_init,v_init,N):
    dt=period/N
    R=np.zeros((int(nperiods*N),4))
    R[0,:]=r_init
    R[1,:]=r_init+v_init*dt+0.5*dt**2*acc(r_init,m1,m2)
    for i in range(1,int(nperiods*N)-1):
        R[i+1,:]=2*R[i,:]-R[i-1,:]+dt**2*acc(R[i,:],m1,m2)
    return R 

#Calculate velocities using the Verlet Method
start4=timeit.default_timer()    
V1=np.zeros((int(nperiods*1000),4))
V1[0,:]=v_i
R1=Verlet(m1,m2,r_i,v_i,1000)
for i in range(1,int(nperiods*1000)-1):
    V1[i,:]=(R1[i+1,:]-R1[i,:])/(dt)

#Initialize r12, pot0, and ke0
r12=np.sqrt((r_i[2]-r_i[0])**2+(r_i[3]-r_i[1])**2)
pot0=-G*m1*m2/r12
ke0=.5*m1*(v_i[0]**2+v_i[1]**2)+.5*m2*(v_i[2]**2+v_i[3]**2)

#Calculate energies
for i in range(1,int(nperiods*1000)):
    r12=np.sqrt((R1[i,2]-R1[i,0])**2+(R1[i,3]-R1[i,1])**2)
    pot=-G*m1*m2/r12
    pot0=np.append(pot0,pot)
    ke=.5*m1*(V1[i,0]**2+V1[i,1]**2)+.5*m2*(V1[i,2]**2+V1[i,3]**2)
    ke0=np.append(ke0,ke)

plt.plot(np.log(-1*(ke0+pot0)))
plt.xlabel('Steps')
plt.ylabel('Energy')
plt.title('Total Energy for %i Orbit(s) (Verlet)' %nperiods)
if m1==m2/1000:
    plt.ylim(81.4,82.4)
if m1==m2/2:
    plt.ylim(87.6,88.6)
if e==0.01:
    plt.ylim(81.82,81.88)
plt.show()
stop4=timeit.default_timer()
print('t =',stop4-start4,'(Verlet)')

start5=timeit.default_timer()
plt.plot(Verlet(m1,m2,r_i,v_i,1000)[:,0],Verlet(m1,m2,r_i,v_i,1000)[:,1],label='m1')
plt.plot(Verlet(m1,m2,r_i,v_i,1000)[:,2],Verlet(m1,m2,r_i,v_i,1000)[:,3],label='m2')
plt.xlabel('$x$ (meters)')
plt.ylabel('$y$ (meters)')
plt.legend(loc='best')
plt.title('%i Orbit(s) (Verlet)' %nperiods)
plt.show()
stop5=timeit.default_timer()
print('t =',stop5-start5,'(Verlet)')