#    Copyright 2013 IBM Corp
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from rock import db
from rock import exception
from rock import objects
from rock.objects import base
from rock.objects import fields
from rock.openstack.common import jsonutils
from rock.openstack.common import log as logging
from rock import utils

LOG = logging.getLogger(__name__)
class ComputeNode(base.RockPersistentObject, base.RockObject):
    # Version 1.0: Initial version
    VERSION = '1.0'

    fields = {
        'id': fields.IntegerField(read_only=True),
        'service_id': fields.IntegerField(),
        'hypervisor_hostname': fields.StringField(nullable=True),
        'host_ip': fields.IPAddressField(nullable=True),
        'mac': fields.StringField(nullable=True),
        'accelerator_stats': fields.DictOfNullableStringsField(nullable=True),
        }

    def obj_make_compatible(self, primitive, target_version):
        pass

    @staticmethod
    def _from_db_object(context, compute, db_compute):

        fields = set(compute.fields) - set(['accelerator_stats'])
        for key in fields:
            compute[key] = db_compute[key]

        accelerator_stats = db_compute['accelerator_stats']
        if accelerator_stats:
            compute['accelerator_stats'] = jsonutils.loads(accelerator_stats)

        compute._context = context
        compute.obj_reset_changes()
        return compute

    @base.remotable_classmethod
    def get_by_id(cls, context, compute_id):
        db_compute = db.compute_node_get(context, compute_id)
        return cls._from_db_object(context, cls(), db_compute)

    @base.remotable_classmethod
    def get_by_service_id(cls, context, service_id):
        db_compute = db.compute_node_get_by_service_id(context, service_id)
        return cls._from_db_object(context, cls(), db_compute)

    def _convert_stats_to_db_format(self, updates):
        accelerator_stats = updates.pop('accelerator_stats', None)
        if accelerator_stats is not None:
            updates['accelerator_stats'] = jsonutils.dumps(accelerator_stats)

    def _convert_host_ip_to_db_format(self, updates):
        host_ip = updates.pop('host_ip', None)
        if host_ip:
            updates['host_ip'] = str(host_ip)

    @base.remotable
    def create(self, context):
        if self.obj_attr_is_set('id'):
            raise exception.ObjectActionError(action='create',
                                              reason='already created')
        updates = self.obj_get_changes()
        self._convert_stats_to_db_format(updates)
        self._convert_host_ip_to_db_format(updates)

        db_compute = db.compute_node_create(context, updates)
        self._from_db_object(context, self, db_compute)

    @base.remotable
    def save(self, context, prune_stats=False):
        # NOTE(belliott) ignore prune_stats param, no longer relevant

        updates = self.obj_get_changes()
        updates.pop('id', None)
        self._convert_stats_to_db_format(updates)
        self._convert_host_ip_to_db_format(updates)

        db_compute = db.compute_node_update(context, self.id, updates)
        self._from_db_object(context, self, db_compute)

    @base.remotable
    def destroy(self, context):
        db.compute_node_delete(context, self.id)

    @property
    def service(self):
        if not hasattr(self, '_cached_service'):
            self._cached_service = objects.Service.get_by_id(self._context,
                                                             self.service_id)
        return self._cached_service


class ComputeNodeList(base.ObjectListBase, base.RockObject):
    # Version 1.0: Initial version
    VERSION = '1.0'
    fields = {
        'objects': fields.ListOfObjectsField('ComputeNode'),
        }

    @base.remotable_classmethod
    def get_all(cls, context):
        db_computes = db.compute_node_get_all(context)
        return base.obj_make_list(context, cls(context), objects.ComputeNode,
                                  db_computes)

    @base.remotable_classmethod
    def get_by_hypervisor(cls, context, hypervisor_match):
        db_computes = db.compute_node_search_by_hypervisor(context,
                                                           hypervisor_match)
        return base.obj_make_list(context, cls(context), objects.ComputeNode,
                                  db_computes)

    @base.remotable_classmethod
    def _get_by_service(cls, context, service_id, use_slave=False):
        db_service = db.service_get(context, service_id,
                                    with_compute_node=True,
                                    use_slave=use_slave)
        return base.obj_make_list(context, cls(context), objects.ComputeNode,
                                  db_service['compute_node'])

    @classmethod
    def get_by_service(cls, context, service, use_slave=False):
        return cls._get_by_service(context, service.id, use_slave=use_slave)
