# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

# Only (de)serialization utils hasn't been removed to decrease requirements
# number.

"""Utility methods for working with WSGI servers."""

import datetime
import errno
import os
import signal

import eventlet
from eventlet import wsgi
from oslo_config import cfg
from oslo_log import log as logging
from oslo_serialization import jsonutils

from sahara import exceptions
from sahara.i18n import _
from sahara.i18n import _LE
from sahara.i18n import _LI
from sahara.openstack.common import sslutils

LOG = logging.getLogger(__name__)

wsgi_opts = [
    cfg.IntOpt('max_header_line',
               default=16384,
               help="Maximum line size of message headers to be accepted. "
                    "max_header_line may need to be increased when using "
                    "large tokens (typically those generated by the "
                    "Keystone v3 API with big service catalogs)."),
]

CONF = cfg.CONF
CONF.register_opts(wsgi_opts)


class ActionDispatcher(object):
    """Maps method name to local methods through action name."""

    def dispatch(self, *args, **kwargs):
        """Find and call local method."""
        action = kwargs.pop('action', 'default')
        action_method = getattr(self, str(action), self.default)
        return action_method(*args, **kwargs)

    def default(self, data):
        raise NotImplementedError()


class DictSerializer(ActionDispatcher):
    """Default request body serialization."""

    def serialize(self, data, action='default'):
        return self.dispatch(data, action=action)

    def default(self, data):
        return ""


class JSONDictSerializer(DictSerializer):
    """Default JSON request body serialization."""

    def default(self, data):
        def sanitizer(obj):
            if isinstance(obj, datetime.datetime):
                _dtime = obj - datetime.timedelta(microseconds=obj.microsecond)
                return _dtime.isoformat()
            return unicode(obj)
        return jsonutils.dumps(data, default=sanitizer)


class TextDeserializer(ActionDispatcher):
    """Default request body deserialization."""

    def deserialize(self, datastring, action='default'):
        return self.dispatch(datastring, action=action)

    def default(self, datastring):
        return {}


class JSONDeserializer(TextDeserializer):

    def _from_json(self, datastring):
        try:
            return jsonutils.loads(datastring)
        except ValueError:
            msg = _("cannot understand JSON")
            raise exceptions.MalformedRequestBody(msg)

    def default(self, datastring):
        return {'body': self._from_json(datastring)}


class Server(object):
    """Server class to manage multiple WSGI sockets and applications."""

    def __init__(self, threads=500):
        eventlet.wsgi.MAX_HEADER_LINE = CONF.max_header_line
        self.threads = threads
        self.children = []
        self.running = True

    def start(self, application):
        """Run a WSGI server with the given application.

        :param application: The application to run in the WSGI server
        """
        def kill_children(*args):
            """Kills the entire process group."""
            LOG.error(_LE('SIGTERM received'))
            signal.signal(signal.SIGTERM, signal.SIG_IGN)
            self.running = False
            os.killpg(0, signal.SIGTERM)

        def hup(*args):
            """Shuts down the server(s).

            Shuts down the server(s), but allows running requests to complete
            """
            LOG.error(_LE('SIGHUP received'))
            signal.signal(signal.SIGHUP, signal.SIG_IGN)
            os.killpg(0, signal.SIGHUP)
            signal.signal(signal.SIGHUP, hup)

        self.application = application
        self.sock = eventlet.listen((CONF.host, CONF.port), backlog=500)
        if sslutils.is_enabled():
            LOG.info(_LI("Using HTTPS for port %s"), CONF.port)
            self.sock = sslutils.wrap(self.sock)

        if CONF.api_workers == 0:
            # Useful for profiling, test, debug etc.
            self.pool = eventlet.GreenPool(size=self.threads)
            self.pool.spawn_n(self._single_run, application, self.sock)
            return

        LOG.debug("Starting %d workers", CONF.api_workers)
        signal.signal(signal.SIGTERM, kill_children)
        signal.signal(signal.SIGHUP, hup)
        while len(self.children) < CONF.api_workers:
            self.run_child()

    def wait_on_children(self):
        while self.running:
            try:
                pid, status = os.wait()
                if os.WIFEXITED(status) or os.WIFSIGNALED(status):
                    if pid in self.children:
                        LOG.error(_LE('Removing dead child %s'), pid)
                        self.children.remove(pid)
                        self.run_child()
            except OSError as err:
                if err.errno not in (errno.EINTR, errno.ECHILD):
                    raise
            except KeyboardInterrupt:
                LOG.info(_LI('Caught keyboard interrupt. Exiting.'))
                os.killpg(0, signal.SIGTERM)
                break
        eventlet.greenio.shutdown_safe(self.sock)
        self.sock.close()
        LOG.debug('Server exited')

    def wait(self):
        """Wait until all servers have completed running."""
        try:
            if self.children:
                self.wait_on_children()
            else:
                self.pool.waitall()
        except KeyboardInterrupt:
            pass

    def run_child(self):
        pid = os.fork()
        if pid == 0:
            signal.signal(signal.SIGHUP, signal.SIG_DFL)
            signal.signal(signal.SIGTERM, signal.SIG_DFL)
            self.run_server()
            LOG.debug('Child %d exiting normally', os.getpid())
            return
        else:
            LOG.info(_LI('Started child %s'), pid)
            self.children.append(pid)

    def run_server(self):
        """Run a WSGI server."""
        self.pool = eventlet.GreenPool(size=self.threads)
        wsgi.server(self.sock,
                    self.application,
                    custom_pool=self.pool,
                    log=LOG,
                    debug=False)
        self.pool.waitall()

    def _single_run(self, application, sock):
        """Start a WSGI server in a new green thread."""
        LOG.info(_LI("Starting single process server"))
        eventlet.wsgi.server(sock, application,
                             custom_pool=self.pool,
                             log=LOG,
                             debug=False)
