Kaydet (Commit) 7171bf75 authored tarafından Josh Smeaton's avatar Josh Smeaton

Refs #14030 -- Added repr methods to all expressions

üst f218a2ff
...@@ -94,6 +94,13 @@ class Count(Aggregate): ...@@ -94,6 +94,13 @@ class Count(Aggregate):
super(Count, self).__init__( super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def __repr__(self):
return "{}({}, distinct={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.extra['distinct'] == '' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return 0 return 0
...@@ -117,6 +124,13 @@ class StdDev(Aggregate): ...@@ -117,6 +124,13 @@ class StdDev(Aggregate):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'STDDEV_POP' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
...@@ -135,6 +149,13 @@ class Variance(Aggregate): ...@@ -135,6 +149,13 @@ class Variance(Aggregate):
self.function = 'VAR_SAMP' if sample else 'VAR_POP' self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra) super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'VAR_POP' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
......
...@@ -340,6 +340,12 @@ class Expression(ExpressionNode): ...@@ -340,6 +340,12 @@ class Expression(ExpressionNode):
self.lhs = lhs self.lhs = lhs
self.rhs = rhs self.rhs = rhs
def __repr__(self):
return "<{}: {}>".format(self.__class__.__name__, self)
def __str__(self):
return "{} {} {}".format(self.lhs, self.connector, self.rhs)
def get_source_expressions(self): def get_source_expressions(self):
return [self.lhs, self.rhs] return [self.lhs, self.rhs]
...@@ -408,7 +414,7 @@ class DurationExpression(Expression): ...@@ -408,7 +414,7 @@ class DurationExpression(Expression):
return expression_wrapper % sql, expression_params return expression_wrapper % sql, expression_params
class F(CombinableMixin): class F(Combinable):
""" """
An object capable of resolving references to existing query objects. An object capable of resolving references to existing query objects.
""" """
...@@ -419,6 +425,9 @@ class F(CombinableMixin): ...@@ -419,6 +425,9 @@ class F(CombinableMixin):
""" """
self.name = name self.name = name
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize) return query.resolve_ref(self.name, allow_joins, reuse, summarize)
...@@ -446,6 +455,13 @@ class Func(ExpressionNode): ...@@ -446,6 +455,13 @@ class Func(ExpressionNode):
self.source_expressions = self._parse_expressions(*expressions) self.source_expressions = self._parse_expressions(*expressions)
self.extra = extra self.extra = extra
def __repr__(self):
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
if extra:
return "{}({}, {})".format(self.__class__.__name__, args, extra)
return "{}({})".format(self.__class__.__name__, args)
def get_source_expressions(self): def get_source_expressions(self):
return self.source_expressions return self.source_expressions
...@@ -504,6 +520,9 @@ class Value(ExpressionNode): ...@@ -504,6 +520,9 @@ class Value(ExpressionNode):
super(Value, self).__init__(output_field=output_field) super(Value, self).__init__(output_field=output_field)
self.value = value self.value = value
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.value)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
val = self.value val = self.value
...@@ -545,6 +564,9 @@ class RawSQL(ExpressionNode): ...@@ -545,6 +564,9 @@ class RawSQL(ExpressionNode):
self.sql, self.params = sql, params self.sql, self.params = sql, params
super(RawSQL, self).__init__(output_field=output_field) super(RawSQL, self).__init__(output_field=output_field)
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params return '(%s)' % self.sql, self.params
...@@ -556,6 +578,9 @@ class Random(ExpressionNode): ...@@ -556,6 +578,9 @@ class Random(ExpressionNode):
def __init__(self): def __init__(self):
super(Random, self).__init__(output_field=fields.FloatField()) super(Random, self).__init__(output_field=fields.FloatField())
def __repr__(self):
return "Random()"
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return connection.ops.random_function_sql(), [] return connection.ops.random_function_sql(), []
...@@ -567,6 +592,10 @@ class Col(ExpressionNode): ...@@ -567,6 +592,10 @@ class Col(ExpressionNode):
super(Col, self).__init__(output_field=source) super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target self.alias, self.target = alias, target
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.target)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias qn = compiler.quote_name_unless_alias
return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
...@@ -588,8 +617,10 @@ class Ref(ExpressionNode): ...@@ -588,8 +617,10 @@ class Ref(ExpressionNode):
""" """
def __init__(self, refs, source): def __init__(self, refs, source):
super(Ref, self).__init__() super(Ref, self).__init__()
self.source = source self.refs, self.source = refs, source
self.refs = refs
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
def get_source_expressions(self): def get_source_expressions(self):
return [self.source] return [self.source]
...@@ -743,6 +774,9 @@ class Date(ExpressionNode): ...@@ -743,6 +774,9 @@ class Date(ExpressionNode):
self.col = None self.col = None
self.lookup_type = lookup_type self.lookup_type = lookup_type
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
def get_source_expressions(self): def get_source_expressions(self):
return [self.col] return [self.col]
...@@ -792,6 +826,10 @@ class DateTime(ExpressionNode): ...@@ -792,6 +826,10 @@ class DateTime(ExpressionNode):
self.tzname = timezone._get_timezone_name(tzinfo) self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo self.tzinfo = tzinfo
def __repr__(self):
return "{}({}, {}, {})".format(
self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
def get_source_expressions(self): def get_source_expressions(self):
return [self.col] return [self.col]
...@@ -833,8 +871,6 @@ class DateTime(ExpressionNode): ...@@ -833,8 +871,6 @@ class DateTime(ExpressionNode):
class OrderBy(BaseExpression): class OrderBy(BaseExpression):
template = '%(expression)s %(ordering)s' template = '%(expression)s %(ordering)s'
descending_template = 'DESC'
ascending_template = 'ASC'
def __init__(self, expression, descending=False): def __init__(self, expression, descending=False):
self.descending = descending self.descending = descending
...@@ -842,6 +878,10 @@ class OrderBy(BaseExpression): ...@@ -842,6 +878,10 @@ class OrderBy(BaseExpression):
raise ValueError('expression must be an expression type') raise ValueError('expression must be an expression type')
self.expression = expression self.expression = expression
def __repr__(self):
return "{}({}, descending={})".format(
self.__class__.__name__, self.expression, self.descending)
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.expression = exprs[0] self.expression = exprs[0]
......
...@@ -6,10 +6,17 @@ import uuid ...@@ -6,10 +6,17 @@ import uuid
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, transaction, DatabaseError from django.db import connection, transaction, DatabaseError
from django.db.models import F, Value, TimeField, UUIDField from django.db.models import TimeField, UUIDField
from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance
from django.db.models.expressions import (
Case, Col, Date, DateTime, F, Func, OrderBy,
Random, RawSQL, Ref, Value, When
)
from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import six from django.utils import six
from django.utils.timezone import utc
from .models import Company, Employee, Number, Experiment, Time, UUID from .models import Company, Employee, Number, Experiment, Time, UUID
...@@ -812,3 +819,40 @@ class ValueTests(TestCase): ...@@ -812,3 +819,40 @@ class ValueTests(TestCase):
UUID.objects.create() UUID.objects.create()
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
class ReprTests(TestCase):
def test_expressions(self):
self.assertEqual(
repr(Case(When(a=1))),
"<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
)
self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)")
self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<Expression: F(cost) + F(tax)>")
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
self.assertEqual(repr(Random()), "Random()")
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
self.assertEqual(repr(Value(1)), "Value(1)")
def test_functions(self):
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))")
self.assertEqual(repr(Length('a')), "Length(F(a))")
self.assertEqual(repr(Lower('a')), "Lower(F(a))")
self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))")
self.assertEqual(repr(Upper('a')), "Upper(F(a))")
def test_aggregates(self):
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
self.assertEqual(repr(Max('a')), "Max(F(a))")
self.assertEqual(repr(Min('a')), "Min(F(a))")
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment