import numpy as np
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.time import Time
from astropy.io import fits
from astropy.wcs import WCS

from sbpy.dynamics import State
from sbpy.data import Ephem
from sbpy.dynamics import SynGenerator

image, header = fits.getdata("https://sbpy.org/data/48p-spitzer-reach07.fits", header=True)
obstime = Time(header["DATE_OBS"])

# Ephem.from_horizons returns equatorial coordinates in the ICRF reference
# frame, which has its origin at the Solar System barycenter.  For State to
# correctly convert the ephemeris to vectors, we need to set the Horizons
# observer to the Solar System barycenter: @ssb

# get the position of the comet and transform to a heliocentric frame for
# integration
eph = Ephem.from_horizons(
    "48P",
    id_type="designation",
    closest_apparition=True,
    epochs=obstime,
    location="@ssb",
)
comet = State.from_ephem(eph, frame="icrs")
comet = comet.transform_to("heliocentriceclipticiau76")

# get the position of the Spitzer Space Telescope
eph = Ephem.from_horizons("-79", id_type=None, epochs=obstime, location="@ssb")
observer = State.from_ephem(eph, frame="icrs")

# generate the syndynes
betas = [1, 0.1, 0.01, 0.001]
ages = np.linspace(0, 365, 51) * u.day
dust = SynGenerator(comet[0], betas, ages, observer=observer[0])

# Set up the world coordinate system object and update the origin to align with
# the calculated position of the comet.
wcs = WCS(header)
coords0 = observer.observe(comet)[0].unmasked
wcs.wcs.crval = coords0.ra.deg, coords0.dec.deg
wcs.wcs.crpix = 209, 99

# plot
fig, ax = plt.subplots(num=1, clear=True, figsize=(6.5, 3.25))

ax.imshow(image, origin="lower", vmin=49.1, vmax=49.5, cmap="gray_r")

# save xlim and ylim for later
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# plot syndynes
syndynes = dust.syndynes()
syndynes.plot(ax, wcs=wcs)

# plot the orbit
dt = np.linspace(-1, 1) * u.d
orbit = dust.source_orbit(dt)
orbit.plot(ax, wcs=wcs, color="tab:cyan", lw=1, label="Orbit")

ax.set(xlim=xlim, ylim=ylim)
ax.legend()
fig.tight_layout()