Kaydet (Commit) 7b00d902 authored tarafından Claude Paroz's avatar Claude Paroz

[py3] Made GeoIP tests pass with Python 3

üst 465a29ab
...@@ -137,9 +137,6 @@ class GeoIP(object): ...@@ -137,9 +137,6 @@ class GeoIP(object):
if not isinstance(query, six.string_types): if not isinstance(query, six.string_types):
raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__) raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__)
# GeoIP only takes ASCII-encoded strings.
query = query.encode('ascii')
# Extra checks for the existence of country and city databases. # Extra checks for the existence of country and city databases.
if city_or_country and not (self._country or self._city): if city_or_country and not (self._country or self._city):
raise GeoIPException('Invalid GeoIP country and city data files.') raise GeoIPException('Invalid GeoIP country and city data files.')
...@@ -148,8 +145,8 @@ class GeoIP(object): ...@@ -148,8 +145,8 @@ class GeoIP(object):
elif city and not self._city: elif city and not self._city:
raise GeoIPException('Invalid GeoIP city data file: %s' % self._city_file) raise GeoIPException('Invalid GeoIP city data file: %s' % self._city_file)
# Return the query string back to the caller. # Return the query string back to the caller. GeoIP only takes bytestrings.
return query return force_bytes(query)
def city(self, query): def city(self, query):
""" """
...@@ -157,33 +154,33 @@ class GeoIP(object): ...@@ -157,33 +154,33 @@ class GeoIP(object):
Fully Qualified Domain Name (FQDN). Some information in the dictionary Fully Qualified Domain Name (FQDN). Some information in the dictionary
may be undefined (None). may be undefined (None).
""" """
query = self._check_query(query, city=True) enc_query = self._check_query(query, city=True)
if ipv4_re.match(query): if ipv4_re.match(query):
# If an IP address was passed in # If an IP address was passed in
return GeoIP_record_by_addr(self._city, c_char_p(query)) return GeoIP_record_by_addr(self._city, c_char_p(enc_query))
else: else:
# If a FQDN was passed in. # If a FQDN was passed in.
return GeoIP_record_by_name(self._city, c_char_p(query)) return GeoIP_record_by_name(self._city, c_char_p(enc_query))
def country_code(self, query): def country_code(self, query):
"Returns the country code for the given IP Address or FQDN." "Returns the country code for the given IP Address or FQDN."
query = self._check_query(query, city_or_country=True) enc_query = self._check_query(query, city_or_country=True)
if self._country: if self._country:
if ipv4_re.match(query): if ipv4_re.match(query):
return GeoIP_country_code_by_addr(self._country, query) return GeoIP_country_code_by_addr(self._country, enc_query)
else: else:
return GeoIP_country_code_by_name(self._country, query) return GeoIP_country_code_by_name(self._country, enc_query)
else: else:
return self.city(query)['country_code'] return self.city(query)['country_code']
def country_name(self, query): def country_name(self, query):
"Returns the country name for the given IP Address or FQDN." "Returns the country name for the given IP Address or FQDN."
query = self._check_query(query, city_or_country=True) enc_query = self._check_query(query, city_or_country=True)
if self._country: if self._country:
if ipv4_re.match(query): if ipv4_re.match(query):
return GeoIP_country_name_by_addr(self._country, query) return GeoIP_country_name_by_addr(self._country, enc_query)
else: else:
return GeoIP_country_name_by_name(self._country, query) return GeoIP_country_name_by_name(self._country, enc_query)
else: else:
return self.city(query)['country_name'] return self.city(query)['country_name']
......
...@@ -92,7 +92,7 @@ def check_string(result, func, cargs): ...@@ -92,7 +92,7 @@ def check_string(result, func, cargs):
free(result) free(result)
else: else:
s = '' s = ''
return s return s.decode()
GeoIP_database_info = lgeoip.GeoIP_database_info GeoIP_database_info = lgeoip.GeoIP_database_info
GeoIP_database_info.restype = geoip_char_p GeoIP_database_info.restype = geoip_char_p
...@@ -100,7 +100,12 @@ GeoIP_database_info.errcheck = check_string ...@@ -100,7 +100,12 @@ GeoIP_database_info.errcheck = check_string
# String output routines. # String output routines.
def string_output(func): def string_output(func):
def _err_check(result, func, cargs):
if result:
return result.decode()
return result
func.restype = c_char_p func.restype = c_char_p
func.errcheck = _err_check
return func return func
GeoIP_country_code_by_addr = string_output(lgeoip.GeoIP_country_code_by_addr) GeoIP_country_code_by_addr = string_output(lgeoip.GeoIP_country_code_by_addr)
......
...@@ -106,12 +106,6 @@ class GeoIPTest(unittest.TestCase): ...@@ -106,12 +106,6 @@ class GeoIPTest(unittest.TestCase):
d = g.city("www.osnabrueck.de") d = g.city("www.osnabrueck.de")
self.assertEqual('Osnabrück', d['city']) self.assertEqual('Osnabrück', d['city'])
def test06_unicode_query(self):
"Testing that GeoIP accepts unicode string queries, see #17059."
g = GeoIP()
d = g.country('whitehouse.gov')
self.assertEqual('US', d['country_code'])
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment