#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2024 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import os
import gzip
import zipfile
import argparse
import subprocess
import cve_manager.defines   as df
from sys                     import argv
from shutil                  import copyfileobj, rmtree
from requests                import get as requesturl
from ax.sefunctions          import PrepareDir
from cve_manager.desc        import DOWNLOAD
from cve_manager.common      import NewArgParser, Init, ReviseCVEVolumes
from cve_manager.conf        import COMMON_SEC
from cve_manager.intf_common import ErrEncode, ERR_MAX

REQUEST_HEADERS = {'User-Agent': 'cve-manager'}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

class YearAction(argparse.Action):
	def __call__(self, parser, namespace, values, option_string=None):
		got_nvd = getattr(namespace, 'nvd_vul', False)
		if got_nvd:
			setattr(namespace, self.dest, values)
		else:
			print('"year" param used only with "nvd_vul" param!')

argparser = argparse.ArgumentParser(description=DOWNLOAD)
argparser.add_argument(
	'-i', '--info',
	metavar='INFO', type=str, required=True,
	nargs='+', choices=(df.ALL_DOWNLOADS + ['all']),
	help=f'Information to download ({", ".join(df.ALL_DOWNLOADS)} or all)'
	)
argparser.add_argument(
	'-n', '--noreplace',
	action='store_true',
	help='Do not replace existing files'
	)
argparser.add_argument(
	'-v', '--vols',
	default='all',
	action=YearAction,
	metavar='<year>|recent', type=str, nargs='+',
	help='Volumes of the NVD CVE lists (all by default)'
	)
argparser.add_argument(
	'-r', '--retry',
	action='store_true',
	help='This is not a first run (the module has been terminated and now it '
	'is running again)'
	)
argparser = NewArgParser(base=argparser, ptype='m')
args = argparser.parse_args()

if 'all' in args.info:
	args.info = df.ALL_DOWNLOADS

checklist = {target: False for target in args.info}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Terminate with a return code formed from a list of those download targets
# that have not been completed

def Err(returncode=None):

	return returncode if returncode != None else \
		exit(ErrEncode(checklist, df.ALL_DOWNLOADS))

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initialising a printer and reading the configuration file

printer, conf = Init(args)
common_params, = conf.Get([COMMON_SEC])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Select NVD volumes that will be downloaded

def SelectNVDVolumes():

	nvd_vols = ReviseCVEVolumes(args.vols, printer)
	if not nvd_vols:
		Err(ERR_MAX)
	if not args.vols:
		nvd_vols.append('recent')
	return nvd_vols

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Forming possible download targets

NVD_FEEDS_URL = 'https://nvd.nist.gov/feeds/'
CPE_URL       = f'{NVD_FEEDS_URL}xml/cpe/dictionary/official-cpe-dictionary_v2.3.xml.gz'
NVD_VUL_URL   = f'{NVD_FEEDS_URL}json/cve/1.1/'
FSTEC_VUL_URL = 'https://bdu.fstec.ru/files/documents/vulxml.zip'
LINUX_KERNEL_STREAMS = 'https://raw.githubusercontent.com/nluedtke/linux_kernel_cves/master/kern.json'
LINUX_KERNEL_FIXES   = 'https://raw.githubusercontent.com/nluedtke/linux_kernel_cves/master/data/stream_fixes.json'

ACLS_URL      = 'git.altlinux.org::acl/'
TIMELINES_URL = 'ftp://ftp.altlinux.org/pub/distributions/archive/'

THIRDPARTY_TARGETS = {
	df.DT_CPE: [CPE_URL],
	df.DT_NVD_VUL: [f'{NVD_VUL_URL}nvdcve-1.1-{vol}.json.gz' for vol in SelectNVDVolumes()],
	df.DT_FSTEC_VUL: [FSTEC_VUL_URL],
	df.DT_LINUX_KERNEL_CVES: [LINUX_KERNEL_STREAMS, LINUX_KERNEL_FIXES],
	}

ALTLINUX_TARGETS = {
	df.DT_ACL: ACLS_URL,
	df.DT_TIMELINES: f'{TIMELINES_URL}{common_params["master_branch"].lower()}/index/src/',
	df.DT_DISTRO_SRC: common_params.get('distro_lists_src'),
	df.DT_DISTRO_BIN: common_params.get('distro_lists_bin'),
	}

CUSTOM_FILE_NAMES = {
	LINUX_KERNEL_STREAMS: 'kernel_streams.json',
	LINUX_KERNEL_FIXES: 'kernel_fixes.json',
	}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Decorator function that prints out given message, tries to execute decorated
# function, checks the resulting string (including with the use of optional
# check function), prints a status message and returns the resulting string

def TryAndHandle(func):

	def Wrapper(msg, params, optional_check=None, verbose=True):
		printer.LineBegin(msg)
		try:
			res = func(*params)
			err_msg = ''
		except Exception as e:
			res = ''
			err_msg = str(e)
		if not res or (optional_check and not optional_check(res)):
			printer.Err(err_msg)
		else:
			if verbose:
				printer.LineAddExtra(f'{res}')
			printer.Success()
		return res

	return Wrapper

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Downloading, unzipping and then removing a zip (gz) file

# Download a gz/zip file from url and place it into base/<gz/zip-name> dir
def Request(url, base):
	custom_file_name = CUSTOM_FILE_NAMES.get(url)
	file_name = custom_file_name if custom_file_name else url.split('/')[-1]
	file_path = os.path.join(base, file_name)
	if args.noreplace and os.path.exists(file_path.rstrip('.gz')):
		return 'Noreplace'

	@TryAndHandle
	def Retrive(url, file_path):
		verify = common_params['always_verify_tls_certificate'] or \
			not args.retry or \
			args.info[0] != df.DT_FSTEC_VUL or \
			url != FSTEC_VUL_URL
		r = requesturl(url, headers=REQUEST_HEADERS, timeout=10, verify=verify)
		if not r.ok:
			return ''
		with open(file_path, 'wb') as f:
			f.write(r.content)
		return file_path

	return Retrive(f'Downloading "{url}"', [url, file_path])

# Unarchive a gz file
@TryAndHandle
def Ungzip(file_path):
	with gzip.open(file_path, 'rb') as f_in:
		file_path = file_path.rstrip('.gz')
		with open(file_path, 'wb') as f_out:
			copyfileobj(f_in, f_out)
			return file_path
	return ''

# Unarchive a zip file
@TryAndHandle
def Unzip(file_path):
	with zipfile.ZipFile(file_path, 'r') as f_zip:
		file_path = file_path.rstrip('.zip')
		f_zip.extractall(file_path)
		return file_path
	return ''

# Remove a gz/zip file
@TryAndHandle
def Remove(file_path):
	if os.path.isfile(file_path):
		os.remove(file_path)
	return file_path

# Download a gz/zip file, unzip it to base/<gz/zip-name> and rm this gz/zip file
def GetContents(url, base):
	# Downloading an archive file
	file_path = Request(url, base)
	if not file_path:
		return False
	if file_path == 'Noreplace':
		return True
	# Picking the right unarchiver
	if file_path.endswith('.gz'):
		Unarchive = Ungzip
	elif file_path.endswith('.zip'):
		Unarchive = Unzip
	else:
		# This is not an archive file, nothing to do
		return True
	msg = f'Unzipping "{file_path}"'
	if not Unarchive(msg, [file_path], os.path.exists):
		return False
	msg = f'Removing "{file_path}"'
	cond = lambda f : not os.path.exists(f)
	if not Remove(msg, [file_path], cond, verbose=False):
		return False
	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Fetching data with rsync

def GetALTExtData(target, src, dest):

	printer.LineBegin(f'Downloading "{src}"')

	cmd = []
	shell = False

	if target == 'acl':
		cmd += ['rsync', '--timeout=10', '-aq', src]
	elif target == 'timelines':
		cmd += ['wget', '--timeout=10', '-rq', src, '-A', 'd-t-s-evr.list',
			'-nH', '--cut-dirs', '3', '-P']
	elif 'distro' in target:
		if src[0] == '/':
			cmd += ['install', '-m', '644', os.path.join(src, '*.list'), '-Dt']
			shell = True
		elif src.startswith('git://'):
			# Destination dir should be removed before cloning with git
			if os.path.exists(dest):
				rmtree(dest)
			cmd += ['git', 'clone', src]
		else:
			printer.Err(f'Invalid value "{src}" of a conf parameter')
			return False
	else:
		printer.Err(f'Wrong target "{target}"')
		return False

	cmd += [dest]
	with open(os.devnull, 'w') as null:
		completed_process = subprocess.run(
			' '.join(cmd) if shell else cmd,
			stdout=null,
			stderr=subprocess.STDOUT,
			shell=shell,
			)

	if not completed_process or completed_process.returncode != 0:
		msg = f'"{" ".join(cmd)}" return code is {completed_process.returncode}' \
			if completed_process else ''
		if target in df.REQUIRED_DOWNLOADS:
			printer.Err(msg)
		else:
			printer.Warn(msg)
		return False

	printer.Success()
	printer.LineEnd(f'\t- {dest}')

	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Form a map of download targets out of a given map

def FormTargets(source_map):

	return {k: v for k, v in source_map.items() if k in args.info}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Convert a string of sources to a list of sources (git-sources will be placed
# in the beginning of the list)

def SrcList(src):

	if type(src) is not list:
		return [src], ''

	# Determining the index of a git-source
	index = -1
	for i, el in enumerate(src):
		if el.startswith('git://'):
			if index > 0:
				return src, 'More than one git-source'
			index = i

	if index < 0:
		return src, ''

	return src[index:index + 1] + src[:index] + src[index + 1:], ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if len(argv) < 2:
		argparser.print_help()
		Err(ERR_MAX)

	# Checking and converting download path
	base, err = PrepareDir(common_params['download'])
	if not base:
		printer.Err(err)
		Err(ERR_MAX)

	# Forming a dict of download targets
	thirdparty_targets = FormTargets(THIRDPARTY_TARGETS)
	altlinux_targets = FormTargets(ALTLINUX_TARGETS)

	# Filtering excluded data sources from the dict
	excluded_url_targets = [t for t in thirdparty_targets if
		t.endswith('_vul') and
		t[:-4] in common_params.get('excluded_vulsrc', [])]
	if excluded_url_targets:
		thirdparty_targets = {k: v for k, v in thirdparty_targets.items()
			if k not in excluded_url_targets}
		msg = f'{", ".join(excluded_url_targets)} ' \
			f'{"are" if len(excluded_url_targets) > 1 else "is"} excluded'
		printer.LineEnd(f'[NOTE: {msg}]')

	# Downloading a third party data
	for target, urls in thirdparty_targets.items():
		target_check = True
		for url in urls:
			if not GetContents(url, base):
				target_check &= False
				if target in df.REQUIRED_DOWNLOADS:
					Err()
		checklist[target] = target_check

	# Downloading an ALT Linux external data
	for target, src in altlinux_targets.items():
		if not src:
			continue
		src, err = SrcList(src)
		if err:
			printer.Err(err)
			Err()
		dest = os.path.join(base, target + '/')
		flag = True
		for _src in src:
			if not GetALTExtData(target, _src, dest):
				flag = False
				if target in df.REQUIRED_DOWNLOADS:
					Err()
		checklist[target] = flag

	exit(0)
