Set lock_path correctly.
[openstack-build/neutron-build.git] / neutron / db / metering / metering_db.py
1 # Copyright (C) 2013 eNovance SAS <licensing@enovance.com>
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may
4 # not use this file except in compliance with the License. You may obtain
5 # a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations
13 # under the License.
14
15 import netaddr
16 from oslo_utils import uuidutils
17 import sqlalchemy as sa
18 from sqlalchemy import orm
19 from sqlalchemy import sql
20
21 from neutron.api.rpc.agentnotifiers import metering_rpc_agent_api
22 from neutron.api.v2 import attributes as attr
23 from neutron.common import constants
24 from neutron.db import common_db_mixin as base_db
25 from neutron.db import l3_db
26 from neutron.db import model_base
27 from neutron.extensions import metering
28
29
30 class MeteringLabelRule(model_base.BASEV2, model_base.HasId):
31     direction = sa.Column(sa.Enum('ingress', 'egress',
32                                   name='meteringlabels_direction'))
33     remote_ip_prefix = sa.Column(sa.String(64))
34     metering_label_id = sa.Column(sa.String(36),
35                                   sa.ForeignKey("meteringlabels.id",
36                                                 ondelete="CASCADE"),
37                                   nullable=False)
38     excluded = sa.Column(sa.Boolean, default=False, server_default=sql.false())
39
40
41 class MeteringLabel(model_base.BASEV2, model_base.HasId, model_base.HasTenant):
42     name = sa.Column(sa.String(attr.NAME_MAX_LEN))
43     description = sa.Column(sa.String(attr.LONG_DESCRIPTION_MAX_LEN))
44     rules = orm.relationship(MeteringLabelRule, backref="label",
45                              cascade="delete", lazy="joined")
46     routers = orm.relationship(
47         l3_db.Router,
48         primaryjoin="MeteringLabel.tenant_id==Router.tenant_id",
49         foreign_keys='MeteringLabel.tenant_id',
50         uselist=True)
51     shared = sa.Column(sa.Boolean, default=False, server_default=sql.false())
52
53
54 class MeteringDbMixin(metering.MeteringPluginBase,
55                       base_db.CommonDbMixin):
56
57     def __init__(self):
58         self.meter_rpc = metering_rpc_agent_api.MeteringAgentNotifyAPI()
59
60     def _make_metering_label_dict(self, metering_label, fields=None):
61         res = {'id': metering_label['id'],
62                'name': metering_label['name'],
63                'description': metering_label['description'],
64                'shared': metering_label['shared'],
65                'tenant_id': metering_label['tenant_id']}
66         return self._fields(res, fields)
67
68     def create_metering_label(self, context, metering_label):
69         m = metering_label['metering_label']
70
71         with context.session.begin(subtransactions=True):
72             metering_db = MeteringLabel(id=uuidutils.generate_uuid(),
73                                         description=m['description'],
74                                         tenant_id=m['tenant_id'],
75                                         name=m['name'],
76                                         shared=m['shared'])
77             context.session.add(metering_db)
78
79         return self._make_metering_label_dict(metering_db)
80
81     def delete_metering_label(self, context, label_id):
82         with context.session.begin(subtransactions=True):
83             try:
84                 label = self._get_by_id(context, MeteringLabel, label_id)
85             except orm.exc.NoResultFound:
86                 raise metering.MeteringLabelNotFound(label_id=label_id)
87
88             context.session.delete(label)
89
90     def get_metering_label(self, context, label_id, fields=None):
91         try:
92             metering_label = self._get_by_id(context, MeteringLabel, label_id)
93         except orm.exc.NoResultFound:
94             raise metering.MeteringLabelNotFound(label_id=label_id)
95
96         return self._make_metering_label_dict(metering_label, fields)
97
98     def get_metering_labels(self, context, filters=None, fields=None,
99                             sorts=None, limit=None, marker=None,
100                             page_reverse=False):
101         marker_obj = self._get_marker_obj(context, 'metering_labels', limit,
102                                           marker)
103         return self._get_collection(context, MeteringLabel,
104                                     self._make_metering_label_dict,
105                                     filters=filters, fields=fields,
106                                     sorts=sorts,
107                                     limit=limit,
108                                     marker_obj=marker_obj,
109                                     page_reverse=page_reverse)
110
111     def _make_metering_label_rule_dict(self, metering_label_rule, fields=None):
112         res = {'id': metering_label_rule['id'],
113                'metering_label_id': metering_label_rule['metering_label_id'],
114                'direction': metering_label_rule['direction'],
115                'remote_ip_prefix': metering_label_rule['remote_ip_prefix'],
116                'excluded': metering_label_rule['excluded']}
117         return self._fields(res, fields)
118
119     def get_metering_label_rules(self, context, filters=None, fields=None,
120                                  sorts=None, limit=None, marker=None,
121                                  page_reverse=False):
122         marker_obj = self._get_marker_obj(context, 'metering_label_rules',
123                                           limit, marker)
124
125         return self._get_collection(context, MeteringLabelRule,
126                                     self._make_metering_label_rule_dict,
127                                     filters=filters, fields=fields,
128                                     sorts=sorts,
129                                     limit=limit,
130                                     marker_obj=marker_obj,
131                                     page_reverse=page_reverse)
132
133     def get_metering_label_rule(self, context, rule_id, fields=None):
134         try:
135             metering_label_rule = self._get_by_id(context,
136                                                   MeteringLabelRule, rule_id)
137         except orm.exc.NoResultFound:
138             raise metering.MeteringLabelRuleNotFound(rule_id=rule_id)
139
140         return self._make_metering_label_rule_dict(metering_label_rule, fields)
141
142     def _validate_cidr(self, context, label_id, remote_ip_prefix,
143                        direction, excluded):
144         r_ips = self.get_metering_label_rules(context,
145                                               filters={'metering_label_id':
146                                                        [label_id],
147                                                        'direction':
148                                                        [direction],
149                                                        'excluded':
150                                                        [excluded]},
151                                               fields=['remote_ip_prefix'])
152
153         cidrs = [r['remote_ip_prefix'] for r in r_ips]
154         new_cidr_ipset = netaddr.IPSet([remote_ip_prefix])
155         if (netaddr.IPSet(cidrs) & new_cidr_ipset):
156             raise metering.MeteringLabelRuleOverlaps(
157                 remote_ip_prefix=remote_ip_prefix)
158
159     def create_metering_label_rule(self, context, metering_label_rule):
160         m = metering_label_rule['metering_label_rule']
161         with context.session.begin(subtransactions=True):
162             label_id = m['metering_label_id']
163             ip_prefix = m['remote_ip_prefix']
164             direction = m['direction']
165             excluded = m['excluded']
166
167             self._validate_cidr(context, label_id, ip_prefix, direction,
168                                 excluded)
169             metering_db = MeteringLabelRule(id=uuidutils.generate_uuid(),
170                                             metering_label_id=label_id,
171                                             direction=direction,
172                                             excluded=m['excluded'],
173                                             remote_ip_prefix=ip_prefix)
174             context.session.add(metering_db)
175
176         return self._make_metering_label_rule_dict(metering_db)
177
178     def delete_metering_label_rule(self, context, rule_id):
179         with context.session.begin(subtransactions=True):
180             try:
181                 rule = self._get_by_id(context, MeteringLabelRule, rule_id)
182             except orm.exc.NoResultFound:
183                 raise metering.MeteringLabelRuleNotFound(rule_id=rule_id)
184             context.session.delete(rule)
185
186         return self._make_metering_label_rule_dict(rule)
187
188     def _get_metering_rules_dict(self, metering_label):
189         rules = []
190         for rule in metering_label.rules:
191             rule_dict = self._make_metering_label_rule_dict(rule)
192             rules.append(rule_dict)
193
194         return rules
195
196     def _make_router_dict(self, router):
197         res = {'id': router['id'],
198                'name': router['name'],
199                'tenant_id': router['tenant_id'],
200                'admin_state_up': router['admin_state_up'],
201                'status': router['status'],
202                'gw_port_id': router['gw_port_id'],
203                constants.METERING_LABEL_KEY: []}
204
205         return res
206
207     def _process_sync_metering_data(self, context, labels):
208         all_routers = None
209
210         routers_dict = {}
211         for label in labels:
212             if label.shared:
213                 if not all_routers:
214                     all_routers = self._get_collection_query(context,
215                                                              l3_db.Router)
216                 routers = all_routers
217             else:
218                 routers = label.routers
219
220             for router in routers:
221                 router_dict = routers_dict.get(
222                     router['id'],
223                     self._make_router_dict(router))
224
225                 rules = self._get_metering_rules_dict(label)
226
227                 data = {'id': label['id'], 'rules': rules}
228                 router_dict[constants.METERING_LABEL_KEY].append(data)
229
230                 routers_dict[router['id']] = router_dict
231
232         return list(routers_dict.values())
233
234     def get_sync_data_for_rule(self, context, rule):
235         label = context.session.query(MeteringLabel).get(
236             rule['metering_label_id'])
237
238         if label.shared:
239             routers = self._get_collection_query(context, l3_db.Router)
240         else:
241             routers = label.routers
242
243         routers_dict = {}
244         for router in routers:
245             router_dict = routers_dict.get(router['id'],
246                                            self._make_router_dict(router))
247             data = {'id': label['id'], 'rule': rule}
248             router_dict[constants.METERING_LABEL_KEY].append(data)
249             routers_dict[router['id']] = router_dict
250
251         return list(routers_dict.values())
252
253     def get_sync_data_metering(self, context, label_id=None, router_ids=None):
254         labels = context.session.query(MeteringLabel)
255
256         if label_id:
257             labels = labels.filter(MeteringLabel.id == label_id)
258         elif router_ids:
259             labels = (labels.join(MeteringLabel.routers).
260                       filter(l3_db.Router.id.in_(router_ids)))
261
262         return self._process_sync_metering_data(context, labels)