# coding: utf-8

# Copyright 2014-2025 Álvaro Justen <https://github.com/turicas/rows/>
#    This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
#    Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
#    any later version.
#    This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
#    warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for
#    more details.
#    You should have received a copy of the GNU Lesser General Public License along with this program.  If not, see
#    <http://www.gnu.org/licenses/>.

from __future__ import unicode_literals

import binascii
import datetime
import json
import locale
import re
import uuid
from base64 import b64decode, b64encode
from collections import defaultdict
from decimal import Decimal, InvalidOperation
from unicodedata import normalize

from rows.compat import BINARY_TYPE, ORDERED_DICT, PYTHON_KEYWORDS_LOWER, PYTHON_VERSION, TEXT_TYPE

if PYTHON_VERSION < (3, 0, 0):
    from itertools import izip_longest as zip_longest  # noqa
else:
    from itertools import zip_longest  # noqa


# Order matters here
__all__ = [
    "BoolField",
    "IntegerField",
    "FloatField",
    "DatetimeField",
    "DateField",
    "DecimalField",
    "PercentField",
    "JSONField",
    "EmailField",
    "UUIDField",
    "TextField",
    "BinaryField",
    "Field",
]
NULL = ("-", "null", "none", "nil", "n/a", "na")
NULL_BYTES = (b"-", b"null", b"none", b"nil", b"n/a", b"na")
REGEXP_CAMELCASE_1 = re.compile("(.)([A-Z][a-z]+)")
REGEXP_CAMELCASE_2 = re.compile("([a-z0-9])([A-Z])")
REGEXP_ONLY_NUMBERS = re.compile(r"[^0-9\-]")
REGEXP_SEPARATOR = re.compile("(_+)")
REGEXP_WORD_BOUNDARY = re.compile("(\\w\\b)")
SHOULD_NOT_USE_LOCALE = True  # This variable is changed by rows.locale_manager
SLUG_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_"

_cache_resize_len = 31000
_cacheable_types = (TEXT_TYPE, BINARY_TYPE, int, float, bool, type(None), datetime.date, datetime.datetime, uuid.UUID)
_max_cache_size = 32000
_deserialization_cache = {}
_deserialization_error = object()

def cached_type_deserialize(type_, value, true_behavior=True):
    """
    LFU cache for type deserialization

    Calls `type_.deserialize(value)`. When `true_behavior` is `True`, exception is raised if value can't be
    deserialized; returns `_deserialization_error` sentinel, otherwise.
    Will only cache values that can be hashed and on `_cacheable_types`.
    """
    from locale import getlocale
    global _deserialization_cache

    should_cache = isinstance(value, _cacheable_types)
    cache_key = hash((type_, type(value), value, SHOULD_NOT_USE_LOCALE or getlocale())) if should_cache else None
    if not should_cache or cache_key not in _deserialization_cache:
        try:
            result = type_.deserialize(value)
        except (ValueError, TypeError):
            if true_behavior:
                raise
            return _deserialization_error
        else:
            if should_cache:
                _deserialization_cache[cache_key] = [result, 1]
                if len(_deserialization_cache) == _max_cache_size:
                    min_freq = _deserialization_cache[sorted(_deserialization_cache.keys(), key=lambda key: _deserialization_cache[key][1])[_cache_resize_len]][1]
                    _deserialization_cache = {k: v for k, v in _deserialization_cache.items() if v[1] > min_freq}
    else:
        result, _ = _deserialization_cache[cache_key]
        _deserialization_cache[cache_key][1] += 1
    return result


def value_error(value, cls):
    value = repr(value)
    if len(value) > 50:
        value = value[:50] + "..."
    raise ValueError("Value '{}' can't be {}".format(value, cls.__name__))


class Field(object):
    """Base Field class - all fields should inherit from this

    As the fallback for all other field types are the BinaryField, this Field
    actually implements what is expected in the BinaryField
    """

    TYPE = (type(None),)
    # TODO: add "name" so we can import automatically from schema CSV

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        """Serialize a value to be exported

        `cls.serialize` should always return an unicode value, except for
        BinaryField
        """

        if value is None:
            value = ""
        return value

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        """Deserialize a value just after importing it

        `cls.deserialize` should always return a value of type `cls.TYPE` or
        `None`.
        """

        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        else:
            return value


class BinaryField(Field):
    """Field class to represent byte arrays

    Is not locale-aware (does not need to be)
    """

    TYPE = (BINARY_TYPE,)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is not None:
            if not isinstance(value, BINARY_TYPE):
                value_error(value, cls)
            else:
                try:
                    return b64encode(value).decode("ascii")
                except (TypeError, binascii.Error):
                    return value
        else:
            return ""

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if value is not None:
            if isinstance(value, BINARY_TYPE):
                return value
            elif isinstance(value, TEXT_TYPE):
                try:
                    return b64decode(value)
                except (TypeError, ValueError, binascii.Error):
                    raise ValueError("Can't decode base64")
            else:
                value_error(value, cls)
        else:
            return b""


class UUIDField(Field):
    """Field class to represent UUIDs

    Is not locale-aware (does not need to be)
    """

    TYPE = (uuid.UUID,)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is not None:
            if not isinstance(value, self.TYPE):
                value_error(value, cls)
            else:
                return TEXT_TYPE(value)
        else:
            return ""

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        value = as_string(value, encoding="ascii").strip()
        if len(value) not in (36, 32):  # with dashes and without dashes
            value_error(value, cls)
        else:
            return uuid.UUID(value)


class BoolField(Field):
    """Base class to representing boolean

    Is not locale-aware (if you need to, please customize by changing its
    attributes like `TRUE_VALUES` and `FALSE_VALUES`)
    """

    TYPE = (bool,)
    SERIALIZED_VALUES = {True: "true", False: "false", None: ""}
    TRUE_VALUES = ("true", "t", "yes")
    FALSE_VALUES = ("false", "f", "no")

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        # TODO: should we serialize `None` as well or give it to the plugin?
        return cls.SERIALIZED_VALUES[value]

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        value = as_string(value).lower()
        if value in cls.TRUE_VALUES:
            return True
        elif value in cls.FALSE_VALUES:
            return False
        else:
            raise ValueError("Value is not boolean")


class IntegerField(Field):
    """Field class to represent integer

    Is locale-aware
    """

    TYPE = (int,)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        if SHOULD_NOT_USE_LOCALE:
            return TEXT_TYPE(value)
        else:
            grouping = kwargs.get("grouping", None)
            return locale.format_string("%d", value, grouping=grouping)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        elif isinstance(value, float):
            new_value = int(value)
            if new_value != value:
                raise ValueError("It's float, not integer")
            else:
                value = new_value
        value = as_string(value)
        return int(value) if SHOULD_NOT_USE_LOCALE else locale.atoi(value)


class FloatField(Field):
    """Field class to represent float

    Is locale-aware
    """

    TYPE = (float,)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        if SHOULD_NOT_USE_LOCALE:
            return TEXT_TYPE(value)
        else:
            grouping = kwargs.get("grouping", None)
            return locale.format_string("%f", value, grouping=grouping)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        value = as_string(value)
        return float(value) if SHOULD_NOT_USE_LOCALE else locale.atof(value)


class DecimalField(Field):
    """Field class to represent decimal data (as Python's decimal.Decimal)

    Is locale-aware
    """

    TYPE = (Decimal,)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        value_as_string = TEXT_TYPE(value)
        if SHOULD_NOT_USE_LOCALE:
            return value_as_string
        else:
            grouping = kwargs.get("grouping", None)
            has_decimal_places = value_as_string.find(".") != -1
            if not has_decimal_places:
                string_format = "%d"
            else:
                decimal_places = len(value_as_string.split(".")[1])
                string_format = "%.{}f".format(decimal_places)
            return locale.format_string(string_format, value, grouping=grouping)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        elif type(value) in (int, float):
            return Decimal(TEXT_TYPE(value))

        if SHOULD_NOT_USE_LOCALE:
            try:
                return Decimal(value)
            except InvalidOperation:
                value_error(value, cls)
        else:
            locale_vars = locale.localeconv()
            decimal_separator = locale_vars["decimal_point"]
            interesting_vars = (
                "decimal_point",
                "mon_decimal_point",
                "mon_thousands_sep",
                "negative_sign",
                "positive_sign",
                "thousands_sep",
            )
            chars = (
                locale_vars[x].replace(".", r"\.").replace("-", r"\-")
                for x in interesting_vars
            )
            interesting_chars = "".join(set(chars))
            regexp = re.compile(r"[^0-9{} ]".format(interesting_chars))
            value = as_string(value)
            if regexp.findall(value):
                value_error(value, cls)

            parts = [
                REGEXP_ONLY_NUMBERS.subn("", number)[0]
                for number in value.split(decimal_separator)
            ]
            if len(parts) > 2:
                raise ValueError("Can't deserialize with this locale.")
            try:
                value = Decimal(parts[0])
                if len(parts) == 2:
                    decimal_places = len(parts[1])
                    value = value + (Decimal(parts[1]) / (10 ** decimal_places))
            except InvalidOperation:
                value_error(value, cls)
            return value


class PercentField(DecimalField):
    """Field class to represent percent values

    Is locale-aware (inherit this behaviour from `rows.DecimalField`)
    """

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""
        elif value == Decimal("0"):
            return "0.00%"

        value = Decimal(TEXT_TYPE(value * 100)[:-2])
        value = super(PercentField, cls).serialize(value, *args, **kwargs)
        return "{}%".format(value)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None

        value = as_string(value)
        if "%" not in value:
            value_error(value, cls)
        value = value.replace("%", "")
        return super(PercentField, cls).deserialize(value) / 100


class DateField(Field):
    """Field class to represent date

    Is not locale-aware (does not need to be)
    """

    TYPE = (datetime.date,)
    INPUT_FORMAT = "%Y-%m-%d"
    OUTPUT_FORMAT = "%Y-%m-%d"

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        return TEXT_TYPE(value.strftime(cls.OUTPUT_FORMAT))

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        # TODO: add locale support?
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        value = as_string(value)
        dt_object = datetime.datetime.strptime(value, cls.INPUT_FORMAT)
        return dt_object.date()


class DatetimeField(Field):
    """Field class to represent date-time

    Is not locale-aware (does not need to be)
    """

    TYPE = (datetime.datetime,)
    DATETIME_REGEXP = re.compile(
        "^([0-9]{4})-([0-9]{2})-([0-9]{2})[ T]" "([0-9]{2}):([0-9]{2}):([0-9]{2})$"
    )

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        return TEXT_TYPE(value.isoformat())

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if isinstance(value, cls.TYPE):
            return value
        elif is_null(value):
            return None
        value = as_string(value)
        # TODO: may use iso8601
        groups = cls.DATETIME_REGEXP.findall(value)
        if not groups:
            value_error(value, cls)
        else:
            return datetime.datetime(*[int(x) for x in groups[0]])


class TextField(Field):
    """Field class to represent unicode strings

    Is not locale-aware (does not need to be)
    """

    TYPE = (TEXT_TYPE,)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if value is None or isinstance(value, cls.TYPE):
            return value
        else:
            return as_string(value)


class EmailField(TextField):
    """Field class to represent e-mail addresses

    Is not locale-aware (does not need to be)
    """

    EMAIL_REGEXP = re.compile(
        r"^[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]+$", flags=re.IGNORECASE
    )

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        if value is None:
            return ""

        return TEXT_TYPE(value)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        if value is None or is_null(value):
            return None
        value = as_string(value)
        result = cls.EMAIL_REGEXP.findall(value)
        if not result:
            value_error(value, cls)
        else:
            return result[0]


class JSONField(Field):
    """Field class to represent JSON-encoded strings

    Is not locale-aware (does not need to be)
    """

    TYPE = (list, dict)

    @classmethod
    def serialize(cls, value, *args, **kwargs):
        return json.dumps(value)

    @classmethod
    def deserialize(cls, value, *args, **kwargs):
        value = super(JSONField, cls).deserialize(value)
        if value is None or isinstance(value, cls.TYPE):
            return value
        else:
            return json.loads(value)


def as_string(value, encoding=None):
    if isinstance(value, BINARY_TYPE):
        if encoding is None:
            raise ValueError("Binary is not supported")
        return value.decode(encoding)
    elif isinstance(value, TEXT_TYPE):
        return value
    else:
        return TEXT_TYPE(value)


def is_null(value):
    if value is None:
        return True
    elif type(value) is BINARY_TYPE:
        value = value.strip().lower()
        return not value or value in NULL_BYTES
    else:
        value_str = as_string(value).strip().lower()
        return not value_str or value_str in NULL


def unique_values(values):
    result = []
    for value in values:
        if not is_null(value) and value not in result:
            result.append(value)
    return result


def get_items(*indexes):
    """Return a callable that fetches the given indexes of an object
    Always return a tuple even when len(indexes) == 1.

    Similar to `operator.itemgetter`, but will insert `None` when the object
    does not have the desired index (instead of raising IndexError).
    """
    return lambda obj: tuple(
        obj[index] if len(obj) > index else None for index in indexes
    )


def slug(text, separator="_", permitted_chars=SLUG_CHARS):
    """Generate a slug for the `text`.

    >>> str(slug(' ÁLVARO  justen% '))
    'alvaro_justen'
    >>> str(slug(' ÁLVARO  justen% ', separator='-'))
    'alvaro-justen'
    """

    text = TEXT_TYPE(text or "")

    # Strip non-ASCII characters
    # Example: u' ÁLVARO  justen% ' -> ' ALVARO  justen% '
    text = normalize("NFKD", text.strip()).encode("ascii", "ignore").decode("ascii")

    # Replace word boundaries with separator
    text = REGEXP_WORD_BOUNDARY.sub("\\1" + re.escape(separator), text)

    # Remove non-permitted characters and put everything to lowercase
    # Example: u'_ALVARO__justen%_' -> u'_alvaro__justen_'
    allowed_chars = list(permitted_chars) + [separator]
    text = "".join(char for char in text if char in allowed_chars).lower()

    # Remove double occurrencies of separator
    # Example: u'_alvaro__justen_' -> u'_alvaro_justen_'
    text = (
        REGEXP_SEPARATOR
        if separator == "_"
        else re.compile("(" + re.escape(separator) + "+)")
    ).sub(separator, text)

    # Strip separators
    # Example: u'_alvaro_justen_' -> u'alvaro_justen'
    return text.strip(separator)


def camel_to_snake(value):
    value = TEXT_TYPE(value or "").strip()
    if not value:
        return ""
    # Adapted from <https://stackoverflow.com/a/1176023/1299446>
    return slug(
        REGEXP_CAMELCASE_2.sub(r"\1_\2", REGEXP_CAMELCASE_1.sub(r"\1_\2", value))
    )


def make_unique_name(name, existing_names, name_format="{name}_{index}", start=2, max_size=None):
    """Return a unique name based on `name_format` and `name`."""
    index = start
    new_name = name
    while new_name in existing_names:
        new_name = name_format.format(name=name, index=index)
        if max_size is not None and len(new_name) > max_size:
            new_name = name_format.format(name=name[:-(len(new_name) - max_size)], index=index)
        index += 1

    return new_name


def make_header(field_names, permit_not=False, max_size=None, prefix="field_"):
    """Return unique and slugged field names."""
    slug_chars = SLUG_CHARS if not permit_not else SLUG_CHARS + "^"

    header = [
        slug(field_name, permitted_chars=slug_chars) for field_name in field_names
    ]
    if max_size is not None:
        header = [
            slug(field_name[:max_size], permitted_chars=slug_chars)
            for field_name in header
        ]
    result = []
    for index, field_name in enumerate(header):
        if not field_name:
            field_name = "{}{}".format(prefix, index)
        elif field_name[0].isdigit():
            field_name = "{}{}".format(prefix, field_name)
        elif field_name in PYTHON_KEYWORDS_LOWER:
            field_name = make_unique_name(
                name=field_name, existing_names=[field_name] + result, start=1, max_size=max_size
            )
        if field_name in result:
            field_name = make_unique_name(
                name=field_name, existing_names=result, start=2, max_size=max_size
            )
        result.append(field_name)

    return result


DEFAULT_TYPES = (
    BoolField,
    IntegerField,
    FloatField,
    DecimalField,
    PercentField,
    DecimalField,
    DatetimeField,
    DateField,
    JSONField,
    TextField,
    BinaryField,
)


def _unique_list_values(values):
    result = []
    for value in values:
        if value not in result:
            result.append(value)
    return result


class TypeDetector(object):
    """Detect data types based on a list of Field classes"""

    def __init__(
        self,
        field_names=None,
        field_types=DEFAULT_TYPES,
        fallback_type=TextField,
        skip_indexes=None,
    ):
        self.field_names = field_names or []
        self.field_types = list(field_types)
        self.fallback_type = fallback_type
        self._possible_types = defaultdict(lambda: list(self.field_types))
        self._is_empty = defaultdict(lambda: True)
        self._samples = []
        self._skip = skip_indexes or tuple()

    def process_row(self, row):
        for index, value in enumerate(row):
            if index in self._skip:
                continue
            for type_ in self._possible_types[index][:]:
                if self._is_empty[index] and not is_null(value):
                    self._is_empty[index] = False
                if cached_type_deserialize(type_, value, true_behavior=False) is _deserialization_error:
                    self._possible_types[index].remove(type_)

    # TODO: create two kinds of `feed`: by row and by column (some formats will have it by column)

    def feed(self, data, batch_size=512):
        if not isinstance(data, list):
            data = list(data)  # Must have all values in memory and indexable
        if not data:
            return
        indices = [index for index in range(len(data[0])) if index not in self._skip]
        if not indices:
            return

        skip, possible_types, is_empty = self._skip, self._possible_types, self._is_empty
        while data:
            for col_index in indices:
                col_values = _unique_list_values(row[col_index] for row in data[:batch_size])
                if is_empty[col_index] and any(not is_null(value) for value in col_values):
                    is_empty[col_index] = False
                for type_ in possible_types[col_index][:]:
                    if any(
                        cached_type_deserialize(type_, value, true_behavior=False) is _deserialization_error
                        for value in col_values
                    ):
                        possible_types[col_index].remove(type_)
            data = data[batch_size:]

    def priority(self, *field_types):
        """Decide the priority between each possible type"""

        return field_types[0] if field_types else self.fallback_type

    def define_field_type(self, is_empty, possible_types):
        if is_empty:
            return self.fallback_type
        else:
            return self.priority(*possible_types)

    @property
    def fields(self):
        possible, skip = self._possible_types, self._skip

        if possible:
            # Create a header with placeholder values for each detected column
            # and then join this placeholders with original header - the
            # original header may have less columns then the detected ones, so
            # we end with a full header having a name for every possible
            # column.
            placeholders = make_header(range(max(possible.keys()) + 1))
            header = [a or b for a, b in zip_longest(self.field_names, placeholders)]
        else:
            header = self.field_names

        return ORDERED_DICT(
            [
                (
                    field_name,
                    self.define_field_type(
                        is_empty=self._is_empty[index],
                        possible_types=possible[index] if index in possible else [],
                    ),
                )
                for index, field_name in enumerate(header)
                if index not in skip
            ]
        )


def detect_types(
    field_names,
    field_values,
    field_types=DEFAULT_TYPES,
    skip_indexes=None,
    type_detector=TypeDetector,
    fallback_type=TextField,
    *args,
    **kwargs
):
    """Detect column types (or "where the magic happens")"""

    # TODO: look strategy of csv.Sniffer.has_header
    # TODO: may receive 'type hints'
    detector = type_detector(
        field_names,
        field_types=field_types,
        fallback_type=fallback_type,
        skip_indexes=skip_indexes,
    )
    detector.feed(field_values)
    return detector.fields


def identify_type(value):
    """Identify the field type for a specific value"""

    return detect_types(["name"], [[value]])["name"]
