]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Fix callback registry notification for security group rule
authorOleg Bondarev <obondarev@mirantis.com>
Tue, 2 Jun 2015 13:14:40 +0000 (16:14 +0300)
committerOleg Bondarev <obondarev@mirantis.com>
Mon, 22 Jun 2015 14:19:44 +0000 (17:19 +0300)
Some housekeeping was done in
 - SecurityGroupDbMixin:
   - create_rule_bulk() calls to create_rule();
   - registry notification is in create_rule();
   - separate validation for a single rule and for a group of rules
 - SecurityGroupServerRpcMixin:
   - overriden methods call to corresponding super class methods;

Hopefully code is now self-documented enough

Closes-Bug: #1461024
Change-Id: Ia75d7e206716bbe74aae89e4cebd0c2c40af68a8

neutron/db/securitygroups_db.py
neutron/db/securitygroups_rpc_base.py
neutron/tests/unit/db/test_securitygroups_db.py

index 3caca9bbc918f60cfb699e17c4946238cdbe1f1a..3e24dc13ef5bc27fa330d2a4829af701699a7b35 100644 (file)
@@ -347,40 +347,36 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
                                  security_group_rules)
 
     def create_security_group_rule_bulk_native(self, context,
-                                               security_group_rule):
-        r = security_group_rule['security_group_rules']
-
+                                               security_group_rules):
+        rules = security_group_rules['security_group_rules']
         scoped_session(context.session)
         security_group_id = self._validate_security_group_rules(
-            context, security_group_rule)
+            context, security_group_rules)
         with context.session.begin(subtransactions=True):
             if not self.get_security_group(context, security_group_id):
                 raise ext_sg.SecurityGroupNotFound(id=security_group_id)
 
-            self._check_for_duplicate_rules(context, r)
+            self._check_for_duplicate_rules(context, rules)
             ret = []
-            for rule_dict in r:
-                rule = rule_dict['security_group_rule']
-                tenant_id = self._get_tenant_id_for_create(context, rule)
-                db = SecurityGroupRule(
-                    id=(rule.get('id') or uuidutils.generate_uuid()),
-                    tenant_id=tenant_id,
-                    security_group_id=rule['security_group_id'],
-                    direction=rule['direction'],
-                    remote_group_id=rule.get('remote_group_id'),
-                    ethertype=rule['ethertype'],
-                    protocol=rule['protocol'],
-                    port_range_min=rule['port_range_min'],
-                    port_range_max=rule['port_range_max'],
-                    remote_ip_prefix=rule.get('remote_ip_prefix'))
-                context.session.add(db)
-                ret.append(self._make_security_group_rule_dict(db))
-        return ret
+            for rule_dict in rules:
+                res_rule_dict = self._create_security_group_rule(
+                    context, rule_dict, validate=False)
+                ret.append(res_rule_dict)
+            return ret
 
     def create_security_group_rule(self, context, security_group_rule):
+        return self._create_security_group_rule(context, security_group_rule)
+
+    def _create_security_group_rule(self, context, security_group_rule,
+                                    validate=True):
+        if validate:
+            self._validate_security_group_rule(context, security_group_rule)
+            self._check_for_duplicate_rules_in_db(context, security_group_rule)
+
+        rule_dict = security_group_rule['security_group_rule']
         kwargs = {
             'context': context,
-            'security_group_rule': security_group_rule,
+            'security_group_rule': rule_dict
         }
         # NOTE(armax): a callback exception here will prevent the request
         # from being processed. This is a hook point for backend's validation;
@@ -392,15 +388,26 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
         except exceptions.CallbackFailure as e:
             raise ext_sg.SecurityGroupConflict(reason=e)
 
-        bulk_rule = {'security_group_rules': [security_group_rule]}
-        sg_rule_dict = self.create_security_group_rule_bulk_native(
-            context, bulk_rule)[0]
-
-        kwargs['security_group_rule'] = sg_rule_dict
+        tenant_id = self._get_tenant_id_for_create(context, rule_dict)
+        with context.session.begin(subtransactions=True):
+            db = SecurityGroupRule(
+                id=(rule_dict.get('id') or uuidutils.generate_uuid()),
+                tenant_id=tenant_id,
+                security_group_id=rule_dict['security_group_id'],
+                direction=rule_dict['direction'],
+                remote_group_id=rule_dict.get('remote_group_id'),
+                ethertype=rule_dict['ethertype'],
+                protocol=rule_dict['protocol'],
+                port_range_min=rule_dict['port_range_min'],
+                port_range_max=rule_dict['port_range_max'],
+                remote_ip_prefix=rule_dict.get('remote_ip_prefix'))
+            context.session.add(db)
+        res_rule_dict = self._make_security_group_rule_dict(db)
+        kwargs['security_group_rule'] = res_rule_dict
         registry.notify(
             resources.SECURITY_GROUP_RULE, events.AFTER_CREATE, self,
             **kwargs)
-        return sg_rule_dict
+        return res_rule_dict
 
     def _get_ip_proto_number(self, protocol):
         if protocol is None:
@@ -436,45 +443,50 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
                 raise ext_sg.SecurityGroupMissingIcmpType(
                     value=rule['port_range_max'])
 
-    def _validate_security_group_rules(self, context, security_group_rules):
-        """Check that rules being installed.
-
-        Check that all rules belong to the same security
-        group, remote_group_id/security_group_id belong to the same tenant,
-        and rules are valid.
+    def _validate_single_tenant_and_group(self, security_group_rules):
+        """Check that all rules belong to the same security group and tenant
         """
-        new_rules = set()
-        tenant_ids = set()
+        sg_groups = set()
+        tenants = set()
         for rule_dict in security_group_rules['security_group_rules']:
-            rule = rule_dict.get('security_group_rule')
-            new_rules.add(rule['security_group_id'])
-
-            self._validate_port_range(rule)
-            self._validate_ip_prefix(rule)
-
-            if rule['remote_ip_prefix'] and rule['remote_group_id']:
-                raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()
-
-            if rule['tenant_id'] not in tenant_ids:
-                tenant_ids.add(rule['tenant_id'])
-            remote_group_id = rule.get('remote_group_id')
-            # Check that remote_group_id exists for tenant
-            if remote_group_id:
-                self.get_security_group(context, remote_group_id,
-                                        tenant_id=rule['tenant_id'])
-        if len(new_rules) > 1:
-            raise ext_sg.SecurityGroupNotSingleGroupRules()
-        security_group_id = new_rules.pop()
-
-        # Confirm single tenant and that the tenant has permission
+            rule = rule_dict['security_group_rule']
+            sg_groups.add(rule['security_group_id'])
+            if len(sg_groups) > 1:
+                raise ext_sg.SecurityGroupNotSingleGroupRules()
+
+            tenants.add(rule['tenant_id'])
+            if len(tenants) > 1:
+                raise ext_sg.SecurityGroupRulesNotSingleTenant()
+        return sg_groups.pop()
+
+    def _validate_security_group_rule(self, context, security_group_rule):
+        rule = security_group_rule['security_group_rule']
+        self._validate_port_range(rule)
+        self._validate_ip_prefix(rule)
+
+        if rule['remote_ip_prefix'] and rule['remote_group_id']:
+            raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()
+
+        remote_group_id = rule['remote_group_id']
+        # Check that remote_group_id exists for tenant
+        if remote_group_id:
+            self.get_security_group(context, remote_group_id,
+                                    tenant_id=rule['tenant_id'])
+
+        security_group_id = rule['security_group_id']
+
+        # Confirm that the tenant has permission
         # to add rules to this security group.
-        if len(tenant_ids) > 1:
-            raise ext_sg.SecurityGroupRulesNotSingleTenant()
-        for tenant_id in tenant_ids:
-            self.get_security_group(context, security_group_id,
-                                    tenant_id=tenant_id)
+        self.get_security_group(context, security_group_id,
+                                tenant_id=rule['tenant_id'])
         return security_group_id
 
+    def _validate_security_group_rules(self, context, security_group_rules):
+        sg_id = self._validate_single_tenant_and_group(security_group_rules)
+        for rule in security_group_rules['security_group_rules']:
+            self._validate_security_group_rule(context, rule)
+        return sg_id
+
     def _make_security_group_rule_dict(self, security_group_rule, fields=None):
         res = {'id': security_group_rule['id'],
                'tenant_id': security_group_rule['tenant_id'],
@@ -513,23 +525,27 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
                         raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i)
                     found_self = True
 
-            # Check in database if rule exists
-            filters = self._make_security_group_rule_filter_dict(i)
-            db_rules = self.get_security_group_rules(context, filters)
-            # Note(arosen): the call to get_security_group_rules wildcards
-            # values in the filter that have a value of [None]. For
-            # example, filters = {'remote_group_id': [None]} will return
-            # all security group rules regardless of their value of
-            # remote_group_id. Therefore it is not possible to do this
-            # query unless the behavior of _get_collection()
-            # is changed which cannot be because other methods are already
-            # relying on this behavior. Therefore, we do the filtering
-            # below to check for these corner cases.
-            for db_rule in db_rules:
-                # need to remove id from db_rule for matching
-                id = db_rule.pop('id')
-                if (i['security_group_rule'] == db_rule):
-                    raise ext_sg.SecurityGroupRuleExists(id=id)
+            self._check_for_duplicate_rules_in_db(context, i)
+
+    def _check_for_duplicate_rules_in_db(self, context, security_group_rule):
+        # Check in database if rule exists
+        filters = self._make_security_group_rule_filter_dict(
+            security_group_rule)
+        db_rules = self.get_security_group_rules(context, filters)
+        # Note(arosen): the call to get_security_group_rules wildcards
+        # values in the filter that have a value of [None]. For
+        # example, filters = {'remote_group_id': [None]} will return
+        # all security group rules regardless of their value of
+        # remote_group_id. Therefore it is not possible to do this
+        # query unless the behavior of _get_collection()
+        # is changed which cannot be because other methods are already
+        # relying on this behavior. Therefore, we do the filtering
+        # below to check for these corner cases.
+        for db_rule in db_rules:
+            # need to remove id from db_rule for matching
+            id = db_rule.pop('id')
+            if (security_group_rule['security_group_rule'] == db_rule):
+                raise ext_sg.SecurityGroupRuleExists(id=id)
 
     def _validate_ip_prefix(self, rule):
         """Check that a valid cidr was specified as remote_ip_prefix
index 56c14fe5c5bfe393f9645a1bc8716372317ad76a..3e90c124b420b9f492b079795e90cf0d64e673d1 100644 (file)
@@ -69,18 +69,17 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
                 for device in devices]
 
     def create_security_group_rule(self, context, security_group_rule):
-        bulk_rule = {'security_group_rules': [security_group_rule]}
-        rule = self.create_security_group_rule_bulk_native(context,
-                                                           bulk_rule)[0]
+        rule = super(SecurityGroupServerRpcMixin,
+                     self).create_security_group_rule(context,
+                                                      security_group_rule)
         sgids = [rule['security_group_id']]
         self.notifier.security_groups_rule_updated(context, sgids)
         return rule
 
-    def create_security_group_rule_bulk(self, context,
-                                        security_group_rule):
+    def create_security_group_rule_bulk(self, context, security_group_rules):
         rules = super(SecurityGroupServerRpcMixin,
                       self).create_security_group_rule_bulk_native(
-                          context, security_group_rule)
+                          context, security_group_rules)
         sgids = set([r['security_group_id'] for r in rules])
         self.notifier.security_groups_rule_updated(context, list(sgids))
         return rules
index 0626f9ca65fdbf6ae363727e65540830e72a322b..db98f4622c0dd4758ad2a7765624adf7bd44fef0 100644 (file)
@@ -62,11 +62,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
                 self.mixin.update_security_group(self.ctx, 'foo_id', secgroup)
 
     def test_create_security_group_rule_conflict(self):
-        with mock.patch.object(registry, "notify") as mock_notify:
+        with mock.patch.object(self.mixin, '_validate_security_group_rule'),\
+                mock.patch.object(self.mixin,
+                                  '_check_for_duplicate_rules_in_db'),\
+                mock.patch.object(registry, "notify") as mock_notify:
             mock_notify.side_effect = exceptions.CallbackFailure(Exception())
             with testtools.ExpectedException(
                 securitygroup.SecurityGroupConflict):
-                self.mixin.create_security_group_rule(self.ctx, mock.ANY)
+                self.mixin.create_security_group_rule(
+                    self.ctx, mock.MagicMock())
 
     def test_delete_security_group_rule_in_use(self):
         with mock.patch.object(registry, "notify") as mock_notify: