162 lines
4.0 KiB
Python
162 lines
4.0 KiB
Python
# publik-django-templatetags
|
|
# Copyright (C) 2022 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import math
|
|
from decimal import Decimal
|
|
from decimal import DivisionByZero as DecimalDivisionByZero
|
|
from decimal import InvalidOperation as DecimalInvalidOperation
|
|
|
|
from django import template
|
|
from django.template import defaultfilters
|
|
from django.utils.encoding import force_text
|
|
|
|
register = template.Library()
|
|
|
|
|
|
@register.filter(name='get')
|
|
def get(obj, key):
|
|
try:
|
|
return obj.get(key)
|
|
except AttributeError:
|
|
try:
|
|
return obj[key]
|
|
except (IndexError, KeyError, TypeError):
|
|
return None
|
|
|
|
|
|
@register.filter
|
|
def getlist(mapping, key):
|
|
if mapping is None:
|
|
return []
|
|
mapping = list(mapping)
|
|
for value in mapping:
|
|
try:
|
|
yield value.get(key)
|
|
except AttributeError:
|
|
yield None
|
|
|
|
|
|
@register.filter(name='list')
|
|
def as_list(obj):
|
|
return list(obj)
|
|
|
|
|
|
@register.filter
|
|
def split(string, separator=' '):
|
|
return (force_text(string) or '').split(separator)
|
|
|
|
|
|
@register.filter
|
|
def first(value):
|
|
try:
|
|
return defaultfilters.first(value)
|
|
except TypeError:
|
|
return ''
|
|
|
|
|
|
@register.filter
|
|
def last(value):
|
|
try:
|
|
return defaultfilters.last(value)
|
|
except TypeError:
|
|
return ''
|
|
|
|
|
|
def parse_decimal(value, default=Decimal(0)):
|
|
if isinstance(value, str):
|
|
# replace , by . for French users comfort
|
|
value = value.replace(',', '.')
|
|
try:
|
|
return Decimal(value).quantize(Decimal('1.0000')).normalize()
|
|
except (ArithmeticError, TypeError):
|
|
return default
|
|
|
|
|
|
@register.filter(is_safe=False)
|
|
def decimal(value, arg=None):
|
|
if not isinstance(value, Decimal):
|
|
value = parse_decimal(value)
|
|
if arg is None:
|
|
return value
|
|
return defaultfilters.floatformat(value, arg=arg)
|
|
|
|
|
|
@register.filter
|
|
def add(term1, term2):
|
|
'''replace the "add" native django filter'''
|
|
|
|
if term1 is None:
|
|
term1 = ''
|
|
if term2 is None:
|
|
term2 = ''
|
|
term1_decimal = parse_decimal(term1, default=None)
|
|
term2_decimal = parse_decimal(term2, default=None)
|
|
|
|
if term1_decimal is not None and term2_decimal is not None:
|
|
return term1_decimal + term2_decimal
|
|
if term1 == '' and term2_decimal is not None:
|
|
return term2_decimal
|
|
if term2 == '' and term1_decimal is not None:
|
|
return term1_decimal
|
|
return defaultfilters.add(term1, term2)
|
|
|
|
|
|
@register.filter
|
|
def subtract(term1, term2):
|
|
return parse_decimal(term1) - parse_decimal(term2)
|
|
|
|
|
|
@register.filter
|
|
def multiply(term1, term2):
|
|
return parse_decimal(term1) * parse_decimal(term2)
|
|
|
|
|
|
@register.filter
|
|
def divide(term1, term2):
|
|
try:
|
|
return parse_decimal(term1) / parse_decimal(term2)
|
|
except DecimalInvalidOperation:
|
|
return ''
|
|
except DecimalDivisionByZero:
|
|
return ''
|
|
|
|
|
|
@register.filter
|
|
def ceil(value):
|
|
'''the smallest integer value greater than or equal to value'''
|
|
return decimal(math.ceil(parse_decimal(value)))
|
|
|
|
|
|
@register.filter
|
|
def floor(value):
|
|
return decimal(math.floor(parse_decimal(value)))
|
|
|
|
|
|
@register.filter(name='abs')
|
|
def abs_(value):
|
|
return decimal(abs(parse_decimal(value)))
|
|
|
|
|
|
@register.filter(name='sum')
|
|
def sum_(list_):
|
|
if isinstance(list_, str):
|
|
# do not consider string as iterable, to avoid misusage
|
|
return ''
|
|
try:
|
|
return sum(parse_decimal(term) for term in list_)
|
|
except TypeError: # list_ is not iterable
|
|
return ''
|