# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import logging
import os


from awscli.clidriver import CLIOperationCaller
from awscli.customizations.emr import constants
from awscli.customizations.emr import exceptions
from botocore.exceptions import WaiterError, NoCredentialsError
from botocore import xform_name

LOG = logging.getLogger(__name__)


def parse_tags(raw_tags_list):
    tags_dict_list = []
    if raw_tags_list:
        for tag in raw_tags_list:
            if tag.find('=') == -1:
                key, value = tag, ''
            else:
                key, value = tag.split('=', 1)
            tags_dict_list.append({'Key': key, 'Value': value})

    return tags_dict_list


def parse_key_value_string(key_value_string):
    # raw_key_value_string is a list of key value pairs separated by comma.
    # Examples: "k1=v1,k2='v  2',k3,k4"
    key_value_list = []
    if key_value_string is not None:
        raw_key_value_list = key_value_string.split(',')
        for kv in raw_key_value_list:
            if kv.find('=') == -1:
                key, value = kv, ''
            else:
                key, value = kv.split('=', 1)
            key_value_list.append({'Key': key, 'Value': value})
        return key_value_list
    else:
        return None


def apply_boolean_options(
        true_option, true_option_name, false_option, false_option_name):
    if true_option and false_option:
        error_message = \
            'aws: error: cannot use both ' + true_option_name + \
            ' and ' + false_option_name + ' options together.'
        raise ValueError(error_message)
    elif true_option:
        return True
    else:
        return False


# Deprecate. Rename to apply_dict
def apply(params, key, value):
    if value:
        params[key] = value

    return params


def apply_dict(params, key, value):
    if value:
        params[key] = value

    return params


def apply_params(src_params, src_key, dest_params, dest_key):
    if src_key in src_params.keys() and src_params[src_key]:
        dest_params[dest_key] = src_params[src_key]

    return dest_params


def build_step(
        jar, name='Step',
        action_on_failure=constants.DEFAULT_FAILURE_ACTION,
        args=None,
        main_class=None,
        properties=None):
    check_required_field(
        structure='HadoopJarStep', name='Jar', value=jar)

    step = {}
    apply_dict(step, 'Name', name)
    apply_dict(step, 'ActionOnFailure', action_on_failure)
    jar_config = {}
    jar_config['Jar'] = jar
    apply_dict(jar_config, 'Args', args)
    apply_dict(jar_config, 'MainClass', main_class)
    apply_dict(jar_config, 'Properties', properties)
    step['HadoopJarStep'] = jar_config

    return step


def build_bootstrap_action(
        path,
        name='Bootstrap Action',
        args=None):
    if path is None:
        raise exceptions.MissingParametersError(
            object_name='ScriptBootstrapActionConfig', missing='Path')
    ba_config = {}
    apply_dict(ba_config, 'Name', name)
    script_config = {}
    apply_dict(script_config, 'Args', args)
    script_config['Path'] = path
    apply_dict(ba_config, 'ScriptBootstrapAction', script_config)

    return ba_config


def build_s3_link(relative_path='', region='us-east-1'):
    if region is None:
        region = 'us-east-1'
    return 's3://{0}.elasticmapreduce{1}'.format(region, relative_path)


def get_script_runner(region='us-east-1'):
    if region is None:
        region = 'us-east-1'
    return build_s3_link(
        relative_path=constants.SCRIPT_RUNNER_PATH, region=region)


def check_required_field(structure, name, value):
    if not value:
        raise exceptions.MissingParametersError(
            object_name=structure, missing=name)


def check_empty_string_list(name, value):
    if not value or (len(value) == 1 and value[0].strip() == ""):
        raise exceptions.EmptyListError(param=name)


def call(session, operation_name, parameters, region_name=None,
         endpoint_url=None, verify=None):
    # We could get an error from get_endpoint() about not having
    # a region configured.  Before this happens we want to check
    # for credentials so we can give a good error message.
    if session.get_credentials() is None:
        raise NoCredentialsError()

    client = session.create_client(
        'emr', region_name=region_name, endpoint_url=endpoint_url,
        verify=verify)
    LOG.debug('Calling ' + str(operation_name))
    return getattr(client, operation_name)(**parameters)


def get_example_file(command):
    return open('awscli/examples/emr/' + command + '.rst')


def dict_to_string(dict, indent=2):
    return json.dumps(dict, indent=indent)


def get_client(session, parsed_globals):
    return session.create_client(
        'emr',
        region_name=get_region(session, parsed_globals),
        endpoint_url=parsed_globals.endpoint_url,
        verify=parsed_globals.verify_ssl)


def get_cluster_state(session, parsed_globals, cluster_id):
    client = get_client(session, parsed_globals)
    data = client.describe_cluster(ClusterId=cluster_id)
    return data['Cluster']['Status']['State']


def find_master_dns(session, parsed_globals, cluster_id):
    """
    Returns the master_instance's 'PublicDnsName'.
    """
    client = get_client(session, parsed_globals)
    data = client.describe_cluster(ClusterId=cluster_id)
    return data['Cluster']['MasterPublicDnsName']


def which(program):
    for path in os.environ["PATH"].split(os.pathsep):
        path = path.strip('"')
        exe_file = os.path.join(path, program)
        if os.path.isfile(exe_file) and os.access(exe_file, os.X_OK):
            return exe_file

    return None


def call_and_display_response(session, operation_name, parameters,
                              parsed_globals):
    cli_operation_caller = CLIOperationCaller(session)
    cli_operation_caller.invoke(
        'emr', operation_name,
        parameters, parsed_globals)


def display_response(session, operation_name, result, parsed_globals):
    cli_operation_caller = CLIOperationCaller(session)
    # Calling a private method. Should be changed after the functionality
    # is moved outside CliOperationCaller.
    cli_operation_caller._display_response(
        operation_name, result, parsed_globals)


def get_region(session, parsed_globals):
    region = parsed_globals.region
    if region is None:
        region = session.get_config_variable('region')
    return region


def join(values, separator=',', lastSeparator='and'):
    """
    Helper method to print a list of values
    [1,2,3] -> '1, 2 and 3'
    """
    values = [str(x) for x in values]
    if len(values) < 1:
        return ""
    elif len(values) == 1:
        return values[0]
    else:
        separator = '%s ' % separator
        return ' '.join([separator.join(values[:-1]),
                         lastSeparator, values[-1]])


def split_to_key_value(string):
    if string.find('=') == -1:
        return string, ''
    else:
        return string.split('=', 1)


def get_cluster(cluster_id, session, region,
                endpoint_url, verify_ssl):
        describe_cluster_params = {'ClusterId': cluster_id}
        describe_cluster_response = call(
            session, 'describe_cluster', describe_cluster_params,
            region, endpoint_url,
            verify_ssl)

        if describe_cluster_response is not None:
            return describe_cluster_response.get('Cluster')


def get_release_label(cluster_id, session, region,
                      endpoint_url, verify_ssl):
        cluster = get_cluster(cluster_id, session, region,
                              endpoint_url, verify_ssl)
        if cluster is not None:
            return cluster.get('ReleaseLabel')
