# License for the specific language governing permissions and limitations
# under the License.
+import os
+
import netaddr
from oslo.config import cfg
return self._execute(options, command, args,
log_fail_as_error=self.log_fail_as_error)
- def _as_root(self, options, command, args, use_root_namespace=False):
- if not self.root_helper:
+ def enforce_root_helper(self):
+ if not self.root_helper and os.geteuid() != 0:
raise exceptions.SudoRequired()
+ def _as_root(self, options, command, args, use_root_namespace=False):
+ self.enforce_root_helper()
+
namespace = self.namespace if not use_root_namespace else None
return self._execute(options,
extra_ok_codes=None):
ns_params = []
if self._parent.namespace:
- if not self._parent.root_helper:
- raise exceptions.SudoRequired()
+ self._parent.enforce_root_helper()
ns_params = ['ip', 'netns', 'exec', self._parent.namespace]
env_params = []
# License for the specific language governing permissions and limitations
# under the License.
+import os
+
import mock
from neutron.agent.linux import ip_lib
root_helper='sudo',
log_fail_as_error=True)
- def test_as_root_no_root_helper(self):
+ def test_enforce_root_helper_no_root_helper(self):
+ base = ip_lib.SubProcessBase()
+ not_root = 42
+ with mock.patch.object(os, 'geteuid', return_value=not_root):
+ self.assertRaises(exceptions.SudoRequired,
+ base.enforce_root_helper)
+
+ def test_enforce_root_helper_with_root_helper_supplied(self):
+ base = ip_lib.SubProcessBase('sudo')
+ try:
+ base.enforce_root_helper()
+ except exceptions.SudoRequired:
+ self.fail('enforce_root_helper should not raise SudoRequired '
+ 'when a root_helper is supplied.')
+
+ def test_enforce_root_helper_with_no_root_helper_but_root(self):
base = ip_lib.SubProcessBase()
- self.assertRaises(exceptions.SudoRequired,
- base._as_root,
- [], 'link', ('list',))
+ with mock.patch.object(os, 'geteuid', return_value=0):
+ try:
+ base.enforce_root_helper()
+ except exceptions.SudoRequired:
+ self.fail('enforce_root_helper should not require a root '
+ 'helper when run as root.')
class TestIpWrapper(base.BaseTestCase):