diff --git a/cachalot/utils.py b/cachalot/utils.py index 5d92113..11017fd 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -18,9 +18,6 @@ from .settings import cachalot_settings from .transaction import AtomicCache -DJANGO_GTE_1_9 = django_version[:2] >= (1, 9) - - class UncachableQuery(Exception): pass @@ -33,19 +30,24 @@ CACHABLE_PARAM_TYPES = { } UNCACHABLE_FUNCS = set() -if DJANGO_GTE_1_9: +if django_version[:2] >= (1, 9): from django.db.models.functions import Now from django.contrib.postgres.functions import TransactionNow UNCACHABLE_FUNCS.update((Now, TransactionNow)) try: + from psycopg2 import Binary from psycopg2.extras import ( NumericRange, DateRange, DateTimeRange, DateTimeTZRange, Inet, Json) except ImportError: pass else: CACHABLE_PARAM_TYPES.update(( + Binary, NumericRange, DateRange, DateTimeRange, DateTimeTZRange, Inet, Json)) + if django_version[:2] >= (1, 11): + from django.contrib.postgres.fields.jsonb import JsonAdapter + CACHABLE_PARAM_TYPES.add(JsonAdapter) def check_parameter_types(params): @@ -75,7 +77,7 @@ def get_query_cache_key(compiler): """ sql, params = compiler.as_sql() check_parameter_types(params) - cache_key = '%s:%s:%s' % (compiler.using, sql, params) + cache_key = '%s:%s:%s' % (compiler.using, sql, [str(p) for p in params]) return sha1(cache_key.encode('utf-8')).hexdigest() @@ -103,8 +105,8 @@ def _get_table_cache_key(db_alias, table): def _get_tables_from_sql(connection, lowercased_sql): - return [t for t in connection.introspection.django_table_names() - if t in lowercased_sql] + return {t for t in connection.introspection.django_table_names() + if t in lowercased_sql} def _find_subqueries(children): @@ -152,21 +154,22 @@ def filter_cachable(tables): return tables -def _get_tables(query, db_alias): - if ('?' in query.order_by and not cachalot_settings.CACHALOT_CACHE_RANDOM) \ - or query.select_for_update: +def _get_tables(db_alias, query): + if query.select_for_update or ( + '?' in query.order_by + and not cachalot_settings.CACHALOT_CACHE_RANDOM): raise UncachableQuery - tables = set(query.table_map) - tables.add(query.get_meta().db_table) - subquery_constraints = _find_subqueries(query.where.children) - for subquery in subquery_constraints: - tables.update(_get_tables(subquery, db_alias)) - if query.extra_select or (hasattr(query, 'subquery') and query.subquery) \ + if query.extra_select or getattr(query, 'subquery', False) \ or any(c.__class__ is ExtraWhere for c in query.where.children): sql = query.get_compiler(db_alias).as_sql()[0].lower() - additional_tables = _get_tables_from_sql(connections[db_alias], sql) - tables.update(additional_tables) + tables = _get_tables_from_sql(connections[db_alias], sql) + else: + tables = set(query.table_map) + tables.add(query.get_meta().db_table) + subquery_constraints = _find_subqueries(query.where.children) + for subquery in subquery_constraints: + tables.update(_get_tables(db_alias, subquery)) if not are_all_cachable(tables): raise UncachableQuery @@ -176,7 +179,7 @@ def _get_tables(query, db_alias): def _get_table_cache_keys(compiler): db_alias = compiler.using return [_get_table_cache_key(db_alias, t) - for t in _get_tables(compiler.query, db_alias)] + for t in _get_tables(db_alias, compiler.query)] def _invalidate_tables(cache, db_alias, tables):