# Copyright (c) 2012 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

"""
Track accelerator like pf and vf for a rock device.  Provides the
scheduler with useful information about availability through the ComputeNode
model.
"""
import collections
import copy

from rock.agent import acc_device
from rock.agent import acc_stats
from rock.common import acc_type
from rock.common import dict_util
from rock import exception
from rock import utils
from rock.i18n import _
from rock import objects
from rock.openstack.common import jsonutils
from rock.openstack.common import log as logging

AGENT_RESOURCE_SEMAPHORE = "agent_resources"

LOG = logging.getLogger(__name__)


class AcceleratorTracker(object):
    """Rock agent helper class for keeping track of accelerator resource usage as instances
    are built and destroyed.
    Manage all of accelerator(Rock) devices can be used by the nova-compute node.

    This class fetches accelerator device information from trackers the usage of these devices.

    It's called by rock-controller node resource tracker to allocate and free
    devices to/from instances, and to update the available accelerator
    devices information from hardware driver periodically. The accelerator
    information is updated to DB when devices information is changed.
    """
    primary_key = ["address"]

    def __init__(self, context, host, driver, agent_node=None, node_name=None):
        self.context = context
        self.host = host
        self.driver = driver
        self.node_name = node_name
        self.agent_node = agent_node
        self.stats = acc_stats.RockDeviceStats()
        self.stale = {}
        if self.agent_node:
            self.rock_accelerators = list(
                objects.AcceleratorList.get_by_compute_node(context, self.agent_node['id']))
        else:
            self.rock_accelerators = []
        self.add_accelerators = []
        self.local_device_address = set()
        self._initial_instance_usage()

    def _initial_instance_usage(self):
        self.allocations = collections.defaultdict(list)
        self.claims = collections.defaultdict(list)
        for accelerator in self.rock_accelerators:
            if acc_type.FUNCTION_VF == accelerator["function_type"]:
                self.local_device_address.add(accelerator["address"])
            uuid = accelerator['instance_uuid']
            if accelerator['status'] == 'claimed':
                self.claims[uuid].append(accelerator)
            elif accelerator['status'] == 'allocated':
                self.allocations[uuid].append(accelerator)
                self.stats.attach_accelerator(accelerator)
            # elif accelerator['status'] == 'available':
            #     self.stats.detach_accelerator(accelerator)

    def update_rock_capability(self, context):
        """Override in-memory calculations of rock-agent node accelerator usage.

        Add in accelerator claims in progress to account for operations that have
        declared a need for accelerators.
        """
        LOG.audit(_("Auditing locally available agent accelerators"))
        if not self.check_compute_node():
            return
        for driver_item in self.driver:
            relationship = driver_item.get_relation()
            dev_type = driver_item.get_device_type()
            dev_acc_list = [dev_acc_item for dev_acc_item in self.rock_accelerators if
                            dev_acc_item['device_type'] == dev_type]
            if not self.rock_accelerators or len(dev_acc_list) == 0:
                vf_accelerator = objects.Accelerator.create(driver_item.get_vf_resource())
                # pf_accelerator = objects.Accelerator.create(driver_item.get_total_resource(self.node_name))
                vf_accelerator.compute_node_id = self.agent_node['id']
                self.add_accelerators.append(vf_accelerator)
                self.local_device_address.add(vf_accelerator["address"])
            for index, accelerator in enumerate(self.rock_accelerators):
                if accelerator['device_type'] != dev_type:
                    continue
                self._update_accelerator(accelerator, driver_item, index)
            self._update_relationship(driver_item)

        self._syn_resource_db(context)
        self.update_rock_stat(context)

    def _update_accelerator(self, accelerator, acc_driver, index):
        function_type = accelerator['function_type']
        if accelerator.compute_node_id != self.agent_node['id']:
            accelerator.compute_node_id = self.agent_node['id']
        # update vf resource
        if acc_type.FUNCTION_VF == function_type:
            query_id = 'address'
            query_value = [query_id]
            driver_acc = acc_driver.get_vf_resource(accelerator)
            if not driver_acc:
                self.rock_accelerators[index].status = 'removed'
            elif accelerator['status'] in ('claimed', 'allocated'):
                self.stale[accelerator[query_id]] = accelerator
            else:
                temp_address = accelerator['address']
                acc_device.update_accelerator(accelerator, driver_acc)
                #========================================================================================================
                #******************                                                                         *************
                #******************                                                                         *************
                #******************                                                                         *************
                #******************                                                                         *************
                #******************                             TEST                                        *************
                #******************                                                                         *************
                #******************                                                                         *************
                #******************                                                                         *************
                #=================================only for test===========================================================
                # accelerator['address'] = temp_address
                # self.rock_accelerators[index] = accelerator
        # update pf resource
        elif acc_type.FUNCTION_PF == function_type:
            if not hasattr(self, 'pf_dict'):
                self.address_id_dict = {}
            self.address_id_dict[accelerator['address']] = accelerator['id']
            self.local_device_address.add(accelerator["address"])
        else:
            LOG.warn('Can not find accelerator with type-%s and id-%s ' % (function_type, accelerator['id']))
            self.rock_accelerators[index].status = 'removed'

    def _update_relationship(self, driver):
        relationship = driver.get_relation()
        total_pf_address = None
        if relationship and relationship['total_pf_address']:
            total_pf_address = relationship.pop('total_pf_address')
        total_resource = driver.get_total_resource(self.node_name)
        capability_used = []

        for accelerator in self.rock_accelerators:
            #  find used capability
            if relationship.has_key(accelerator["address"]) and accelerator['status'] in ('claimed', 'allocated'):
                capability_used.append(accelerator['acc_capability'])

        if capability_used and len(capability_used) >= 0:
            for use_element in capability_used:
                total_resource = dict_util._dict_subtract(total_resource, use_element)

        if not hasattr(self, 'pf_address_dict'):
            self.pf_address_dict = {}
        else:
            for key in self.pf_address_dict.keys():
                self.pf_address_dict[key] = 0
        for index, accelerator in enumerate(self.rock_accelerators):
            #  refresh total pf accelerator capability
            if accelerator['address'] == total_pf_address:
                temp_capability = {}
                for key, element in total_resource.items():
                    if isinstance(element, str):
                        temp_capability[key] = element
                    else:
                        temp_capability[key] = str(jsonutils.dumps(element))
                self.rock_accelerators[index]['acc_capability'] = temp_capability
                break

        for vf_address, pf_address in relationship.items():
            if pf_address in self.pf_address_dict.keys():
                self.pf_address_dict[pf_address] += 1
            else:
                self.pf_address_dict[pf_address] = 1
            self.create_acc_from_vf(pf_address, vf_address, driver)

        for index, accelerator in enumerate(self.rock_accelerators):
            if acc_type.FUNCTION_PF == accelerator['function_type'] and accelerator[
                'address'] in self.pf_address_dict.keys() and accelerator['vf_number'] != self.pf_address_dict[
                accelerator['address']]:
                self.rock_accelerators[index]['vf_number'] = self.pf_address_dict[accelerator['address']]
            #  update pf id for vf parents
            if acc_type.FUNCTION_VF == accelerator['function_type'] and hasattr(self,'address_id_dict') \
                    and relationship.has_key(accelerator['address']) \
                    and self.address_id_dict.has_key(relationship[accelerator['address']]):
                parent_pf_id = self.address_id_dict[relationship[accelerator['address']]]
                if accelerator['belong_pf_id'] != parent_pf_id:
                    LOG.audit('Change accelerator[id=%s] parent from %s to %s' % (
                        accelerator['id'], accelerator['belong_pf_id'], parent_pf_id))
                    self.rock_accelerators[index]['belong_pf_id'] = parent_pf_id

    def create_acc_from_vf(self, pf_address, vf_address, driver):
        if pf_address not in self.local_device_address and vf_address in self.local_device_address:
            vf_accelerator = None
            for acc in self.rock_accelerators:
                # LOG.info("acc address : %s , vf_address : %s , equal : %s" % (acc["address"], vf_address, acc['address'] == vf_address))
                if acc['address'] == vf_address:
                    vf_accelerator = copy.deepcopy(acc)
                    break
            if vf_accelerator:
                self.add_accelerators.append(self.build_accelerator(pf_address, vf_accelerator))
                self.local_device_address.add(pf_address)
        if vf_address not in self.local_device_address:
            vf_accelerator = objects.Accelerator.create(
                driver.get_vf_resource(objects.Accelerator.create({"address": vf_address})))
            self.add_accelerators.append(vf_accelerator)
            self.local_device_address.add(vf_address)

    def build_accelerator(self, address, accelerator):
        pf_accelerator = objects.Accelerator.create({"address": address})
        for k in accelerator.fields.keys():
            if k not in ['status', 'instance_uuid', 'address', 'id', 'extra_info', 'belong_pf_id', "request_id",
                         "created_at", "deleted_at"]:
                pf_accelerator[k] = accelerator[k]
        pf_accelerator['address'] = address
        pf_accelerator['acc_capability'] = None
        pf_accelerator['belong_pf_id'] = None
        pf_accelerator['function_type'] = acc_type.FUNCTION_PF
        pf_accelerator['instance_uuid'] = None
        pf_accelerator['status'] = 'available'
        pf_accelerator['compute_node_id'] = self.agent_node['id']
        pf_accelerator['vf_number'] = 0
        return pf_accelerator

    def update_rock_stat(self, context):
        stat_pools = self.stats.syn_pools()
        temp_capability = {}
        for key, element in stat_pools.items():
            if isinstance(element, str):
                temp_capability[key] = element
            else:
                temp_capability[key] = str(jsonutils.dumps(element))
        if self.agent_node['accelerator_stats'] != temp_capability:
            self.agent_node['accelerator_stats'] = temp_capability
            LOG.info("Save compute node ====> %s" % self.agent_node)
            self.agent_node.save(context)

    @utils.synchronized(AGENT_RESOURCE_SEMAPHORE)
    def _syn_resource_db(self, context):
        for acc in self.rock_accelerators:
            if acc['host'] != self.host:
                acc['host'] = self.host
            if acc.obj_what_changed():
                # LOG.info("Save accelerator ====> id:%s, address:%s, pf:%s" % (acc['id'], acc['address'], acc['belong_pf_id']))
                if acc.obj_get_changes().has_key('status'):
                    del acc.obj_get_changes()['status']
                acc.save(context)

        self.set_compute_node_id(self.agent_node['id'])
        for acc in self.add_accelerators:
            LOG.info("Add new accelerator ====> %s" % acc)
            acc.save(context)
            self.rock_accelerators.append(acc)

        self.rock_accelerators = [acc for acc in self.rock_accelerators
                                  if acc['status'] != 'deleted']
        self.add_accelerators = []

    def check_compute_node(self):
        if not self.agent_node:
            if self.node_name:
                node_list = objects.ComputeNodeList.get_by_hypervisor(self.context, self.node_name)
                equal_name_node = [node for node in node_list if node['hypervisor_hostname'] == self.node_name]
                if equal_name_node:
                    self.agent_node = equal_name_node[0]
                    return True
                else:
                    return False
            else:
                return False
        return True

    def allocate_accelerator(self, acc_list):
        LOG.info("======================agent tracker allocate_acc----------")
        self.stats.allocate_accelerator(acc_list)
        for acc in acc_list:
            acc_device.allocate(acc, None)
        # LOG.info("result ====> %s" % acc_list)
        return acc_list

    def delocate_accelerator(self, acc_list):
        self.stats.free_accelerator(acc_list)
        for acc in acc_list:
            if acc['status'] in ['claimed', 'allocated']:
                acc_device.free(acc, None)
        # LOG.info("result ====> %s" % acc_list)
        return acc_list

    def _free_instance(self, instance):
        # Note(yjiang5): When a instance is resized, the devices in the
        # destination node are claimed to the instance in prep_resize stage.
        # However, the instance contains only allocated virtual accelerators
        # information, not the claimed one. So we can't use
        # instance['rocket_devices'] to check the devices to be freed.
        for accelerator in self.rock_accelerators:
            if (accelerator['status'] in ('claimed', 'allocated') and
                        accelerator['instance_uuid'] == instance['uuid']):
                self._free_accelerator(accelerator)

    def _free_accelerator(self, accelerator, instance=None):
        acc_device.free(accelerator, instance)
        stale = self.stale.pop(accelerator['address'], None)
        if stale:
            acc_device.update_accelerator(accelerator, stale)
        self.stats.remove_accelerator(accelerator)

    def set_compute_node_id(self, node_id):
        """Set the compute node id that this object is tracking for.

        In current resource tracker implementation, the
        compute_node entry is created in the last step of
        update_available_resoruces, thus we have to lazily set the
        compute_node_id at that time.
        """

        if self.agent_node['id'] and self.agent_node['id'] != node_id:
            raise exception.AcceleratorDeviceInvalidNodeid(node_id=self.agent_node['id'],
                                                           new_node_id=node_id)
        for accelerator in self.add_accelerators:
            accelerator.compute_node_id = node_id

    def update_usage(context, instance):
        pass