Kaydet (Commit) 2a4af0ea authored tarafından Josh Smeaton's avatar Josh Smeaton

Fixed #25774 -- Refactor datetime expressions into public API

üst 77b73e79
from django.db.models import DateTimeField
from django.db.models.functions import Func
from django.db.models import DateTimeField, Func
class TransactionNow(Func):
......
import copy
import datetime
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.query_utils import Q
from django.utils import six, timezone
from django.utils import six
from django.utils.functional import cached_property
......@@ -860,111 +859,6 @@ class Case(Expression):
return sql, sql_params
class Date(Expression):
"""
Add a date selection column.
"""
def __init__(self, lookup, lookup_type):
super(Date, self).__init__(output_field=fields.DateField())
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
if settings.USE_TZ:
assert not isinstance(field, fields.DateTimeField), (
"%r is a DateTimeField, not a DateField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
def copy(self):
copy = super(Date, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
return copy
def convert_value(self, value, expression, connection, context):
if isinstance(value, datetime.datetime):
value = value.date()
return value
class DateTime(Expression):
"""
Add a datetime selection column.
"""
def __init__(self, lookup, lookup_type, tzinfo):
super(DateTime, self).__init__(output_field=fields.DateTimeField())
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
if tzinfo is None:
self.tzname = None
else:
self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo
def __repr__(self):
return "{}({}, {}, {})".format(
self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateTimeField), (
"%r isn't a DateTimeField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
def copy(self):
copy = super(DateTime, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
copy.tzname = self.tzname
return copy
def convert_value(self, value, expression, connection, context):
if settings.USE_TZ:
if value is None:
raise ValueError(
"Database returned an invalid value in QuerySet.datetimes(). "
"Are time zone definitions for your database and pytz installed?"
)
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo)
return value
class OrderBy(BaseExpression):
template = '%(expression)s %(ordering)s'
......
from .base import (
Cast, Coalesce, Concat, ConcatPair, Greatest, Least, Length, Lower, Now,
Substr, Upper,
)
from .datetime import (
Extract, ExtractDay, ExtractHour, ExtractMinute, ExtractMonth,
ExtractSecond, ExtractWeekDay, ExtractYear, Trunc, TruncDate, TruncDay,
TruncHour, TruncMinute, TruncMonth, TruncSecond, TruncYear,
)
__all__ = [
# base
'Cast', 'Coalesce', 'Concat', 'ConcatPair', 'Greatest', 'Least', 'Length',
'Lower', 'Now', 'Substr', 'Upper',
# datetime
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
'ExtractSecond', 'ExtractWeekDay', 'ExtractYear',
'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
'TruncSecond', 'TruncYear',
]
from __future__ import absolute_import
from datetime import datetime
from django.conf import settings
from django.db.models import (
DateField, DateTimeField, IntegerField, TimeField, Transform,
)
from django.db.models.lookups import (
YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
from django.utils.functional import cached_property
class TimezoneMixin(object):
tzinfo = None
def get_tzname(self):
# Timezone conversions must happen to the input datetime *before*
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
# database as 2016-01-01 01:00:00 +00:00. Any results should be
# based on the input datetime not the stored datetime.
tzname = None
if settings.USE_TZ:
if self.tzinfo is None:
tzname = timezone.get_current_timezone_name()
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return tzname
class Extract(TimezoneMixin, Transform):
lookup_name = None
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError('lookup_name must be provided')
self.tzinfo = tzinfo
super(Extract, self).__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super(Extract, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
if not isinstance(field, (DateField, DateTimeField, TimeField)):
raise ValueError('Extract input expression must be DateField, DateTimeField, or TimeField.')
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
)
return copy
@cached_property
def output_field(self):
return IntegerField()
class ExtractYear(Extract):
lookup_name = 'year'
class ExtractMonth(Extract):
lookup_name = 'month'
class ExtractDay(Extract):
lookup_name = 'day'
class ExtractWeekDay(Extract):
"""
Return Sunday=1 through Saturday=7.
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = 'week_day'
class ExtractHour(Extract):
lookup_name = 'hour'
class ExtractMinute(Extract):
lookup_name = 'minute'
class ExtractSecond(Extract):
lookup_name = 'second'
DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)
DateTimeField.register_lookup(ExtractYear)
DateTimeField.register_lookup(ExtractMonth)
DateTimeField.register_lookup(ExtractDay)
DateTimeField.register_lookup(ExtractWeekDay)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)
ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)
class TruncBase(TimezoneMixin, Transform):
arity = 1
kind = None
tzinfo = None
def __init__(self, expression, output_field=None, tzinfo=None, **extra):
self.tzinfo = tzinfo
super(TruncBase, self).__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
inner_sql, inner_params = compiler.compile(self.lhs)
# Escape any params because trunc_sql will format the string.
inner_sql = inner_sql.replace('%s', '%%s')
if isinstance(self.output_field, DateTimeField):
tzname = self.get_tzname()
sql, params = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
params = []
else:
raise ValueError('Trunc only valid on DateField or DateTimeField.')
return sql, inner_params + params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super(TruncBase, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
assert isinstance(field, DateField), (
"%r isn't a DateField or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField)):
raise ValueError('output_field must be either DateField or DateTimeField')
# Passing dates to functions expecting datetimes is most likely a
# mistake.
if type(field) == DateField and (
isinstance(copy.output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second')):
raise ValueError("Cannot truncate DateField '%s' to DateTimeField. " % field.name)
return copy
def convert_value(self, value, expression, connection, context):
if isinstance(self.output_field, DateTimeField):
if settings.USE_TZ:
if value is None:
raise ValueError(
"Database returned an invalid datetime value. "
"Are time zone definitions for your database and pytz installed?"
)
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo)
elif isinstance(value, datetime):
# self.output_field is definitely a DateField here.
value = value.date()
return value
class Trunc(TruncBase):
def __init__(self, expression, kind, output_field=None, tzinfo=None, **extra):
self.kind = kind
super(Trunc, self).__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
class TruncYear(TruncBase):
kind = 'year'
class TruncMonth(TruncBase):
kind = 'month'
class TruncDay(TruncBase):
kind = 'day'
class TruncDate(TruncBase):
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params
class TruncHour(TruncBase):
kind = 'hour'
@cached_property
def output_field(self):
return DateTimeField()
class TruncMinute(TruncBase):
kind = 'minute'
@cached_property
def output_field(self):
return DateTimeField()
class TruncSecond(TruncBase):
kind = 'second'
@cached_property
def output_field(self):
return DateTimeField()
DateTimeField.register_lookup(TruncDate)
......@@ -2,13 +2,9 @@ import math
import warnings
from copy import copy
from django.conf import settings
from django.db.models.expressions import Func, Value
from django.db.models.fields import (
DateField, DateTimeField, Field, IntegerField, TimeField,
)
from django.db.models.fields import DateTimeField, Field, IntegerField
from django.db.models.query_utils import RegisterLookupMixin
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.functional import cached_property
from django.utils.six.moves import range
......@@ -480,46 +476,6 @@ class IRegex(Regex):
Field.register_lookup(IRegex)
class DateTimeDateTransform(Transform):
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params
class DateTransform(Transform):
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
return sql, params
@cached_property
def output_field(self):
return IntegerField()
class YearTransform(DateTransform):
lookup_name = 'year'
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
output_field = self.lhs.lhs.output_field
......@@ -530,20 +486,6 @@ class YearLookup(Lookup):
return bounds
@YearTransform.register_lookup
class YearExact(YearLookup):
lookup_name = 'exact'
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
bounds = self.year_lookup_bounds(connection, rhs_params[0])
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
class YearComparisonLookup(YearLookup):
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
......@@ -564,7 +506,27 @@ class YearComparisonLookup(YearLookup):
)
@YearTransform.register_lookup
class YearExact(YearLookup, Exact):
lookup_name = 'exact'
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
try:
# Check that rhs_params[0] exists (IndexError),
# it isn't None (TypeError), and is a number (ValueError)
int(rhs_params[0])
except (IndexError, TypeError, ValueError):
# Can't determine the bounds before executing the query, so skip
# optimizations by falling back to a standard exact comparison.
return super(Exact, self).as_sql(compiler, connection)
bounds = self.year_lookup_bounds(connection, rhs_params[0])
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
class YearGt(YearComparisonLookup):
lookup_name = 'gt'
......@@ -572,7 +534,6 @@ class YearGt(YearComparisonLookup):
return finish
@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
lookup_name = 'gte'
......@@ -580,7 +541,6 @@ class YearGte(YearComparisonLookup):
return start
@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
lookup_name = 'lt'
......@@ -588,52 +548,8 @@ class YearLt(YearComparisonLookup):
return start
@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
lookup_name = 'lte'
def get_bound(self, start, finish):
return finish
class MonthTransform(DateTransform):
lookup_name = 'month'
class DayTransform(DateTransform):
lookup_name = 'day'
class WeekDayTransform(DateTransform):
lookup_name = 'week_day'
class HourTransform(DateTransform):
lookup_name = 'hour'
class MinuteTransform(DateTransform):
lookup_name = 'minute'
class SecondTransform(DateTransform):
lookup_name = 'second'
DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)
TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)
DateTimeField.register_lookup(DateTimeDateTransform)
DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)
......@@ -13,11 +13,12 @@ from django.db import (
DJANGO_VERSION_PICKLE_KEY, IntegrityError, connections, router,
transaction,
)
from django.db.models import sql
from django.db.models import DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import Collector
from django.db.models.expressions import Date, DateTime, F
from django.db.models.expressions import F
from django.db.models.fields import AutoField
from django.db.models.functions import Trunc
from django.db.models.query_utils import (
InvalidQuery, Q, check_rel_lookup_compatibility,
)
......@@ -739,7 +740,7 @@ class QuerySet(object):
assert order in ('ASC', 'DESC'), \
"'order' must be either 'ASC' or 'DESC'."
return self.annotate(
datefield=Date(field_name, kind),
datefield=Trunc(field_name, kind, output_field=DateField()),
plain_field=F(field_name)
).values_list(
'datefield', flat=True
......@@ -760,7 +761,7 @@ class QuerySet(object):
else:
tzinfo = None
return self.annotate(
datetimefield=DateTime(field_name, kind, tzinfo),
datetimefield=Trunc(field_name, kind, output_field=DateTimeField(), tzinfo=tzinfo),
plain_field=F(field_name)
).values_list(
'datetimefield', flat=True
......
......@@ -443,6 +443,13 @@ Models
* A proxy model may now inherit multiple proxy models that share a common
non-abstract parent class.
* Added :class:`~django.db.models.functions.datetime.Extract` functions
to extract datetime components as integers, such as year and hour.
* Added :class:`~django.db.models.functions.datetime.Trunc` functions to
truncate a date or datetime to a significant component. They enable queries
like sales-per-day or sales-per-hour.
* ``Model.__init__()`` now sets values of virtual fields from its keyword
arguments.
......@@ -900,6 +907,10 @@ Miscellaneous
989 characters. If you were counting on a limited length, truncate the subject
yourself.
* Private expressions ``django.db.models.expressions.Date`` and ``DateTime``
are removed. The new :class:`~django.db.models.functions.datetime.Trunc`
expressions provide the same functionality.
.. _deprecated-features-1.10:
Features deprecated in 1.10
......
......@@ -8,6 +8,7 @@ from django.utils.encoding import python_2_unicode_compatible
class Article(models.Model):
title = models.CharField(max_length=100)
pub_date = models.DateTimeField()
published_on = models.DateField(null=True)
categories = models.ManyToManyField("Category", related_name="articles")
......
......@@ -153,3 +153,9 @@ class DateTimesTests(TestCase):
datetime.datetime(2005, 7, 30, 0, 0),
datetime.datetime(2005, 7, 29, 0, 0),
datetime.datetime(2005, 7, 28, 0, 0)])
def test_datetimes_disallows_date_fields(self):
dt = datetime.datetime(2005, 7, 28, 12, 15)
Article.objects.create(pub_date=dt, published_on=dt.date(), title="Don't put dates into datetime functions!")
with self.assertRaisesMessage(ValueError, "Cannot truncate DateField 'published_on' to DateTimeField"):
list(Article.objects.datetimes('published_on', 'second'))
......@@ -41,3 +41,18 @@ class Fan(models.Model):
def __str__(self):
return self.name
@python_2_unicode_compatible
class DTModel(models.Model):
name = models.CharField(max_length=32)
start_datetime = models.DateTimeField(null=True, blank=True)
end_datetime = models.DateTimeField(null=True, blank=True)
start_date = models.DateField(null=True, blank=True)
end_date = models.DateField(null=True, blank=True)
start_time = models.TimeField(null=True, blank=True)
end_time = models.TimeField(null=True, blank=True)
duration = models.DurationField(null=True, blank=True)
def __str__(self):
return 'DTModel({0})'.format(self.name)
This diff is collapsed.
......@@ -11,8 +11,8 @@ from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance,
)
from django.db.models.expressions import (
Case, Col, Date, DateTime, ExpressionWrapper, F, Func, OrderBy, Random,
RawSQL, Ref, Value, When,
Case, Col, ExpressionWrapper, F, Func, OrderBy, Random, RawSQL, Ref, Value,
When,
)
from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper,
......@@ -20,7 +20,6 @@ from django.db.models.functions import (
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import six
from django.utils.timezone import utc
from .models import UUID, Company, Employee, Experiment, Number, Time
......@@ -930,8 +929,6 @@ class ReprTests(TestCase):
"<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
)
self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc)
self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>")
self.assertEqual(
......
......@@ -1312,7 +1312,7 @@ class Queries3Tests(BaseQuerysetTest):
def test_ticket8683(self):
# An error should be raised when QuerySet.datetimes() is passed the
# wrong type of field.
with self.assertRaisesMessage(AssertionError, "'name' isn't a DateTimeField."):
with self.assertRaisesMessage(AssertionError, "'name' isn't a DateField or DateTimeField."):
Item.objects.datetimes('name', 'month')
def test_ticket22023(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment