# This file is part of Indico.
# Copyright (C) 2002 - 2026 CERN
#
# Indico is free software; you can redistribute it and/or
# modify it under the terms of the MIT License; see the
# LICENSE file for more details.

import json
from datetime import time, timedelta

import dateutil.parser
import pytz
from flask import session
from markupsafe import escape
from wtforms import Field, SelectField, ValidationError
from wtforms.fields import TimeField
from wtforms.validators import StopValidation
from wtforms_dateutil import DateField, DateTimeField

from indico.core.config import config
from indico.core.logger import Logger
from indico.util.date_time import convert_py_weekdays_to_js, localize_as_utc, relativedelta
from indico.util.i18n import _, get_current_locale, pgettext
from indico.web.forms.fields import JSONField
from indico.web.forms.validators import DateRange, DateTimeRange, LinkedDate, LinkedDateTime
from indico.web.forms.widgets import JinjaWidget


class TimeDeltaField(Field):
    """A field that lets the user select a simple timedelta.

    It does not support mixing multiple units, but it is smart enough
    to switch to a different unit to represent a timedelta that could
    not be represented otherwise.

    :param units: The available units. Must be a tuple containing any
                  any of 'seconds', 'minutes', 'hours', 'days' and 'weeks'.
                  If not specified, ``('hours', 'days')`` is assumed.
    """

    widget = JinjaWidget('forms/timedelta_widget.html', single_line=True, single_kwargs=True)
    unit_names = {
        'seconds': _('Seconds'),
        'minutes': pgettext('Duration', 'Minutes'),
        'hours': _('Hours'),
        'days': _('Days'),
        'weeks': _('Weeks'),
    }
    magnitudes = {
        'weeks': 7*86400,
        'days': 86400,
        'hours': 3600,
        'minutes': 60,
        'seconds': 1,
    }

    def __init__(self, *args, **kwargs):
        self.units = kwargs.pop('units', ('hours', 'days'))
        super().__init__(*args, **kwargs)

    @property
    def best_unit(self):
        """Return the largest unit that covers the current timedelta."""
        if self.data is None:
            return None
        seconds = int(self.data.total_seconds())
        for unit, magnitude in self.magnitudes.items():
            if not seconds % magnitude:
                return unit
        return 'seconds'

    @property
    def choices(self):
        best_unit = self.best_unit
        choices = [(unit, self.unit_names[unit]) for unit in self.units]
        # Add whatever unit is necessary to represent the currenet value if we have one
        if best_unit and best_unit not in self.units:
            choices.append((best_unit, f'({self.unit_names[best_unit]})'))
        return choices

    def process_formdata(self, valuelist):
        if valuelist:
            if valuelist[0]:
                value = int(valuelist[0])
                unit = valuelist[1] if len(valuelist) == 2 else self.units[0]
                if unit not in self.magnitudes:
                    raise ValueError('Invalid unit')
                self.data = timedelta(seconds=self.magnitudes[unit] * value)
            else:
                self.data = None

    def pre_validate(self, form):
        if len(self.units) == 1:
            return
        if self.best_unit in self.units:
            return
        if self.object_data is None:
            raise ValidationError(_('Please choose a valid unit.'))
        if self.object_data != self.data:
            raise ValidationError(_('Please choose a different unit or keep the previous value.'))

    def _value(self):
        if self.data is None:
            return '', ''
        else:
            return int(self.data.total_seconds()) // self.magnitudes[self.best_unit], self.best_unit


class RelativeDeltaField(Field):
    """A field that lets the user select a simple timedelta.

    It does not support mixing multiple units, but it is smart enough
    to switch to a different unit to represent a timedelta that could
    not be represented otherwise.

    :param units: The available units. Must be a tuple containing any
                  any of 'seconds', 'minutes', 'hours' and 'days'.
                  If not specified, ``('hours', 'days')`` is assumed.
    """

    widget = JinjaWidget('forms/timedelta_widget.html', single_line=True, single_kwargs=True)
    unit_names = {
        'seconds': _('Seconds'),
        'minutes': pgettext('Duration', 'Minutes'),
        'hours': _('Hours'),
        'days': _('Days'),
        'weeks': _('Weeks'),
        'months': _('Months'),
        'years': _('Years')
    }
    magnitudes = {
        'years': relativedelta(years=1),
        'months': relativedelta(months=1),
        'weeks': relativedelta(weeks=1),
        'days': relativedelta(days=1),
        'hours': relativedelta(hours=1),
        'minutes': relativedelta(minutes=1),
        'seconds': relativedelta(seconds=1)
    }

    def __init__(self, *args, **kwargs):
        self.units = kwargs.pop('units', ('hours', 'days'))
        super().__init__(*args, **kwargs)

    @property
    def split_data(self):
        if self.data is None:
            return None, None
        for unit in self.magnitudes:
            number = getattr(self.data, unit)
            if number:
                return number, unit
        raise ValueError('Unsupported relativedelta() unit')

    @property
    def choices(self):
        choices = [(unit, self.unit_names[unit]) for unit in self.units]
        number, unit = self.split_data
        if number is not None and unit not in self.units:
            # Add whatever unit is necessary to represent the currenet value if we have one
            choices.append((unit, f'({self.unit_names[unit]})'))
        return choices

    def process_formdata(self, valuelist):
        if valuelist and len(valuelist) == 2:
            value = int(valuelist[0])
            unit = valuelist[1]
            if unit not in self.magnitudes:
                raise ValueError('Invalid unit')
            self.data = (self.magnitudes[unit] * value).normalized()

    def pre_validate(self, form):
        if self.object_data is None:
            raise ValidationError(_('Please choose a valid unit.'))

    def _value(self):
        if self.data is None:
            return '', ''
        return self.split_data


class IndicoTimeField(TimeField):
    widget = JinjaWidget('forms/time_widget.html', single_line=True, single_kwargs=True)


class IndicoDurationField(Field):
    widget = JinjaWidget('forms/duration_widget.html', single_line=True, single_kwargs=True)

    def _value(self):
        if self.data is None:
            return 0
        else:
            return int(self.data.total_seconds())

    def process_formdata(self, valuelist):
        if valuelist:
            self.data = timedelta(seconds=int(valuelist[0]))
            if self.data.total_seconds() % 60:
                Logger.get('forms').warning('Duration with seconds submitted')
                raise ValueError('Duration cannot contain seconds')


class IndicoDateField(DateField):
    widget = JinjaWidget('forms/date_widget.html', single_line=True, single_kwargs=True)

    def __init__(self, *args, allow_clear=None, disabled_days=None, disabled_dates=None, **kwargs):
        self.allow_clear = allow_clear
        self.disabled_days = disabled_days
        self.disabled_dates = disabled_dates
        super().__init__(*args, **kwargs)
        if self.allow_clear is None:
            self.allow_clear = not self.flags.required

    def pre_validate(self, form):
        if self.data and self.disabled_days and self.data.weekday() in self.disabled_days:
            raise StopValidation(_('Disabled day selected'))
        if self.data and self.disabled_dates and self.data in self.disabled_dates:
            raise StopValidation(_('Disabled date selected'))

    @property
    def disabled_days_js(self):
        return convert_py_weekdays_to_js(self.disabled_days) if self.disabled_days else []

    @property
    def earliest_date(self):
        if self.flags.date_range:
            for validator in self.validators:
                if isinstance(validator, DateRange):
                    return validator.get_earliest(self.get_form(), self)

    @property
    def latest_date(self):
        if self.flags.date_range:
            for validator in self.validators:
                if isinstance(validator, DateRange):
                    return validator.get_latest(self.get_form(), self)

    @property
    def linked_field(self):
        validator = self.linked_date_validator
        return validator.linked_field if validator else None

    @property
    def linked_date_validator(self):
        if self.flags.linked_date:
            for validator in self.validators:
                if isinstance(validator, LinkedDate):
                    return validator


class IndicoDateTimeField(DateTimeField):
    """Friendly datetime field that handles timezones and validations.

    Important: When the form has a `timezone` field it must be
    declared before any `IndicoDateTimeField`. Otherwise, its
    value is not available in this field resulting in an error
    during form submission.
    """

    widget = JinjaWidget('forms/datetime_widget.html', single_line=True, single_kwargs=True)

    def __init__(self, *args, timezone=None, default_time=time(0, 0), allow_clear=None,
                 disabled_days=None, disabled_dates=None, **kwargs):
        self._timezone = timezone
        self.default_time = default_time
        self.date_missing = False
        self.time_missing = False
        self.allow_clear = allow_clear
        self.disabled_days = disabled_days
        self.disabled_dates = disabled_dates
        super().__init__(*args, **kwargs)
        if self.allow_clear is None:
            self.allow_clear = not self.flags.required

    def pre_validate(self, form):
        if self.date_missing:
            raise StopValidation(_('Date must be specified'))
        if self.time_missing:
            raise StopValidation(_('Time must be specified'))
        if self.object_data:
            # Normalize datetime resolution of passed data
            self.object_data = self.object_data.replace(second=0, microsecond=0)
        if self.data and self.disabled_days and self.data.date().weekday() in self.disabled_days:
            raise StopValidation(_('Disabled day selected'))
        if self.data and self.disabled_dates and self.data.date() in self.disabled_dates:
            raise StopValidation(_('Disabled date selected'))

    def process_formdata(self, valuelist):
        if any(valuelist):
            if not valuelist[0]:
                self.date_missing = True
            if len(valuelist) < 2 or not valuelist[1]:
                self.time_missing = True
        if valuelist:
            valuelist = [' '.join(valuelist).strip()]
        super().process_formdata(valuelist)
        if self.data and not self.data.tzinfo:
            self.data = localize_as_utc(self.data, self.timezone)

    @property
    def disabled_days_js(self):
        return convert_py_weekdays_to_js(self.disabled_days) if self.disabled_days else []

    @property
    def earliest_dt(self):
        if self.flags.datetime_range:
            for validator in self.validators:
                if isinstance(validator, DateTimeRange):
                    return validator.get_earliest(self.get_form(), self)

    @property
    def latest_dt(self):
        if self.flags.datetime_range:
            for validator in self.validators:
                if isinstance(validator, DateTimeRange):
                    return validator.get_latest(self.get_form(), self)

    @property
    def linked_datetime_validator(self):
        if self.flags.linked_datetime:
            for validator in self.validators:
                if isinstance(validator, LinkedDateTime):
                    return validator

    @property
    def linked_field(self):
        validator = self.linked_datetime_validator
        return validator.linked_field if validator else None

    @property
    def timezone_field(self):
        field = getattr(self.get_form(), 'timezone', None)
        return field if isinstance(field, SelectField) else None

    @property
    def timezone(self):
        if self._timezone:
            return self._timezone
        elif self.timezone_field:
            return self.timezone_field.data
        else:
            form = self.get_form()
            if form and hasattr(form, 'timezone'):
                return form.timezone
            return session.tzinfo.zone

    @property
    def tzinfo(self):
        return pytz.timezone(self.timezone)


class OccurrencesField(JSONField):
    """
    A field that lets you select multiple occurrences consisting of a
    start date/time and a duration.
    """

    widget = JinjaWidget('forms/occurrences_widget.html', single_line=True)
    CAN_POPULATE = True

    def __init__(self, *args, **kwargs):
        self._timezone = kwargs.pop('timezone', None)
        kwargs.setdefault('default', [])
        super().__init__(*args, **kwargs)

    def process_formdata(self, valuelist):
        def _deserialize(occ):
            try:
                dt = dateutil.parser.parse('{} {}'.format(occ['date'], occ['time']))
            except ValueError:
                raise ValueError('Invalid date/time: {} {}'.format(escape(occ['date']), escape(occ['time'])))
            if occ['duration'] < 1:
                raise ValueError('Invalid duration')
            return localize_as_utc(dt, self.timezone), timedelta(minutes=occ['duration'])

        self.data = []
        super().process_formdata(valuelist)
        self.data = list(map(_deserialize, self.data))

    def _value(self):
        def _serialize(occ):
            if isinstance(occ, dict):
                # raw data from the client
                return occ
            dt = occ[0].astimezone(pytz.timezone(self.timezone))
            return {'date': dt.date().isoformat(),
                    'time': dt.time().isoformat()[:-3],  # hh:mm only
                    'duration': int(occ[1].total_seconds() // 60)}

        return json.dumps(list(map(_serialize, self.data)))

    @property
    def timezone_field(self):
        field = getattr(self.get_form(), 'timezone', None)
        return field if isinstance(field, SelectField) else None

    @property
    def timezone(self):
        if self.timezone_field:
            return self.timezone_field.data
        else:
            return getattr(self.get_form(), 'timezone', session.tzinfo.zone)


class IndicoTimezoneSelectField(SelectField):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.choices = [(v, v) for v in pytz.common_timezones]
        self.default = config.DEFAULT_TIMEZONE

    def process_data(self, value):
        super().process_data(value)
        if self.data is not None and self.data not in pytz.common_timezones_set:
            self.choices.append((self.data, self.data))


class IndicoWeekDayRepetitionField(Field):
    """Field that lets you select an ordinal day of the week."""

    widget = JinjaWidget('forms/week_day_repetition_widget.html', single_line=True)

    WEEK_DAY_NUMBER_CHOICES = (
        (1, _('first')),
        (2, _('second')),
        (3, _('third')),
        (4, _('fourth')),
        (-1, _('last'))
    )

    def __init__(self, *args, **kwargs):
        locale = get_current_locale()
        self.day_number_options = self.WEEK_DAY_NUMBER_CHOICES
        self.week_day_options = [(n, locale.weekday(n, short=False)) for n in range(7)]
        self.day_number_missing = False
        self.week_day_missing = False
        super().__init__(*args, **kwargs)

    def process_formdata(self, valuelist):
        self.data = ()
        if any(valuelist):
            if not valuelist[0]:
                self.day_number_missing = True
            if not valuelist[1]:
                self.week_day_missing = True
        if valuelist:
            self.data = tuple(map(int, valuelist))

    @property
    def day_number_data(self):
        return self.data[0] if len(self.data) > 0 else None

    @property
    def week_day_data(self):
        return self.data[1] if len(self.data) > 1 else None
