Kaydet (Commit) ee85ef83 authored tarafından Simon Charette's avatar Simon Charette

Fixed #28792 -- Fixed index name truncation of namespaced tables.

Refs #27458, #27843.

Thanks Tim and Mariusz for the review.
üst 532a4f22
...@@ -5,7 +5,7 @@ from datetime import datetime ...@@ -5,7 +5,7 @@ from datetime import datetime
from django.db.backends.ddl_references import ( from django.db.backends.ddl_references import (
Columns, ForeignKeyName, IndexName, Statement, Table, Columns, ForeignKeyName, IndexName, Statement, Table,
) )
from django.db.backends.utils import strip_quotes from django.db.backends.utils import split_identifier
from django.db.models import Index from django.db.models import Index
from django.db.transaction import TransactionManagementError, atomic from django.db.transaction import TransactionManagementError, atomic
from django.utils import timezone from django.utils import timezone
...@@ -858,7 +858,7 @@ class BaseDatabaseSchemaEditor: ...@@ -858,7 +858,7 @@ class BaseDatabaseSchemaEditor:
The name is divided into 3 parts: the table name, the column names, The name is divided into 3 parts: the table name, the column names,
and a unique digest and suffix. and a unique digest and suffix.
""" """
table_name = strip_quotes(table_name) _, table_name = split_identifier(table_name)
hash_data = [table_name] + list(column_names) hash_data = [table_name] + list(column_names)
hash_suffix_part = '%s%s' % (self._digest(*hash_data), suffix) hash_suffix_part = '%s%s' % (self._digest(*hash_data), suffix)
max_length = self.connection.ops.max_name_length() or 200 max_length = self.connection.ops.max_name_length() or 200
......
...@@ -3,7 +3,6 @@ import decimal ...@@ -3,7 +3,6 @@ import decimal
import functools import functools
import hashlib import hashlib
import logging import logging
import re
from time import time from time import time
from django.conf import settings from django.conf import settings
...@@ -194,20 +193,35 @@ def rev_typecast_decimal(d): ...@@ -194,20 +193,35 @@ def rev_typecast_decimal(d):
return str(d) return str(d)
def truncate_name(name, length=None, hash_len=4): def split_identifier(identifier):
""" """
Shorten a string to a repeatable mangled version with the given length. Split a SQL identifier into a two element tuple of (namespace, name).
If a quote stripped name contains a username, e.g. USERNAME"."TABLE,
The identifier could be a table, column, or sequence name might be prefixed
by a namespace.
"""
try:
namespace, name = identifier.split('"."')
except ValueError:
namespace, name = '', identifier
return namespace.strip('"'), name.strip('"')
def truncate_name(identifier, length=None, hash_len=4):
"""
Shorten a SQL identifier to a repeatable mangled version with the given
length.
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
truncate the table portion only. truncate the table portion only.
""" """
match = re.match(r'([^"]+)"\."([^"]+)', name) namespace, name = split_identifier(identifier)
table_name = match.group(2) if match else name
if length is None or len(table_name) <= length: if length is None or len(name) <= length:
return name return identifier
hsh = hashlib.md5(force_bytes(table_name)).hexdigest()[:hash_len] digest = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
return '%s%s%s' % (match.group(1) + '"."' if match else '', table_name[:length - hash_len], hsh) return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
def format_number(value, max_digits, decimal_places): def format_number(value, max_digits, decimal_places):
......
...@@ -15,3 +15,6 @@ Bugfixes ...@@ -15,3 +15,6 @@ Bugfixes
* Added support for ``QuerySet.values()`` and ``values_list()`` for * Added support for ``QuerySet.values()`` and ``values_list()`` for
``union()``, ``difference()``, and ``intersection()`` queries ``union()``, ``difference()``, and ``intersection()`` queries
(:ticket:`28781`). (:ticket:`28781`).
* Fixed incorrect index name truncation when using a namespaced ``db_table``
(:ticket:`28792`).
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
from decimal import Decimal, Rounded from decimal import Decimal, Rounded
from django.db import connection from django.db import connection
from django.db.backends.utils import format_number, truncate_name from django.db.backends.utils import (
format_number, split_identifier, truncate_name,
)
from django.db.utils import NotSupportedError from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature,
...@@ -21,6 +23,12 @@ class TestUtils(SimpleTestCase): ...@@ -21,6 +23,12 @@ class TestUtils(SimpleTestCase):
self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a') self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a')
self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38') self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38')
def test_split_identifier(self):
self.assertEqual(split_identifier('some_table'), ('', 'some_table'))
self.assertEqual(split_identifier('"some_table"'), ('', 'some_table'))
self.assertEqual(split_identifier('namespace"."some_table'), ('namespace', 'some_table'))
self.assertEqual(split_identifier('"namespace"."some_table"'), ('namespace', 'some_table'))
def test_format_number(self): def test_format_number(self):
def equal(value, max_d, places, result): def equal(value, max_d, places, result):
self.assertEqual(format_number(Decimal(value), max_d, places), result) self.assertEqual(format_number(Decimal(value), max_d, places), result)
......
...@@ -2370,6 +2370,21 @@ class SchemaTests(TransactionTestCase): ...@@ -2370,6 +2370,21 @@ class SchemaTests(TransactionTestCase):
cast_function=lambda x: x.time(), cast_function=lambda x: x.time(),
) )
def test_namespaced_db_table_create_index_name(self):
"""
Table names are stripped of their namespace/schema before being used to
generate index names.
"""
with connection.schema_editor() as editor:
max_name_length = connection.ops.max_name_length() or 200
namespace = 'n' * max_name_length
table_name = 't' * max_name_length
namespaced_table_name = '"%s"."%s"' % (namespace, table_name)
self.assertEqual(
editor._create_index_name(table_name, []),
editor._create_index_name(namespaced_table_name, []),
)
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle specific db_table syntax') @unittest.skipUnless(connection.vendor == 'oracle', 'Oracle specific db_table syntax')
def test_creation_with_db_table_double_quotes(self): def test_creation_with_db_table_double_quotes(self):
oracle_user = connection.creation._test_database_user() oracle_user = connection.creation._test_database_user()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment