Kaydet (Commit) 8ba85ac2 authored tarafından Joffrey F's avatar Joffrey F

Merge pull request #727 from mark-adams/fixes-726

Fixed #726 issue where split_port was checking `len(None)`
def add_port_mapping(port_bindings, internal_port, external): def add_port_mapping(port_bindings, internal_port, external):
if internal_port in port_bindings: if internal_port in port_bindings:
port_bindings[internal_port].append(external) port_bindings[internal_port].append(external)
...@@ -33,9 +32,8 @@ def to_port_range(port): ...@@ -33,9 +32,8 @@ def to_port_range(port):
if "/" in port: if "/" in port:
parts = port.split("/") parts = port.split("/")
if len(parts) != 2: if len(parts) != 2:
raise ValueError('Invalid port "%s", should be ' _raise_invalid_port(port)
'[[remote_ip:]remote_port[-remote_port]:]'
'port[/protocol]' % port)
port, protocol = parts port, protocol = parts
protocol = "/" + protocol protocol = "/" + protocol
...@@ -52,11 +50,17 @@ def to_port_range(port): ...@@ -52,11 +50,17 @@ def to_port_range(port):
'port or startport-endport' % port) 'port or startport-endport' % port)
def _raise_invalid_port(port):
raise ValueError('Invalid port "%s", should be '
'[[remote_ip:]remote_port[-remote_port]:]'
'port[/protocol]' % port)
def split_port(port): def split_port(port):
parts = str(port).split(':') parts = str(port).split(':')
if not 1 <= len(parts) <= 3: if not 1 <= len(parts) <= 3:
raise ValueError('Invalid port "%s", should be ' _raise_invalid_port(port)
'[[remote_ip:]remote_port:]port[/protocol]' % port)
if len(parts) == 1: if len(parts) == 1:
internal_port, = parts internal_port, = parts
...@@ -66,6 +70,10 @@ def split_port(port): ...@@ -66,6 +70,10 @@ def split_port(port):
internal_range = to_port_range(internal_port) internal_range = to_port_range(internal_port)
external_range = to_port_range(external_port) external_range = to_port_range(external_port)
if internal_range is None or external_range is None:
_raise_invalid_port(port)
if len(internal_range) != len(external_range): if len(internal_range) != len(external_range):
raise ValueError('Port ranges don\'t match in length') raise ValueError('Port ranges don\'t match in length')
......
...@@ -419,6 +419,14 @@ class UtilsTest(base.BaseTestCase): ...@@ -419,6 +419,14 @@ class UtilsTest(base.BaseTestCase):
self.assertRaises(ValueError, self.assertRaises(ValueError,
lambda: split_port("0.0.0.0:1000:2000-2002/tcp")) lambda: split_port("0.0.0.0:1000:2000-2002/tcp"))
def test_port_only_with_colon(self):
self.assertRaises(ValueError,
lambda: split_port(":80"))
def test_host_only_with_colon(self):
self.assertRaises(ValueError,
lambda: split_port("localhost:"))
def test_build_port_bindings_with_one_port(self): def test_build_port_bindings_with_one_port(self):
port_bindings = build_port_bindings(["127.0.0.1:1000:1000"]) port_bindings = build_port_bindings(["127.0.0.1:1000:1000"])
self.assertEqual(port_bindings["1000"], [("127.0.0.1", "1000")]) self.assertEqual(port_bindings["1000"], [("127.0.0.1", "1000")])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment