Source code for django_countries.fields

from typing import Any, Iterable, Tuple, Type, Union
from urllib import parse as urlparse

import pkg_resources
from django import forms
from django.contrib.admin.filters import FieldListFilter
from django.core import checks, exceptions
from django.db.models.fields import BLANK_CHOICE_DASH, CharField
from django.utils.encoding import force_str
from django.utils.functional import lazy
from django.utils.html import escape as escape_html

from django_countries import Countries, countries, filters, ioc_data, widgets
from django_countries.conf import settings

EXTENSIONS = dict(
    (ep.name, ep.load())
    for ep in pkg_resources.iter_entry_points("django_countries.Country")
)


def country_to_text(value):
    if hasattr(value, "code"):
        value = value.code
    if value is None:
        return None
    return force_str(value)


class TemporaryEscape:
    __slots__ = ["country", "original_escape"]

    def __init__(self, country):
        self.country = country

    def __bool__(self):
        return self.country._escape

    def __enter__(self):
        self.original_escape = self.country._escape
        self.country._escape = True

    def __exit__(self, type, value, traceback):
        self.country._escape = self.original_escape


class Country:
    def __init__(self, code, flag_url=None, str_attr="code", custom_countries=None):
        self.flag_url = flag_url
        self._escape = False
        self._str_attr = str_attr
        if custom_countries is countries:
            custom_countries = None
        self.custom_countries = custom_countries
        # Attempt to convert the code to the alpha2 equivalent, but this
        # is not meant to be full validation so use the given code if no
        # match was found.
        self.code = self.countries.alpha2(code) or code

    def __str__(self):
        return force_str(getattr(self, self._str_attr) or "")

    def __eq__(self, other):
        return force_str(self.code or "") == force_str(other or "")

    def __ne__(self, other):
        return not self.__eq__(other)

    def __hash__(self):
        return hash(force_str(self))

    def __repr__(self):
        args = [f"code={self.code!r}"]
        if self.flag_url is not None:
            args.append(f"flag_url={self.flag_url!r}")
        if self._str_attr != "code":
            args.append(f"str_attr={self._str_attr!r}")
        args = ", ".join(args)
        return f"{self.__class__.__name__}({args})"

    def __bool__(self):
        return bool(self.code)

    def __len__(self):
        return len(force_str(self))

    @property
    def countries(self):
        return self.custom_countries or countries

    @property
    def escape(self):
        return TemporaryEscape(self)

    def maybe_escape(self, text):
        if not self.escape:
            return text
        return escape_html(text)

    @property
    def name(self):
        return self.maybe_escape(self.countries.name(self.code))

    @property
    def alpha3(self):
        return self.countries.alpha3(self.code)

    @property
    def numeric(self):
        return self.countries.numeric(self.code)

    @property
    def numeric_padded(self):
        return self.countries.numeric(self.code, padded=True)

    @property
    def flag(self):
        if not self.code:
            return ""
        flag_url = self.flag_url
        if flag_url is None:
            flag_url = settings.COUNTRIES_FLAG_URL
        url = flag_url.format(code_upper=self.code, code=self.code.lower())
        if not url:
            return ""
        url = urlparse.urljoin(settings.STATIC_URL, url)
        return self.maybe_escape(url)

    @property
    def flag_css(self):
        """
        Output the css classes needed to display an HTML element as a flag
        sprite.

        Requires the use of 'flags/sprite.css' or 'flags/sprite-hq.css'.
        Usage example::

            <i class="{{ ctry.flag_css }}" aria-label="{{ ctry.code }}></i>
        """
        if not self.code:
            return ""
        x, y = self.code.lower()
        return f"flag-sprite flag-{x} flag-_{y}"

    @property
    def unicode_flag(self):
        """
        Generate a unicode flag for the given country.

        The logic for how these are determined can be found at:

        https://en.wikipedia.org/wiki/Regional_Indicator_Symbol

        Currently, these glyphs appear to only be supported on OS X and iOS.
        """
        if not self.code:
            return ""

        # Don't really like magic numbers, but this is the code point for [A]
        # (Regional Indicator A), minus the code point for ASCII A. By adding
        # this to the uppercase characters making up the ISO 3166-1 alpha-2
        # codes we can get the flag.
        OFFSET = 127397
        points = [ord(x) + OFFSET for x in self.code.upper()]
        return chr(points[0]) + chr(points[1])

    @staticmethod
    def country_from_ioc(ioc_code, flag_url=""):
        code = ioc_data.IOC_TO_ISO.get(ioc_code, "")
        if code == "":
            return None
        return Country(code, flag_url=flag_url)

    @property
    def ioc_code(self):
        return self.countries.ioc_code(self.code)

    def __getattr__(self, attr):
        if attr in EXTENSIONS:
            return EXTENSIONS[attr](self)
        raise AttributeError()


class CountryDescriptor:
    """
    A descriptor for country fields on a model instance. Returns a Country when
    accessed so you can do things like::

        >>> from people import Person
        >>> person = Person.object.get(name='Chris')

        >>> person.country.name
        'New Zealand'

        >>> person.country.flag
        '/static/flags/nz.gif'
    """

    def __init__(self, field):
        self.field = field

    def __get__(self, instance=None, owner=None):
        if instance is None:
            return self
        # Check in case this field was deferred.
        if self.field.name not in instance.__dict__:
            instance.refresh_from_db(fields=[self.field.name])
        value = instance.__dict__[self.field.name]
        if self.field.multiple:
            return [self.country(code) for code in value]
        return self.country(value)

    def country(self, code):
        return Country(
            code=code,
            flag_url=self.field.countries_flag_url,
            str_attr=self.field.countries_str_attr,
            custom_countries=self.field.countries,
        )

    def __set__(self, instance, value):
        value = self.field.get_clean_value(value)
        instance.__dict__[self.field.name] = value


class LazyChoicesMixin(widgets.LazyChoicesMixin):
    def _set_choices(self, value):
        """
        Also update the widget's choices.
        """
        super()._set_choices(value)
        self.widget.choices = value


_Choice = Tuple[Any, str]
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]


class LazyTypedChoiceField(LazyChoicesMixin, forms.TypedChoiceField):
    """
    A form TypedChoiceField that respects choices being a lazy object.
    """

    choices: Any
    widget = widgets.LazySelect


class LazyTypedMultipleChoiceField(LazyChoicesMixin, forms.TypedMultipleChoiceField):
    """
    A form TypedMultipleChoiceField that respects choices being a lazy object.
    """

    choices: Any
    widget = widgets.LazySelectMultiple


class CountryField(CharField):
    """
    A country field for Django models that provides all ISO 3166-1 countries as
    choices.
    """

    descriptor_class = CountryDescriptor

    def __init__(self, *args, **kwargs):
        countries_class: Type[Countries] = kwargs.pop("countries", None)
        self.countries = countries_class() if countries_class else countries
        self.countries_flag_url = kwargs.pop("countries_flag_url", None)
        self.countries_str_attr = kwargs.pop("countries_str_attr", "code")
        self.blank_label = kwargs.pop("blank_label", None)
        self.multiple = kwargs.pop("multiple", None)
        kwargs["choices"] = self.countries
        if "max_length" not in kwargs:
            # Allow explicit max_length so migrations can correctly identify
            # changes in the multiple CountryField fields when new countries are
            # added to the available countries dictionary.
            if self.multiple:
                kwargs["max_length"] = (
                    len(self.countries)
                    - 1
                    + sum(len(code) for code in self.countries.countries)
                )
            else:
                kwargs["max_length"] = max(
                    len(code) for code in self.countries.countries
                )
        super().__init__(*args, **kwargs)

    def check(self, **kwargs):
        errors = super().check(**kwargs)
        errors.extend(self._check_multiple())
        return errors

    def _check_multiple(self):
        if not self.multiple or not self.null:
            return []

        hint = "Remove null=True argument on the field"
        if not self.blank:
            hint += " (just add blank=True if you want to allow no selection)"
        hint += "."

        return [
            checks.Error(
                "Field specifies multiple=True, so should not be null.",
                obj=self,
                id="django_countries.E100",
                hint=hint,
            )
        ]

    def get_internal_type(self):
        return "CharField"

    def contribute_to_class(self, cls, name):
        super().contribute_to_class(cls, name)
        setattr(cls, self.name, self.descriptor_class(self))

    def pre_save(self, *args, **kwargs):
        "Returns field's value just before saving."
        value = super(CharField, self).pre_save(*args, **kwargs)
        return self.get_prep_value(value)

    def get_prep_value(self, value):
        "Returns field's value prepared for saving into a database."
        value = self.get_clean_value(value)
        if self.multiple:
            if value:
                value = ",".join(value)
            else:
                value = ""
        return super(CharField, self).get_prep_value(value)

    def get_clean_value(self, value):
        if value is None:
            return None
        if not self.multiple:
            return country_to_text(value)
        if isinstance(value, (str, Country)):
            if isinstance(value, str) and "," in value:
                value = value.split(",")
            else:
                value = [value]
        else:
            try:
                iter(value)
            except TypeError:
                value = [value]
        return list(filter(None, [country_to_text(c) for c in value]))

    def deconstruct(self):
        """
        Remove choices from deconstructed field, as this is the country list
        and not user editable.

        Not including the ``blank_label`` property, as this isn't database
        related.
        """
        name, path, args, kwargs = super(CharField, self).deconstruct()
        kwargs.pop("choices", None)
        if self.multiple:  # multiple determines the length of the field
            kwargs["multiple"] = self.multiple
        if self.countries is not countries:
            # Include the countries class if it's not the default countries
            # instance.
            kwargs["countries"] = self.countries.__class__
        return name, path, args, kwargs

    def get_choices(self, include_blank=True, blank_choice=None, *args, **kwargs):
        if blank_choice is None:
            if self.blank_label is None:
                blank_choice = BLANK_CHOICE_DASH
            else:
                blank_choice = [("", self.blank_label)]
        if self.multiple:
            include_blank = False
        return super().get_choices(
            include_blank=include_blank, blank_choice=blank_choice, *args, **kwargs
        )

    get_choices = lazy(get_choices, list)

    def formfield(self, **kwargs):
        kwargs.setdefault(
            "choices_form_class",
            LazyTypedMultipleChoiceField if self.multiple else LazyTypedChoiceField,
        )
        if "coerce" not in kwargs:
            kwargs["coerce"] = super().to_python
        field = super().formfield(**kwargs)
        return field

    def to_python(self, value):
        if not self.multiple:
            return super().to_python(value)
        if not value:
            return value
        if isinstance(value, str):
            value = value.split(",")
        output = []
        for item in value:
            output.append(super().to_python(item))
        return output

    def validate(self, value, model_instance):
        """
        Use custom validation for when using a multiple countries field.
        """
        if not self.multiple:
            return super().validate(value, model_instance)

        if not self.editable:
            # Skip validation for non-editable fields.
            return

        if value:
            choices = [option_key for option_key, option_value in self.choices]
            for single_value in value:
                if single_value not in choices:
                    raise exceptions.ValidationError(
                        self.error_messages["invalid_choice"],
                        code="invalid_choice",
                        params={"value": single_value},
                    )

        if not self.blank and value in self.empty_values:
            raise exceptions.ValidationError(self.error_messages["blank"], code="blank")

    def value_to_string(self, obj):
        """
        Ensure data is serialized correctly.
        """
        value = self.value_from_object(obj)
        return self.get_prep_value(value)


FieldListFilter.register(lambda f: isinstance(f, CountryField), filters.CountryFilter)