#!python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

""" Plot a single BOSS spectrum.
"""

from __future__ import division,print_function

from astropy.utils.compat import argparse

import os.path

import numpy as np
import numpy.ma
import matplotlib.pyplot as plt
import astropy.table

import bossdata.path
import bossdata.remote
import bossdata.spec
import bossdata.bits
import bossdata.plate

def print_mask_summary(label, mask_values):
    if np.any(mask_values):
        print('{0} pixel mask summary:'.format(label))
        bit_summary = bossdata.bits.summarize_bitmask_values(
            bossdata.bits.SPPIXMASK,mask_values)
        for bit_name,bit_count in bit_summary.iteritems():
            print('{0:5d} {1}'.format(bit_count,bit_name))
    else:
        print('No pixels masked.')

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter,
        description = 'Plot a single BOSS spectrum.')
    parser.add_argument('--verbose', action = 'store_true',
        help = 'Provide verbose output.')
    parser.add_argument('--plate',type = int, default = 6641, metavar = 'PLATE',
        help = 'Plate number of spectrum to plot.')
    parser.add_argument('--mjd',type = int, default = 56383, metavar = 'MJD',
        help = 'Modified Julian date of plate observation to use.')
    parser.add_argument('--fiber',type = int,default = 30, metavar = 'FIBER',
        help = 'Fiber number identifying the spectrum of the requested PLATE-MJD to plot.')
    parser.add_argument('--exposure',type = int,default = None, metavar = 'EXP',
        help = 'Exposure sequence number starting from 0, or plot the coadd if not set.')
    parser.add_argument('--camera',type = str, choices = ['blue','red','both'], default = 'both',
        help = 'Camera to use when plotting a single exposure.')
    parser.add_argument('--allow-mask', type = str, default = None,
        help = 'SPPIXMASK bit names to allow in valid data. Separate multiple names with |.')
    parser.add_argument('--frame', action='store_true',
        help = 'Plot the spectrum from an uncalibrated spFrame file.')
    parser.add_argument('--cframe', action='store_true',
        help = 'Plot the spectrum from a calibrated spCFrame file.')
    parser.add_argument('--save-plot', type=str, default=None, const='', nargs='?', metavar='FILE',
        help = ('Save the generated plot to specified name ' +
            '(uses bossplot-{plate}-{mjd}-{fiber}.png if name omitted).'))
    parser.add_argument('--save-data', type=str, default=None, const='', nargs='?', metavar='FILE',
        help = ('Save the spectrum data to specified name ' +
            '(uses bossplot-{plate}-{mjd}-{fiber}.dat if name omitted).'))
    parser.add_argument('--no-display', action = 'store_true',
        help = 'Do not display the image on screen (useful for batch processing).')
    parser.add_argument('--scatter', action = 'store_true',
        help = 'Show scatter of flux instead of a flux error band.')
    parser.add_argument('--show-mask', action = 'store_true',
        help = 'Indicate pixels with invalid data using vertical lines.')
    parser.add_argument('--show-dispersion', action = 'store_true',
        help = 'Show the wavelength dispersion using the right-hand axis.')
    parser.add_argument('--show-sky', action = 'store_true',
        help = 'Show the subtracted sky flux instead of the object flux.')
    parser.add_argument('--add-sky', action = 'store_true',
        help = 'Add the subtracted sky to the object flux (overrides show-sky).')
    args = parser.parse_args()

    if args.exposure is None:
        if args.frame or args.cframe:
            print('Coadds not available from frame and cframe files.')
            return -1
        if args.camera is not 'both':
            print('Ignoring camera = "{0}" for coadded spectrum.'.format(args.camera))
            args.camera = 'both'

    if args.allow_mask is None:
        pixel_quality_mask = None
    else:
        pixel_quality_mask = bossdata.bits.bitmask_from_text(
            bossdata.bits.SPPIXMASK,args.allow_mask)

    try:
        finder = bossdata.path.Finder(verbose=args.verbose)
        mirror = bossdata.remote.Manager(verbose=args.verbose)
    except ValueError as e:
        print(e)
        return -1

    # Load spectra into memory, downloading if necessary.
    try:
        if args.frame or args.cframe:
            frames = {}
            frame_path = finder.get_plate_path(plate=args.plate)
            plan_path = finder.get_plate_plan_path(plate=args.plate, mjd=args.mjd)
            plan = bossdata.plate.Plan(mirror.get(plan_path))
            if args.camera in ('red','both'):
                red_name = plan.get_exposure_name(
                    args.exposure, 'red', args.fiber, calibrated=args.cframe)
                if red_name is None:
                    print('Red camera data not available.')
                    return -1
                frames['red'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, red_name)),
                    index=1 + (args.fiber-1)//500, calibrated=args.cframe)
            if args.camera in ('blue','both'):
                blue_name = plan.get_exposure_name(
                    args.exposure, 'blue', args.fiber, calibrated=args.cframe)
                if blue_name is None:
                    print('Blue camera data not available.')
                    return -1
                frames['blue'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, blue_name)),
                    index=1 + (args.fiber-1)//500, calibrated=args.cframe)
        else:
            lite=(args.exposure is None)
            remote_paths = [finder.get_spec_path(plate=args.plate, mjd=args.mjd,fiber=args.fiber,
                                            lite=lite)]
            if lite:    # If lite, we can use the Full file if it exists but lite does not
                remote_paths.append(finder.get_spec_path(plate=args.plate, mjd=args.mjd,
                                fiber=args.fiber, lite=False))

            local_path = mirror.get(remote_paths)
            specfile = bossdata.spec.SpecFile(local_path)
            if args.verbose:
                print('Exposure summary:')
                print(specfile.exposure_table)
                relative_local_path = local_path.replace(mirror.local_root, '', 1)
                if relative_local_path != remote_paths[0]:
                    print("A substitution was made:\n\t{}\nwas substituted for\n\t{}.".format(
                        relative_local_path, remote_paths[0]))
    except RuntimeError as e:
        print(str(e))
        return -1

    # Initialize the plot.
    figure = plt.figure(figsize=(12,8))
    left_axis = plt.gca()
    figure.set_facecolor('white')
    plt.xlabel('Wavelength ($\AA$)')
    if args.frame:
        left_axis.set_ylabel('Flux (electrons)')
    else:
        left_axis.set_ylabel('Flux ($\\times 10^{-17}$ erg/s/cm$^{2}$/$\AA$)')
    if args.show_dispersion:
        right_axis = left_axis.twinx()
        right_axis.set_ylabel('Dispersion (pixels)')

    # We will potentially plot two spectra.
    spectra = [ ]
    plot_colors = [ ]
    data_args = dict(include_wdisp=args.show_dispersion, include_sky=args.show_sky or args.add_sky)
    if args.exposure is None:
        spectra.append(specfile.get_valid_data(pixel_quality_mask=pixel_quality_mask, **data_args))
        plot_colors.append('black')
        if args.verbose:
            print('Showing coadd of {0:d} exposures:'.format(specfile.num_exposures))
            print_mask_summary('Coadd (AND)',specfile.get_pixel_mask())
    elif args.frame or args.cframe:
        fibers = np.array([args.fiber],dtype=int)
        if args.verbose:
            print('Showing exposure {:08d}.'.format(
                plan.exposures['science'][args.exposure]['EXPID']))
        if args.camera in ('blue','both'):
            spectra.append(frames['blue'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',frames['blue'].get_pixel_masks(fibers)[0])
        if args.camera in ('red','both'):
            spectra.append(frames['red'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',frames['red'].get_pixel_masks(fibers)[0])
    else:
        if args.verbose:
            print('Showing exposure {:08d}.'.format(
                specfile.exposure_table[args.exposure]['exp']))
        if args.camera in ('blue','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'blue',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',specfile.get_pixel_mask(args.exposure,'blue'))
        if args.camera in ('red','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'red',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',specfile.get_pixel_mask(args.exposure,'red'))

    # Save the spectrum data, if requested.
    if args.save_data:
        if len(spectra) > 1:
            print('WARNING: saving data for the first spectrum only.')
        save_name = args.save_data
        if save_name == '':
            save_name = 'bossplot-{plate}-{mjd}-{fiber}.dat'.format(
                plate=args.plate, mjd=args.mjd, fiber=args.fiber)
        # Only save un-masked rows.
        valid = ~(spectra[0]['wavelength'].mask)
        table = astropy.table.Table(spectra[0][valid])
        table.write(save_name, format='ascii.basic')
        if args.verbose:
            print('Saved data to {}'.format(save_name))

    wlen_min,wlen_max = +1e6,-1e6
    flux_lo, flux_hi = np.array([], dtype=float), np.array([], dtype=float)
    for data,plot_color in zip(spectra,plot_colors):

        wlen,dflux = data['wavelength'][:],data['dflux'][:]
        if args.add_sky:
            flux = data['sky'][:] + data['flux'][:]
        elif args.show_sky:
            flux = data['sky'][:]
        else:
            flux = data['flux'][:]

        if args.scatter:
            left_axis.scatter(wlen,flux,color=plot_color,marker='.',s=0.1)
        else:
            left_axis.fill_between(wlen,flux-dflux,flux+dflux,color=plot_color,alpha=0.5)

        num_masked = len(data.mask)
        if args.show_mask and num_masked > 0:
            x_mask = [ ]
            y_mask = [ ]
            ymin,ymax = left_axis.get_ylim()
            bad_pixels = np.where(data.mask)
            for x in data.data['wavelength'][bad_pixels]:
                x_mask.extend([x,x,None])
                y_mask.extend([ymin,ymax,None])
            plt.plot(x_mask,y_mask,'-',color=plot_color,alpha=0.2)

        if args.show_dispersion:
            right_axis.plot(wlen,data['wdisp'][:],ls='-',color=plot_color)

        # Update the plot wavelength limits to include this data.
        wlen_min = min(wlen_min,np.ma.min(wlen))
        wlen_max = max(wlen_max,np.ma.max(wlen))

        # Update the list of fluxes that we will use to auto-range the vertical scale.
        flux_lo = numpy.ma.append(flux_lo, flux - dflux)
        flux_hi = numpy.ma.append(flux_hi, flux + dflux)

    # The x-axis limits are reset by the twinx() function so we set them here.
    plt.xlim(wlen_min,wlen_max)

    # Set the flux scale to show 99% of the valid flux +/- dflux range.
    valid = ~flux_lo.mask & ~flux_hi.mask
    left_axis.set_ylim(np.percentile(flux_lo[valid], 0.5), np.percentile(flux_hi[valid], 99.5))

    if args.save_plot is not None:
        save_name = args.save_plot
        if save_name == '':
            save_name = 'bossplot-{plate}-{mjd}-{fiber}.png'.format(
                plate=args.plate, mjd=args.mjd, fiber=args.fiber)
        figure.savefig(save_name)
        if args.verbose:
            print('Saved plot to {}'.format(save_name))
    if not args.no_display:
        plt.show()
    plt.close()

if __name__ == '__main__':
    main()
