#!python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

"""Fetch BOSS data files containing the spectra of specified observations and mirror them locally.
"""

from __future__ import division,print_function

import os.path
import multiprocessing

from astropy.utils.compat import argparse

from progressbar import ProgressBar, Percentage, Bar

import astropy.table

import bossdata.path
import bossdata.remote
import bossdata.plate

def fetch(remote_paths,response_queue):
    mirror = bossdata.remote.Manager()
    for remote_path in remote_paths:
        try:
            local_path = mirror.get(remote_path, progress_min_size=None)
            response_queue.put((os.path.getsize(local_path),))
        except RuntimeError as e:
            response_queue.put((0,remote_path,str(e)))

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action='store_true',
        help='Provide verbose output.')
    parser.add_argument('observations', type=str, default=None, metavar='FILE',
        help='File containing PLATE,MJD,FIBER columns that specify the observations to fetch.')
    parser.add_argument('--full', action='store_true',
        help='Fetch the full version of each spectrum data file.')
    parser.add_argument('--frame', action='store_true',
        help='Fetch spFrame files for each plate instead of individual spectra.')
    parser.add_argument('--cframe', action='store_true',
        help='Fetch spCFrame files for each plate instead of individual spectra.')
    parser.add_argument('--save', type=str, default=None, metavar='FILE',
        help='Filename for saving the list of data files to download.')
    parser.add_argument('--dry-run', action='store_true',
        help='Prepare the list of files to fetch but do not perform downloads.')
    parser.add_argument('--nproc', type=int, default=2,
        help='Number of subprocesses to use to parallelize downloads (1-5).')
    parser.add_argument('--globus', type=str, default=None,
        help='Globus transfer endpoints (remote#endpoint:local#endpoint).')
    parser.add_argument('--plate-name', type=str, default='PLATE',
        help='Name of PLATE column in the input file.')
    parser.add_argument('--mjd-name', type=str, default='MJD',
        help='Name of MJD column in the input file.')
    parser.add_argument('--fiber-name', type=str, default='FIBER',
        help='Name of FIBER column in the input file.')
    args = parser.parse_args()

    if args.nproc < 1 or args.nproc > 5:
        print('nproc must be 1-5.')
        return -1

    if args.full and (args.frame or args.cframe):
        print('Option --full is not compatible with --frame or --cframe.')
        return -1

    # Read the list of observations to fetch.
    root,ext = os.path.splitext(args.observations)
    if ext in ('.dat','.txt'):
        input_format = 'ascii'
    else:
        input_format = None
    table = astropy.table.Table.read(args.observations,format=input_format)

    try:
        finder = bossdata.path.Finder(verbose=args.verbose)
        mirror = bossdata.remote.Manager(verbose=args.verbose)
    except ValueError as e:
        print(e)
        return -1

    # Build a list of remote paths from the input plate-mjd-fiber values.
    try:
        remote_paths = set()
        if args.frame or args.cframe:
            plans = {}
            calibrated = {}
            if args.frame:
                calibrated['spFrame'] = False
            if args.cframe:
                calibrated['spCFrame'] = True
            for row in table:
                tag = (row[args.plate_name], row[args.mjd_name])
                if tag in plans:
                    plan = plans[tag]
                else:
                    plan_remote_path = finder.get_plate_plan_path(
                        plate=row[args.plate_name], mjd=row[args.mjd_name])
                    # The next line downloads the small plan file, if necessary.
                    plan_local_path = mirror.get(plan_remote_path)
                    plan = bossdata.plate.Plan(plan_local_path)
                    plans[tag] = plan
                plate_path = finder.get_plate_path(plate=row[args.plate_name])
                for i in range(plan.num_science_exposures):
                    # Add the names of spFrame and/or spCFrame files, if they are available.
                    for name,cal in calibrated.iteritems():
                        blue_path = plan.get_exposure_name(
                            i, 'blue', row[args.fiber_name], calibrated=cal)
                        if blue_path:
                            remote_paths.add(os.path.join(plate_path, blue_path))
                        red_path = plan.get_exposure_name(
                            i, 'red', row[args.fiber_name], calibrated=cal)
                        if red_path:
                            remote_paths.add(os.path.join(plate_path, red_path))
        else:
            for row in table:
                remote_paths.add(finder.get_spec_path(
                    plate=row[args.plate_name], mjd=row[args.mjd_name], fiber=row[args.fiber_name], lite=not args.full))
    except RuntimeError as e:
        print('Error while preparing paths: {}'.format(str(e)))
        return -1
    remote_paths = list(remote_paths)
    num_files = len(remote_paths)
    if num_files == 0:
        print('No files to fetch.')
        return 0

    if args.globus and args.save:
        try:
            remote_endpoint, local_endpoint = args.globus.split(':')
            globus_username = local_endpoint.split('#')[0]
        except ValueError as e:
            print('Error while parsing globus endpoints: {}'.format(str(e)))
            return -1
        if args.verbose:
            print('Globus endpoints (remote, local): {0}, {1}'.format(remote_endpoint, local_endpoint))
        with open(args.save,'w') as f:
            for remote_path in remote_paths:
                local_path = mirror.local_path(remote_path)
                f.write('{0}{1} {2}{3}\n'.format(remote_endpoint, remote_path, local_endpoint, local_path))
        print('Saved {0} globus transfer remote/local file name pairs to {1}.'.format(num_files,args.save))
        print('To initiate transfer, run:')
        print('\tssh {0}@cli.globusonline.org transfer -s 1 < {1}'.format(globus_username, args.save))
        return 0
    elif args.globus:
        print('Option --globus requires that an output file is specified via the --save option')
        return -1

    if args.save:
        with open(args.save,'w') as f:
            for remote_path in remote_paths:
                f.write(remote_path+'\n')
        print('Saved {0} remote file names to {1}.'.format(num_files,args.save))
        return 0

    if args.dry_run:
        # Compile statistics on the number of files and bytes involved in this fetch.
        Mb = 1 << 20
        num_local = 0
        local_bytes = 0
        for remote_path in remote_paths:
            local_path = mirror.local_path(remote_path)
            if os.path.isfile(local_path):
                num_local += 1
                try:
                    local_bytes += os.path.getsize(local_path)
                except os.error:
                    print('Unable to get size of local {}.'.format(local_path))
        print('{} / {} = {:.1f}% of files are already local and using {:.1f} Mb.'.format(
            num_local, num_files, 100 * num_local / num_files, local_bytes / Mb))
        print('This command would download {} new files.'.format(num_files - num_local))
        if num_local < num_files and num_local > 0:
            fetch_bytes = (local_bytes/num_local) * (num_files - num_local)
            print('Estimated total download size is {:.1f} Mb.'.format(fetch_bytes / Mb))
        return 0

    if args.verbose:
        print('Fetching {:d} files...'.format(num_files))
        progress_bar = ProgressBar(widgets=[Percentage(), Bar()], maxval=num_files).start()

    # Initialize a queue that subprocesses use to signal their progress.
    response_queue = multiprocessing.Queue()

    # Launch subprocesses to handle subsets of remote paths.
    if num_files < args.nproc:
        args.nproc = num_files
    chunk_size = (len(remote_paths) + args.nproc - 1)//args.nproc
    processes = []
    for i in range(args.nproc):
        # The last chunk will be shorter if the number of paths does not evenly divide
        # between the subprocesses.
        chunk = remote_paths[i*chunk_size:(i+1)*chunk_size]
        process = multiprocessing.Process(target=fetch, args=(chunk, response_queue))
        processes.append(process)
        process.start()

    # Monitor subprocess progress.
    num_fetched = 0
    num_bytes = 0
    try:
        while num_fetched < len(remote_paths):
            response = response_queue.get()
            if response[0] == 0:
                print('Download error for {file}:\n{msg}'.format(file=response[1], msg=response[2]))
            else:
                num_bytes += response[0]
            num_fetched += 1
            if args.verbose:
                progress_bar.update(num_fetched)
        if args.verbose:
            progress_bar.finish()
        # Give subprocesses a chance to finish normally.
        for process in processes:
            process.join(timeout=1)
    except KeyboardInterrupt:
        print('Stopping after keyboard interrupt.')

    # Ensure that all subprocesses have terminated. This should never be necessary
    # after normal completion.
    for process in processes:
        if process.is_alive():
            print('Killing subprocess {}.'.format(process.name))
            process.terminate()

    if args.verbose:
        print('Fetched {:.1f} Mb for {:d} files.'.format(
            num_bytes/float(1<<20),num_fetched))

    if num_fetched != num_files:
        print('WARNING: {:d} of {:d} files were not fetched.'.format(
            num_files-num_fetched,num_files),'Re-run the command after any problems are fixed.')

if __name__ == '__main__':
    main()
