#!/usr/bin/env python
# Copyright (c) 2010-2012 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Consistency checks between the volumes actual content and the KV.
This needs work.
"""
import argparse
import logging.handlers
import os
from os.path import basename, dirname, normpath

from swift.obj.rpc_http import RpcError, StatusCode
import glob
import sys

import time

from swift.common.storage_policy import get_policy_string
from swift.obj import vfile
from swift.obj.fmgr_pb2 import STATE_RW
from swift.obj.header import read_volume_header, HeaderException, \
    read_object_header
from swift.obj.vfile_utils import get_socket_path_from_volume_path, \
    next_aligned_offset, change_user, get_mountpoint_from_volume_path
from swift.common.utils import ismount
from swift.obj import rpc_http as rpc

# Not available in this python version
SEEK_DATA = 3
SEEK_HOLE = 4


class VolumeCheckException(Exception):
    def __init__(self, msg, volume_path, *args):
        self.msg = msg
        self.volume_path = volume_path
        super(VolumeCheckException, self).__init__(msg, volume_path, *args)


def check_volume_header(volume_path, header, socket_path):
    if header.state != STATE_RW:
        err_txt = "volume {} not in state RW. ({})".format(volume_path,
                                                           header.state)
        raise VolumeCheckException(err_txt, volume_path)


class Vfiles(object):
    """
    Will yield the offset of the vfile header and a vfile header until the
    end of the file is reached. It expects to find a valid header at offset.
    It will attempt to find the next vfile from the previous header
    information (file length). If that fails (next file has been deleted),
    it will scan for the next object header.

    Starting at the provided offset :
      - attempt to read header, yield if successful. seek to current
        position + object length
      - if no header is found, first attempt to skip hole (SEEK_DATA) and try
        to read a header again
      - if no header is found, seek further 4k, try to read a header, repeat
        until found or EOF.
    It expects headers to be aligned on 4k boundaries. (don't search for a
    header where there could be user data)
    """

    def __init__(self, volume_file, offset):
        volume_file.seek(0, os.SEEK_END)
        volume_size = volume_file.tell()
        self.offset = next_aligned_offset(offset, 4096)
        self.volume_file = volume_file
        self.volume_size = volume_size
        self.next_offset = offset
        if self.offset >= volume_size:
            return

    def __iter__(self):
        return self

    def next(self):
        vh = None
        seek_data_done = False
        if self.next_offset >= self.volume_size:
            raise StopIteration

        self.volume_file.seek(self.next_offset)

        while True:
            try:
                header_offset = self.volume_file.tell()
                vh = read_object_header(self.volume_file)
            except HeaderException:
                logger.debug("no vfile at offset {}".format(self.next_offset))

            # found a header
            if vh:
                self.next_offset += vh.total_size
                logger.debug("found a header, set next offset to: {}".format(
                    self.next_offset))
                # should not happen
                aligned_next_offset = next_aligned_offset(self.next_offset,
                                                          4096)
                if aligned_next_offset != self.next_offset:
                    logger.warn(
                        "total_size of header not aligned on 4k ({})".format(
                            vh))
                    self.next_offset = aligned_next_offset
                return header_offset, vh

            if not seek_data_done:
                # that's ugly: our python version (2.7.12) does not support
                # SEEK_HOLE/SEEK_DATA. Skip holes via lseek using the
                # underlying file descriptor, then you need to seek() on the
                # python file.
                logger.debug("SEEK_DATA")
                try:
                    self.next_offset = next_aligned_offset(
                        os.lseek(self.volume_file.fileno(), self.next_offset,
                                 SEEK_DATA), 4096)
                except OSError:
                    # lseek() with SEEK_DATA sometimes returns ENXIO despite
                    #  the offset being greater than zero and smaller than
                    # the file size. If that happens, continue skipping 4k
                    seek_data_done = True
                    continue
                if self.next_offset >= self.volume_size:
                    raise StopIteration
                self.volume_file.seek(self.next_offset)
                seek_data_done = True
                continue

            logger.debug("SEEK+4k")
            self.next_offset += 4096
            if self.next_offset >= self.volume_size:
                raise StopIteration
            self.volume_file.seek(self.next_offset)
            continue


def register_volume(volume_path, socket_path):
    with open(volume_path, "rb") as f:
        vh = read_volume_header(f)

        logger.info(
            "Registering volume {} in KV. Partition: {} Type: {} \
             State: {}".format(volume_path, vh.partition, vh.type, vh.state))

        # Get the next usable offset in the volume
        f.seek(0, os.SEEK_END)
        offset = next_aligned_offset(f.tell(), 4096)

        rpc.register_volume(socket_path, vh.partition, vh.type, vh.volume_idx,
                            offset, vh.state, repair_tool=True)


def check_volume(volume_path, socket_path=None, force_full_check=False):
    if not socket_path:
        socket_path = get_socket_path_from_volume_path(volume_path)

    missing_in_kv = False

    # read volume header
    with open(volume_path, "rb") as f:
        vh = read_volume_header(f)
        # check header
        check_volume_header(volume_path, vh, socket_path)
        # get file size
        # TODO: check the volume size with the header value
        # curpos = f.tell()
        f.seek(0, os.SEEK_END)
        # vol_size = f.tell()

    # check that volume exists
    try:
        rpc.get_volume(socket_path, vh.volume_idx, repair_tool=True)
    except RpcError as e:
        if e.code == StatusCode.NotFound:
            txt = "Missing volume: {} in the KV"
            logger.warn(txt.format(vh.volume_idx))
            missing_in_kv = True
        else:
            logger.exception(e)
            logger.warn('Error while checking volume entry in KV, exiting')
            return False

    if missing_in_kv:
        if args.repair:
            if not args.no_prompt:
                if confirm_action("Add missing volume {} to the KV?".format(
                        volume_path)):
                    register_volume(volume_path, socket_path)
            else:
                register_volume(volume_path, socket_path)
        else:
            raise VolumeCheckException("Volume not in KV", volume_path)

    # TODO: add check for volume state (header vs KV)

    if force_full_check:
        start_offset = vh.first_obj_offset
    else:
        start_offset = rpc.get_next_offset(socket_path, vh.volume_idx,
                                           repair_tool=True)

    with open(volume_path, "rb") as f:
        for offset, header in Vfiles(f, start_offset):
            logger.debug("start check: {} {}".format(offset, header))
            objname = "{}{}".format(header.ohash, header.filename)
            # Get object information from the KV
            try:
                obj = rpc.get_object(socket_path, "{}".format(objname),
                                     repair_tool=True)
            except RpcError as e:
                if e.code == StatusCode.NotFound:
                    handle_obj_missing_in_kv(socket_path, volume_path, header,
                                             offset, args)
                    continue
                else:
                    # TODO: handle this
                    logger.exception(e)
            except Exception as e:
                logger.exception(e)
                continue

            # check header and kv consistency
            check_header_vs_obj(offset, obj, header, volume_path)

            # check that vfile can be opened and metadata deserialized
            check_open_vfile(objname, dirname(volume_path), socket_path)


def check_open_vfile(name, volume_dir, socket_path):
    vf = vfile.VFileReader._get_vfile(name, volume_dir, socket_path, logger,
                                      repair_tool=True)
    vf.close()


def check_header_vs_obj(file_offset, obj, header, volume_path):
    # Check offset in file
    if file_offset != obj.offset:
        err_txt = "Header/KV inconsistency. Name: {} File offset: {} \
Position from RPC: {}".format(obj.name, file_offset, obj.offset)
        logger.warn(err_txt)
        logger.warn("header: {}".format(header))
        logger.warn("rpc obj: {}".format(obj))
        raise VolumeCheckException(err_txt)

    volume_file_index = vfile.get_volume_index(basename(volume_path))
    if volume_file_index != obj.volume_index:
        txt = "Volume index error, KV volume index: {}, actual index: {}"
        err_txt = txt.format(obj.volume_index, volume_file_index)
        raise VolumeCheckException(err_txt)

        # Check volume index, Todo


def handle_obj_missing_in_kv(socket_path, volume_path, header, offset, args):
    objname = "{}{}".format(header.ohash, header.filename)
    txt = "Missing file in the KV. Volume: {}. full name: {}"
    logger.warn(txt.format(volume_path, objname))
    txt = "Offset: {}, total length: {}"
    logger.warn(txt.format(offset, header.total_size))
    if args.repair:
        if args.no_prompt:
            register_object(objname, header, volume_path, offset, socket_path)
        else:
            if confirm_action("Add object to the KV?"):
                register_object(objname, header, volume_path, offset,
                                socket_path)


def register_object(objname, header, volume_path, offset, socket_path=None):
    logger.debug("Registering {}".format(objname))
    if not socket_path:
        socket_path = get_socket_path_from_volume_path(volume_path)

    # Notes about register :
    # - because there is no end marker, we have to trust the header about the
    # end of the file.

    # get partition, from volume path, Todo, check it against the obj hash
    volume_index = vfile.get_volume_index(basename(volume_path))
    # absolute object_end (next_offset)
    object_end = offset + header.total_size
    try:
        rpc.register_object(socket_path, objname, volume_index, offset,
                            object_end, repair_tool=True)
    except RpcError as e:
        logger.warn("Failed to register object {}".format(objname))
        logger.exception(e)


def confirm_action(message):
    response = raw_input("{} (y/n)".format(message))
    if response == "y":
        return True
    else:
        return False


log_levels = {
    "critical": logging.CRITICAL,
    "error": logging.ERROR,
    "warning": logging.WARNING,
    "info": logging.INFO,
    "debug": logging.DEBUG
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # log level
    parser.add_argument("--log_level", help="logging level, defaults to info")

    # check one volume
    parser.add_argument("--volume", help="path to volume")

    # check all volumes on the disk
    parser.add_argument("--disk_path", help="/srv/node/disk-xx")
    parser.add_argument("--policy_idx", help="policy index")

    # by default, we will only check the volume, repair is to create missing
    # entries
    help_txt = "creates missing files in the KV"
    parser.add_argument("--repair", action="store_true", help=help_txt)

    help_txt = "No prompt. In repair mode, do not prompt and take " \
               "automatic action"
    parser.add_argument("--no_prompt", action="store_true", help=help_txt)

    # force full check
    help_txt = "Force check of the whole volume"
    parser.add_argument("--force_full", action="store_true", default=False,
                        help=help_txt)

    parser.add_argument("--keepuser", action='store_true', default=True,
                        help="Do not attempt to switch to swift user")
    parser.add_argument("--mount_check", action='store_true', default=False,
                        help="Wait until disk is mounted")

    args = parser.parse_args()

    log_level = "info"
    if args.log_level:
        log_level = args.log_level

    logger = logging.getLogger(__name__)
    logger.setLevel(log_levels[log_level])
    handler = logging.handlers.SysLogHandler(address='/dev/log')
    formatter = logging.Formatter('losf.volcheck: %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    if (not args.volume and not args.disk_path) or (args.volume and
                                                    args.disk_path):
        parser.print_help()
        sys.exit(0)

    if not args.keepuser:
        change_user("swift")

    if args.volume:
        if args.mount_check:
            mountpoint = get_mountpoint_from_volume_path(args.volume)
            while not ismount(mountpoint):
                logger.info(
                    "Waiting for disk {} to be mounted".format(mountpoint))
                time.sleep(5)

        socket_path = get_socket_path_from_volume_path(args.volume)
        if not args.force_full:
            while True:
                resp = rpc.get_kv_state(socket_path)
                if resp.isClean:
                    logger.info(
                        "LOSF DB {} is clean, skipping".format(socket_path))
                    sys.exit(0)
        check_volume(args.volume, force_full_check=args.force_full)

    if args.policy_idx and args.disk_path:
        losf_dir = get_policy_string('losf', args.policy_idx)
        volume_topdir = os.path.join(args.disk_path, losf_dir, 'volumes')
        if args.mount_check:
            mountpoint = dirname(dirname(os.path.normpath(volume_topdir)))
            while not ismount(mountpoint):
                logger.debug(
                    "Waiting for disk {} to be mounted".format(mountpoint))
                time.sleep(1)
        socket_path = os.path.join(dirname(normpath(volume_topdir)),
                                   "rpc.socket")
        if not args.force_full:
            resp = rpc.get_kv_state(socket_path)
            if resp.isClean:
                logger.info(
                    "LOSF DB {} is clean, skipping".format(socket_path))
                sys.exit(0)

        lock_pattern = "{}/*.writelock"
        failed_at_least_once = False
        for lock_path in glob.iglob(lock_pattern.format(volume_topdir)):
            volume_path = lock_path.replace(".writelock", "")
            logger.info("Checking volume {}".format(volume_path))
            if not os.path.exists(volume_path):
                logger.warn(
                    "writelock file found but volume does not exist, "
                    "remove it")
                os.remove(lock_path)
                continue
            try:
                check_volume(volume_path, force_full_check=args.force_full)
            except Exception as e:
                logger.warn("check_volume failed on {}".format(volume_path))
                failed_at_least_once = True

        # Mark kv as clean
        # FIXME: check failed_at_least_once, and don't mark KV clean if True.
        # However, if we do this, if you get a single IO error on the drive
        # preventing the check of a volume, the whole drive becomes
        # unavailble. Probably this should be more fine-grained (KV available
        # but keeps a list of unchecked volumes)
        if args.repair:
            socket_path = os.path.join(dirname(normpath(volume_topdir)),
                                       "rpc.socket")
            # This will be done each time, even if we have not had to repair
            # a volume, change this
            logger.info("Marking KV as clean ({})".format(socket_path))
            rpc.set_kv_state(socket_path, True)
