Kaydet (Commit) 0ac30f82 authored tarafından Walter Dörwald's avatar Walter Dörwald

Enhance the punycode decoder so that it can decode

unicode objects.

Fix the idna codec and the tests.
üst 1f05a3b7
...@@ -7,7 +7,8 @@ from unicodedata import ucd_3_2_0 as unicodedata ...@@ -7,7 +7,8 @@ from unicodedata import ucd_3_2_0 as unicodedata
dots = re.compile("[\u002E\u3002\uFF0E\uFF61]") dots = re.compile("[\u002E\u3002\uFF0E\uFF61]")
# IDNA section 5 # IDNA section 5
ace_prefix = "xn--" ace_prefix = b"xn--"
sace_prefix = "xn--"
# This assumes query strings, so AllowUnassigned is true # This assumes query strings, so AllowUnassigned is true
def nameprep(label): def nameprep(label):
...@@ -87,7 +88,7 @@ def ToASCII(label): ...@@ -87,7 +88,7 @@ def ToASCII(label):
raise UnicodeError("label empty or too long") raise UnicodeError("label empty or too long")
# Step 5: Check ACE prefix # Step 5: Check ACE prefix
if label.startswith(ace_prefix): if label.startswith(sace_prefix):
raise UnicodeError("Label starts with ACE prefix") raise UnicodeError("Label starts with ACE prefix")
# Step 6: Encode with PUNYCODE # Step 6: Encode with PUNYCODE
...@@ -134,7 +135,7 @@ def ToUnicode(label): ...@@ -134,7 +135,7 @@ def ToUnicode(label):
# Step 7: Compare the result of step 6 with the one of step 3 # Step 7: Compare the result of step 6 with the one of step 3
# label2 will already be in lower case. # label2 will already be in lower case.
if label.lower() != label2: if str(label, "ascii").lower() != str(label2, "ascii"):
raise UnicodeError("IDNA does not round-trip", label, label2) raise UnicodeError("IDNA does not round-trip", label, label2)
# Step 8: return the result of step 5 # Step 8: return the result of step 5
...@@ -143,7 +144,7 @@ def ToUnicode(label): ...@@ -143,7 +144,7 @@ def ToUnicode(label):
### Codec APIs ### Codec APIs
class Codec(codecs.Codec): class Codec(codecs.Codec):
def encode(self,input,errors='strict'): def encode(self, input, errors='strict'):
if errors != 'strict': if errors != 'strict':
# IDNA is quite clear that implementations must be strict # IDNA is quite clear that implementations must be strict
...@@ -152,19 +153,21 @@ class Codec(codecs.Codec): ...@@ -152,19 +153,21 @@ class Codec(codecs.Codec):
if not input: if not input:
return b"", 0 return b"", 0
result = [] result = b""
labels = dots.split(input) labels = dots.split(input)
if labels and len(labels[-1])==0: if labels and not labels[-1]:
trailing_dot = b'.' trailing_dot = b'.'
del labels[-1] del labels[-1]
else: else:
trailing_dot = b'' trailing_dot = b''
for label in labels: for label in labels:
result.append(ToASCII(label)) if result:
# Join with U+002E # Join with U+002E
return b".".join(result)+trailing_dot, len(input) result.extend(b'.')
result.extend(ToASCII(label))
return result+trailing_dot, len(input)
def decode(self,input,errors='strict'): def decode(self, input, errors='strict'):
if errors != 'strict': if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors) raise UnicodeError("Unsupported error handling "+errors)
...@@ -199,30 +202,31 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder): ...@@ -199,30 +202,31 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
raise UnicodeError("unsupported error handling "+errors) raise UnicodeError("unsupported error handling "+errors)
if not input: if not input:
return ("", 0) return (b'', 0)
labels = dots.split(input) labels = dots.split(input)
trailing_dot = '' trailing_dot = b''
if labels: if labels:
if not labels[-1]: if not labels[-1]:
trailing_dot = '.' trailing_dot = b'.'
del labels[-1] del labels[-1]
elif not final: elif not final:
# Keep potentially unfinished label until the next call # Keep potentially unfinished label until the next call
del labels[-1] del labels[-1]
if labels: if labels:
trailing_dot = '.' trailing_dot = b'.'
result = [] result = b""
size = 0 size = 0
for label in labels: for label in labels:
result.append(ToASCII(label))
if size: if size:
# Join with U+002E
result.extend(b'.')
size += 1 size += 1
result.extend(ToASCII(label))
size += len(label) size += len(label)
# Join with U+002E result += trailing_dot
result = ".".join(result) + trailing_dot
size += len(trailing_dot) size += len(trailing_dot)
return (result, size) return (result, size)
...@@ -239,8 +243,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder): ...@@ -239,8 +243,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
labels = dots.split(input) labels = dots.split(input)
else: else:
# Must be ASCII string # Must be ASCII string
input = str(input) input = str(input, "ascii")
str(input, "ascii")
labels = input.split(".") labels = input.split(".")
trailing_dot = '' trailing_dot = ''
......
...@@ -181,6 +181,8 @@ def insertion_sort(base, extended, errors): ...@@ -181,6 +181,8 @@ def insertion_sort(base, extended, errors):
return base return base
def punycode_decode(text, errors): def punycode_decode(text, errors):
if isinstance(text, str):
text = text.encode("ascii")
pos = text.rfind(b"-") pos = text.rfind(b"-")
if pos == -1: if pos == -1:
base = "" base = ""
...@@ -194,11 +196,11 @@ def punycode_decode(text, errors): ...@@ -194,11 +196,11 @@ def punycode_decode(text, errors):
class Codec(codecs.Codec): class Codec(codecs.Codec):
def encode(self,input,errors='strict'): def encode(self, input, errors='strict'):
res = punycode_encode(input) res = punycode_encode(input)
return res, len(input) return res, len(input)
def decode(self,input,errors='strict'): def decode(self, input, errors='strict'):
if errors not in ('strict', 'replace', 'ignore'): if errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError, "Unsupported error handling "+errors raise UnicodeError, "Unsupported error handling "+errors
res = punycode_decode(input, errors) res = punycode_decode(input, errors)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment