#! /usr/bin/python3.6
#
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import signal
import struct
import subprocess
import sys
import time
from typing import Optional, List, Dict, Set
import fcntl
import atexit

try:
    import ovs.doca_tcpdump_util
    import ovs.dpif_doca_tcpdump_hooks
except ModuleNotFoundError:
    print(u"""\
ERROR: Missing dependencies.
Please install the Open vSwitch python libraries: python3-doca-openvswitch (version 3.2.0044).
Alternatively, install them from source: ( cd ovs/python ; python3 setup.py install ).
Alternatively, check that your PYTHONPATH is pointing to the correct location.""")
    sys.exit(1)


# PCAP Constants
PCAP_MAGIC_NUMBER = 0xa1b2c3d4
PCAP_VERSION_MAJOR = 2
PCAP_VERSION_MINOR = 4
PCAP_THISZONE = 0
PCAP_SIGFIGS = 0
PCAP_SNAPLEN = 65535
PCAP_NETWORK_TYPE = 1  # Ethernet


def get_global_header() -> bytes:
    """Generate PCAP global header."""
    return struct.pack(
        'IHHIIII',
        PCAP_MAGIC_NUMBER,
        PCAP_VERSION_MAJOR,
        PCAP_VERSION_MINOR,
        PCAP_THISZONE,
        PCAP_SIGFIGS,
        PCAP_SNAPLEN,
        PCAP_NETWORK_TYPE
    )


def write_packet_header(fd, data: bytes) -> None:
    """Write PCAP packet header and data to file descriptor."""
    ts_sec = int(time.time())
    ts_usec = int((time.time() - ts_sec) * 1_000_000)
    incl_len = len(data)
    orig_len = len(data)

    packet_header = struct.pack('IIII', ts_sec, ts_usec, incl_len, orig_len)
    fd.write(packet_header)
    fd.write(data)
    fd.flush()


def parse_interface_config(config_str: str) -> ovs.doca_tcpdump_util.InterfaceConfig:
    """Parse interface configuration string like 'eth0:hook1,hook2' or 'eth0'."""
    if ':' in config_str:
        iface_name, hooks_str = config_str.split(':', 1)
        hooks = set()
        for hook_name in hooks_str.split(','):
            hook_name = hook_name.strip()
            if hook_name:  # Skip empty hook names
                try:
                    hook = ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook.verify_hook_exists(hook_name)
                    hooks.add(hook)
                except ValueError as e:
                    print(f"Error: {e}")
                    sys.exit(1)
        return ovs.doca_tcpdump_util.InterfaceConfig(iface_name.strip(), hooks)
    else:
        return ovs.doca_tcpdump_util.InterfaceConfig(config_str.strip(), {ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook.RX})


def create_argument_parser() -> argparse.ArgumentParser:
    """Create and configure the argument parser."""
    parser = argparse.ArgumentParser(
        prog=os.path.basename(sys.argv[0]),
        description='Dump software traffic from an Open vSwitch port using tcpdump',
        add_help=False,  # We'll handle help ourselves
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Interface specification format:
  eth0                    # Single interface, default hook
  eth0:hook1,hook2        # Interface with specific hooks
  eth0+eth1               # Multiple interfaces, default hooks
  eth0:hook1,hook2+eth1:hook3   # Multiple interfaces with specific hooks
  any                     # All interfaces, default hooks

Use plus (+) to separate multiple interfaces
Use comma (,) to separate multiple hooks per interface
Available hooks: """ + ", ".join(h.value for h in ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook)
    )

    # Our custom arguments
    parser.add_argument('-i', '--interface',
                       help='Interface specification (see format below)')
    parser.add_argument('--list-interfaces',
                       action='store_true',
                       help='List available OVS interfaces')
    parser.add_argument('--list-hooks',
                       action='store_true',
                       help='List available DPIF-DOCA hooks')
    parser.add_argument('-h', '--help',
                       action='store_true',
                       help='Show this help message and exit')

    return parser


def handle_special_commands(args: argparse.Namespace) -> bool:
    """Handle special commands that should exit the program.

    Returns:
        True if the program should exit, False if it should continue.
    """
    if args.list_interfaces:
        try:
            ports = ovs.doca_tcpdump_util.OVSInterfaceManager.list_ports()
            for line in ports:
                print(line)
        except RuntimeError as e:
            print(f"Error: {e}")
        return True

    if args.list_hooks:
        print("Available DOCA tcpdump hooks:")
        for hook in ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook:
            print(f"  {hook.value:<15} - {hook.description}")
        return True

    if args.help:
        parser = create_argument_parser()
        parser.print_help()
        return True

    return False


def get_interface_configs(interface_spec: str) -> Dict[str, ovs.doca_tcpdump_util.InterfaceConfig]:
    """Parse interface configuration from the interface specification string."""
    interfaces = {}

    if not interface_spec or interface_spec == "any":
        return interfaces

    # Split on plus to separate interfaces
    interface_parts = interface_spec.split('+')

    for interface_part in interface_parts:
        interface_part = interface_part.strip()
        if interface_part:
            config = parse_interface_config(interface_part)
            if config.name != "any":
                try:
                    ovs.doca_tcpdump_util.OVSInterfaceManager.verify_interface_exists(config.name)
                except (ValueError, RuntimeError) as e:
                    print(f"Error: {e}")
                    sys.exit(1)
            interfaces[config.name] = config

    return interfaces


def verify_tcpdump_args(args: List[str]) -> None:
    """Verify that the tcpdump arguments are valid by doing a dry run."""
    kwargs = {}
    if sys.version_info >= (3, 7):
        kwargs["text"] = True
    else:
        kwargs["universal_newlines"] = True

    dry_run = subprocess.run(
        ["tcpdump", "-d"] + args,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        **kwargs
    )

    if "usage:" in dry_run.stdout.lower():
        print(dry_run.stdout)
        sys.exit(0)


class TcpdumpManager:
    """Manages tcpdump subprocess and pipe operations."""

    def __init__(self, args: List[str]):
        self.args = args
        self.process: Optional[subprocess.Popen] = None
        self.stdin_fd = None

    def create_pipe(self) -> None:
        """Create a tcpdump subprocess with pipe for input."""
        self.process = subprocess.Popen(
            ["tcpdump", "-r", "-"] + self.args,
            stdin=subprocess.PIPE,
            preexec_fn=os.setpgrp
        )
        self.stdin_fd = self.process.stdin

    def init_pcap_header(self) -> None:
        """Initialize the tcpdump pipe with PCAP global header."""
        if self.stdin_fd:
            self.stdin_fd.write(get_global_header())
            self.stdin_fd.flush()

    def write_packet(self, data: bytes) -> None:
        """Write packet data to tcpdump."""
        if self.stdin_fd:
            write_packet_header(self.stdin_fd, data)

    def close(self) -> None:
        """Close tcpdump process and pipe."""
        if self.stdin_fd:
            self.stdin_fd.close()
        if self.process:
            self.process.wait()


class OVSDocaTcpdump:
    """Main class that orchestrates the DOCA tcpdump functionality."""

    def __init__(self):
        self.tcpdump_manager: Optional[TcpdumpManager] = None
        self.doca_manager: Optional[ovs.doca_tcpdump_util.DOCATcpdumpManager] = None
        self.running = False

    def setup_signal_handlers(self) -> None:
        """Set up signal handlers for graceful shutdown."""
        signal.signal(signal.SIGINT, self._handle_signal)
        signal.signal(signal.SIGTERM, self._handle_signal)

    def _handle_signal(self, signum: int, frame) -> None:
        """Handle system signals for graceful shutdown."""
        if self.tcpdump_manager:
            try:
                self.tcpdump_manager.close()
            except Exception:
                pass
        if self.doca_manager:
            self.doca_manager.stop_capture()
        sys.exit(0)

    def process_arguments(self, args: List[str]) -> tuple:
        """Process and validate command line arguments."""
        parser = create_argument_parser()

        # Parse known args, allowing unknown args to pass through
        parsed_args, unknown_args = parser.parse_known_args(args)

        # Handle special commands that should exit
        if handle_special_commands(parsed_args):
            sys.exit(0)

        # Get interface configurations
        interfaces = get_interface_configs(parsed_args.interface)

        # Get tcpdump arguments (original args minus our custom ones)
        tcpdump_args = unknown_args

        # Verify tcpdump arguments
        verify_tcpdump_args(tcpdump_args)

        return interfaces, tcpdump_args

    def setup_components(self, interfaces: Dict[str, ovs.doca_tcpdump_util.InterfaceConfig], tcpdump_args: List[str]) -> None:
        """Set up all components for packet capture."""
        # Create DOCA manager
        self.doca_manager = ovs.doca_tcpdump_util.DOCATcpdumpManager()
        self.doca_manager.setup_socket_server()

        # Start DOCA tcpdump with interface configurations
        self.doca_manager.start_capture(interfaces)

        # Create tcpdump manager
        self.tcpdump_manager = TcpdumpManager(tcpdump_args)
        self.tcpdump_manager.create_pipe()
        self.tcpdump_manager.init_pcap_header()

    def run_capture_loop(self) -> None:
        """Main capture loop that processes incoming packets."""
        self.running = True

        try:
            while self.running:
                batch = self.doca_manager.receive_packet_batch()
                if batch:
                    for pkt in batch["packets"]:
                        self.tcpdump_manager.write_packet(pkt['data'])
        except KeyboardInterrupt:
            self._handle_signal(signal.SIGINT, None)
        finally:
            self.cleanup()

    def cleanup(self) -> None:
        """Clean up all resources."""
        if self.tcpdump_manager:
            self.tcpdump_manager.close()
        if self.doca_manager:
            self.doca_manager.cleanup()

    def run(self, args: List[str]) -> None:
        """Main entry point to run the application."""
        self.setup_signal_handlers()
        interfaces, tcpdump_args = self.process_arguments(args)
        self.setup_components(interfaces, tcpdump_args)
        self.run_capture_loop()


def acquire_lock():
    """Acquire a file lock to prevent parallel execution. Returns lock file descriptor or None."""
    try:
        import ovs.dirs
        lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
        lock_fd = os.open(lock_file, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
        fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
        # Write PID to lock file
        os.write(lock_fd, str(os.getpid()).encode())
        os.fsync(lock_fd)
        return lock_fd
    except (OSError, IOError):
        return None

def release_lock(lock_fd):
    """Release the file lock."""
    if lock_fd is not None:
        try:
            fcntl.flock(lock_fd, fcntl.LOCK_UN)
            os.close(lock_fd)
            # Remove lock file
            import ovs.dirs
            lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
            os.unlink(lock_file)
        except (OSError, IOError):
            pass

# Global variable to hold lock file descriptor
_lock_fd = None

def main():
    """Main function that creates and runs the OVS DOCA tcpdump application."""
    global _lock_fd

    # Try to acquire lock
    _lock_fd = acquire_lock()
    if _lock_fd is None:
        import ovs.dirs
        lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
        print("Error: Another instance of ovs-doca-tcpdump is already running.")
        print(f"Lock file: {lock_file}")
        sys.exit(1)

    # Register cleanup function
    atexit.register(lambda: release_lock(_lock_fd))

    try:
        app = OVSDocaTcpdump()
        app.run(sys.argv[1:])
    finally:
        release_lock(_lock_fd)


if __name__ == "__main__":
    main()
