Kaydet (Commit) 482fe047 authored tarafından Ethan Furman's avatar Ethan Furman

issue23673

add private method to enum to support replacing global constants with Enum members:
- search for candidate constants via supplied filter
- create new enum class and members
- insert enum class and replace constants with members via supplied module name
- replace __reduce_ex__ with function that returns member name, so previous Python versions can unpickle
modify IntEnum classes to use new method
üst d833779c
...@@ -511,11 +511,37 @@ class Enum(metaclass=EnumMeta): ...@@ -511,11 +511,37 @@ class Enum(metaclass=EnumMeta):
"""The value of the Enum member.""" """The value of the Enum member."""
return self._value_ return self._value_
@classmethod
def _convert(cls, name, module, filter, source=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
# convert all constants from source (or module) that pass filter() to
# a new Enum called name, and export the enum and its members back to
# module;
# also, replace the __reduce_ex__ method so unpickling works in
# previous Python versions
module_globals = vars(sys.modules[module])
if source:
source = vars(source)
else:
source = module_globals
members = {name: value for name, value in source.items()
if filter(name)}
cls = cls(name, members, module=module)
cls.__reduce_ex__ = _reduce_ex_by_name
module_globals.update(cls.__members__)
module_globals[name] = cls
return cls
class IntEnum(int, Enum): class IntEnum(int, Enum):
"""Enum where members are also (and must be) ints""" """Enum where members are also (and must be) ints"""
def _reduce_ex_by_name(self, proto):
return self.name
def unique(enumeration): def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values.""" """Class decorator for enumerations ensuring unique member values."""
duplicates = [] duplicates = []
......
...@@ -69,15 +69,15 @@ __all__.extend(os._get_exports_list(_socket)) ...@@ -69,15 +69,15 @@ __all__.extend(os._get_exports_list(_socket))
# Note that _socket only knows about the integer values. The public interface # Note that _socket only knows about the integer values. The public interface
# in this module understands the enums and translates them back from integers # in this module understands the enums and translates them back from integers
# where needed (e.g. .family property of a socket object). # where needed (e.g. .family property of a socket object).
AddressFamily = IntEnum('AddressFamily', IntEnum._convert(
{name: value for name, value in globals().items() 'AddressFamily',
if name.isupper() and name.startswith('AF_')}) __name__,
globals().update(AddressFamily.__members__) lambda C: C.isupper() and C.startswith('AF_'))
SocketKind = IntEnum('SocketKind', IntEnum._convert(
{name: value for name, value in globals().items() 'SocketKind',
if name.isupper() and name.startswith('SOCK_')}) __name__,
globals().update(SocketKind.__members__) lambda C: C.isupper() and C.startswith('SOCK_'))
def _intenum_converter(value, enum_klass): def _intenum_converter(value, enum_klass):
"""Convert a numeric family value to an IntEnum member. """Convert a numeric family value to an IntEnum member.
......
...@@ -581,6 +581,14 @@ class TestEnum(unittest.TestCase): ...@@ -581,6 +581,14 @@ class TestEnum(unittest.TestCase):
test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs, test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs,
protocol=(4, HIGHEST_PROTOCOL)) protocol=(4, HIGHEST_PROTOCOL))
def test_pickle_by_name(self):
class ReplaceGlobalInt(IntEnum):
ONE = 1
TWO = 2
ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_name
for proto in range(HIGHEST_PROTOCOL):
self.assertEqual(ReplaceGlobalInt.TWO.__reduce_ex__(proto), 'TWO')
def test_exploding_pickle(self): def test_exploding_pickle(self):
BadPickle = Enum( BadPickle = Enum(
'BadPickle', 'dill sweet bread-n-butter', module=__name__) 'BadPickle', 'dill sweet bread-n-butter', module=__name__)
......
...@@ -1375,6 +1375,11 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1375,6 +1375,11 @@ class GeneralModuleTests(unittest.TestCase):
with sock: with sock:
for protocol in range(pickle.HIGHEST_PROTOCOL + 1): for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
self.assertRaises(TypeError, pickle.dumps, sock, protocol) self.assertRaises(TypeError, pickle.dumps, sock, protocol)
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
family = pickle.loads(pickle.dumps(socket.AF_INET, protocol))
self.assertEqual(family, socket.AF_INET)
type = pickle.loads(pickle.dumps(socket.SOCK_STREAM, protocol))
self.assertEqual(type, socket.SOCK_STREAM)
def test_listen_backlog(self): def test_listen_backlog(self):
for backlog in 0, -1: for backlog in 0, -1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment