aboutsummaryrefslogtreecommitdiff
path: root/Vland/db/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'Vland/db/db.py')
-rw-r--r--Vland/db/db.py825
1 files changed, 825 insertions, 0 deletions
diff --git a/Vland/db/db.py b/Vland/db/db.py
new file mode 100644
index 0000000..d5b541b
--- /dev/null
+++ b/Vland/db/db.py
@@ -0,0 +1,825 @@
+#! /usr/bin/python
+
+# Copyright 2014-2018 Linaro Limited
+# Authors: Dave Pigott <dave.pigott@linaro.org>,
+# Steve McIntyre <steve.mcintyre@linaro.org>
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA 02110-1301, USA.
+
+import psycopg2
+import psycopg2.extras
+import datetime, os, sys
+import logging
+
+TRUNK_ID_NONE = -1
+
+# The schema version that this code expects. If it finds an older version (or
+# no version!) at startup, it will auto-migrate to the latest version
+#
+# Version 0: Base, no version found
+#
+# Version 1: No changes, except adding the version and coping with upgrade
+#
+# Version 2: Add "lock_reason" field in the port table, and code to deal with
+# it
+DATABASE_SCHEMA_VERSION = 2
+
+from Vland.errors import CriticalError, InputError, NotFoundError
+
+class VlanDB:
+ def __init__(self, db_name="vland", username="vland", readonly=True):
+ try:
+ self.connection = psycopg2.connect(database=db_name, user=username)
+ # Create first cursor for normal usage - returns tuples
+ self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
+ # Create second cursor for full-row lookups - returns a dict
+ # instead, much more useful in the admin interface
+ self.dictcursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
+ if not readonly:
+ self._init_state()
+ except Exception as e:
+ logging.error("Failed to access database: %s", e)
+ raise
+
+ def __del__(self):
+ self.cursor.close()
+ self.dictcursor.close()
+ self.connection.close()
+
+ # Create the state table (if needed) and add its only record
+ #
+ # Use the stored record of the expected database schema to track what
+ # version the on-disk database is, and upgrade it to match the current code
+ # if necessary.
+ def _init_state(self):
+ found_db = False
+ current_db_version = 0
+ try:
+ sql = "SELECT * FROM state"
+ self.cursor.execute(sql)
+ found_db = True
+ except psycopg2.ProgrammingError:
+ self.connection.commit() # state doesn't exist; clear error
+ sql = "CREATE TABLE state (last_modified TIMESTAMP, schema_version INTEGER)"
+ self.cursor.execute(sql)
+ # We've just created a version 1 database
+ current_db_version = 1
+
+ if found_db:
+ # Grab the version of the database we have
+ try:
+ sql = "SELECT schema_version FROM state"
+ self.cursor.execute(sql)
+ current_db_version = self.cursor.fetchone()[0]
+ # No version found ==> we have "version 0"
+ except psycopg2.ProgrammingError:
+ self.connection.commit() # state doesn't exist; clear error
+ current_db_version = 0
+
+ # Now delete the existing state record, we'll write a new one in a
+ # moment
+ self.cursor.execute('DELETE FROM state')
+ logging.info("Found a database, version %d", current_db_version)
+
+ # Apply upgrades here!
+ if current_db_version < 1:
+ logging.info("Upgrading database to match schema version 1")
+ sql = "ALTER TABLE state ADD schema_version INTEGER"
+ self.cursor.execute(sql)
+ logging.info("Schema version 1 upgrade successful")
+
+ if current_db_version < 2:
+ logging.info("Upgrading database to match schema version 2")
+ sql = "ALTER TABLE port ADD lock_reason VARCHAR(64)"
+ self.cursor.execute(sql)
+ logging.info("Schema version 2 upgrade successful")
+
+ sql = "INSERT INTO state (last_modified, schema_version) VALUES (%s, %s)"
+ data = (datetime.datetime.now(), DATABASE_SCHEMA_VERSION)
+ self.cursor.execute(sql, data)
+ self.connection.commit()
+
+ # Create a new switch in the database. Switches are really simple
+ # devices - they're just containers for ports.
+ #
+ # Constraints:
+ # Switches must be uniquely named
+ def create_switch(self, name):
+
+ switch_id = self.get_switch_id_by_name(name)
+ if switch_id is not None:
+ raise InputError("Switch name %s already exists" % name)
+
+ try:
+ sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id"
+ data = (name, )
+ self.cursor.execute(sql, data)
+ switch_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+
+ return switch_id
+
+ # Create a new port in the database. Three of the fields are
+ # created with default values (is_locked, is_trunk, trunk_id)
+ # here, and should be updated separately if desired. For the
+ # current_vlan_id and base_vlan_id fields, *BE CAREFUL* that you
+ # have already looked up the correct VLAN_ID for each. This is
+ # *NOT* the same as the VLAN tag (likely to be 1). You Have Been
+ # Warned!
+ #
+ # Constraints:
+ # 1. The switch referred to must already exist
+ # 2. The VLANs mentioned here must already exist
+ # 3. (Switch/name) must be unique
+ # 4. (Switch/number) must be unique
+ def create_port(self, switch_id, name, number, current_vlan_id, base_vlan_id):
+
+ switch = self.get_switch_by_id(switch_id)
+ if switch is None:
+ raise NotFoundError("Switch ID %d does not exist" % int(switch_id))
+
+ for vlan_id in (current_vlan_id, base_vlan_id):
+ vlan = self.get_vlan_by_id(vlan_id)
+ if vlan is None:
+ raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id))
+
+ port_id = self.get_port_by_switch_and_name(switch_id, name)
+ if port_id is not None:
+ raise InputError("Already have a port %s on switch ID %d" % (name, int(switch_id)))
+
+ port_id = self.get_port_by_switch_and_number(switch_id, int(number))
+ if port_id is not None:
+ raise InputError("Already have a port %d on switch ID %d" % (int(number), int(switch_id)))
+
+ try:
+ sql = "INSERT INTO port (name, number, switch_id, is_locked, lock_reason, is_trunk, current_vlan_id, base_vlan_id, trunk_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING port_id"
+ data = (name, number, switch_id,
+ False, "",
+ False,
+ current_vlan_id, base_vlan_id, TRUNK_ID_NONE)
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+
+ return port_id
+
+ # Create a new vlan in the database. We locally add a creation
+ # timestamp, for debug purposes. If vlans seems to be sticking
+ # around, we'll be able to see when they were created.
+ #
+ # Constraints:
+ # Names and tags must be unique
+ # Tags must be in the range 1-4095 (802.1q spec)
+ # Names can be any free-form text, length 1-32 characters
+ def create_vlan(self, name, tag, is_base_vlan):
+
+ if int(tag) < 1 or int(tag) > 4095:
+ raise InputError("VLAN tag %d is outside of the valid range (1-4095)" % int(tag))
+
+ if (len(name) < 1) or (len(name) > 32):
+ raise InputError("VLAN name %s is invalid (must be 1-32 chars)" % name)
+
+ vlan_id = self.get_vlan_id_by_name(name)
+ if vlan_id is not None:
+ raise InputError("VLAN name %s is already in use" % name)
+
+ vlan_id = self.get_vlan_id_by_tag(tag)
+ if vlan_id is not None:
+ raise InputError("VLAN tag %d is already in use" % int(tag))
+
+ try:
+ dt = datetime.datetime.now()
+ sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id"
+ data = (name, tag, is_base_vlan, dt)
+ self.cursor.execute(sql, data)
+ vlan_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+
+ return vlan_id
+
+ # Create a new trunk in the database, linking two ports. Trunks
+ # are really simple objects for our use - they're just containers
+ # for 2 ports.
+ #
+ # Constraints:
+ # 1. Both ports listed must already exist.
+ # 2. Both ports must be in trunk mode.
+ # 3. Both must not be locked.
+ # 4. Both must not already be in a trunk.
+ def create_trunk(self, port_id1, port_id2):
+
+ for port_id in (port_id1, port_id2):
+ port = self.get_port_by_id(int(port_id))
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+ if not port['is_trunk']:
+ raise InputError("Port ID %d is not in trunk mode" % int(port_id))
+ if port['is_locked']:
+ raise InputError("Port ID %d is locked" % int(port_id))
+ if port['trunk_id'] != TRUNK_ID_NONE:
+ raise InputError("Port ID %d is already on trunk ID %d" % (int(port_id), int(port['trunk_id'])))
+
+ try:
+ # Add the trunk itself
+ dt = datetime.datetime.now()
+ sql = "INSERT INTO trunk (creation_time) VALUES (%s) RETURNING trunk_id"
+ data = (dt, )
+ self.cursor.execute(sql, data)
+ trunk_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ # And update the ports
+ for port_id in (port_id1, port_id2):
+ self._set_port_trunk(port_id, trunk_id)
+ except:
+ self.delete_trunk(trunk_id)
+ raise
+
+ return trunk_id
+
+ # Internal helper function
+ def _delete_row(self, table, field, value):
+ try:
+ sql = "DELETE FROM %s WHERE %s = %s" % (table, field, '%s')
+ data = (value,)
+ self.cursor.execute(sql, data)
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+
+ # Delete the specified switch
+ #
+ # Constraints:
+ # 1. The switch must exist
+ # 2. The switch may not be referenced by any ports -
+ # delete them first!
+ def delete_switch(self, switch_id):
+ switch = self.get_switch_by_id(switch_id)
+ if switch is None:
+ raise NotFoundError("Switch ID %d does not exist" % int(switch_id))
+ ports = self.get_ports_by_switch(switch_id)
+ if ports is not None:
+ raise InputError("Cannot delete switch ID %d when it still has %d ports" %
+ (int(switch_id), len(ports)))
+ self._delete_row("switch", "switch_id", switch_id)
+ return switch_id
+
+ # Delete the specified port
+ #
+ # Constraints:
+ # 1. The port must exist
+ # 2. The port must not be locked
+ def delete_port(self, port_id):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+ if port['is_locked']:
+ raise InputError("Cannot delete port ID %d as it is locked" % int(port_id))
+ self._delete_row("port", "port_id", port_id)
+ return port_id
+
+ # Delete the specified VLAN
+ #
+ # Constraints:
+ # 1. The VLAN must exist
+ # 2. The VLAN may not contain any ports - move or delete them first!
+ def delete_vlan(self, vlan_id):
+ vlan = self.get_vlan_by_id(vlan_id)
+ if vlan is None:
+ raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id))
+ ports = self.get_ports_by_current_vlan(vlan_id)
+ if ports is not None:
+ raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
+ (int(vlan_id), len(ports)))
+ ports = self.get_ports_by_base_vlan(vlan_id)
+ if ports is not None:
+ raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
+ (int(vlan_id), len(ports)))
+ self._delete_row("vlan", "vlan_id", vlan_id)
+ return vlan_id
+
+ # Delete the specified trunk
+ #
+ # Constraints:
+ # 1. The trunk must exist
+ #
+ # Any ports attached will be detached (i.e. moved to trunk TRUNK_ID_NONE)
+ def delete_trunk(self, trunk_id):
+ trunk = self.get_trunk_by_id(trunk_id)
+ if trunk is None:
+ raise NotFoundError("Trunk ID %d does not exist" % int(trunk_id))
+ ports = self.get_ports_by_trunk(trunk_id)
+ for port_id in ports:
+ self._set_port_trunk(port_id, TRUNK_ID_NONE)
+ self._delete_row("trunk", "trunk_id", trunk_id)
+ return trunk_id
+
+ # Find the lowest unused VLAN tag and return it
+ #
+ # Constraints:
+ # None
+ def find_lowest_unused_vlan_tag(self):
+ sql = "SELECT tag FROM vlan ORDER BY tag ASC"
+ self.cursor.execute(sql,)
+
+ # Walk through the list, looking for gaps
+ last = 1
+ result = None
+
+ for record in self.cursor:
+ if (record[0] - last) > 1:
+ result = last + 1
+ break
+ last = record[0]
+
+ if result is None:
+ result = last + 1
+
+ if result > 4093:
+ raise CriticalError("Can't find any VLAN tags remaining for allocation!")
+
+ return result
+
+ # Grab one column from one row of a query on one column; useful as
+ # a quick wrapper
+ def _get_element(self, select_field, table, compare_field, value):
+
+ if value is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field)
+
+ # We really want to use psycopg's type handling deal with the
+ # (potentially) user-supplied data in the value field, so we
+ # have to pass (sql,data) through to cursor.execute. However,
+ # we can't have psycopg do all the argument substitution here
+ # as it will quote all the params like the table name. That
+ # doesn't work. So, we substitute a "%s" for "%s" here so we
+ # keep it after python's own string substitution.
+ sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s")
+
+ # Now, the next icky thing: we need to make sure that we're
+ # passing a dict so that psycopg2 can pick it apart properly
+ # for its own substitution code. We force this with the
+ # trailing comma here
+ data = (value, )
+ self.cursor.execute(sql, data)
+
+ if self.cursor.rowcount > 0:
+ return self.cursor.fetchone()[0]
+ else:
+ return None
+
+ # Grab one column from one row of a query on 2 columns; useful as
+ # a quick wrapper
+ def _get_element2(self, select_field, table, compare_field1, value1, compare_field2, value2):
+
+ if value1 is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field1)
+ if value2 is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field2)
+
+ # We really want to use psycopg's type handling deal with the
+ # (potentially) user-supplied data in the value field, so we
+ # have to pass (sql,data) through to cursor.execute. However,
+ # we can't have psycopg do all the argument substitution here
+ # as it will quote all the params like the table name. That
+ # doesn't work. So, we substitute a "%s" for "%s" here so we
+ # keep it after python's own string substitution.
+ sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s")
+
+ data = (value1, value2)
+ self.cursor.execute(sql, data)
+
+ if self.cursor.rowcount > 0:
+ return self.cursor.fetchone()[0]
+ else:
+ return None
+
+ # Grab one column from multiple rows of a query; useful as a quick
+ # wrapper
+ def _get_multi_elements(self, select_field, table, compare_field, value, sort_field):
+
+ if value is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field)
+
+ # We really want to use psycopg's type handling deal with the
+ # (potentially) user-supplied data in the value field, so we
+ # have to pass (sql,data) through to cursor.execute. However,
+ # we can't have psycopg do all the argument substitution here
+ # as it will quote all the params like the table name. That
+ # doesn't work. So, we substitute a "%s" for "%s" here so we
+ # keep it after python's own string substitution.
+ sql = "SELECT %s FROM %s WHERE %s = %s ORDER BY %s ASC" % (select_field, table, compare_field, "%s", sort_field)
+
+ # Now, the next icky thing: we need to make sure that we're
+ # passing a dict so that psycopg2 can pick it apart properly
+ # for its own substitution code. We force this with the
+ # trailing comma here
+ data = (value, )
+ self.cursor.execute(sql, data)
+
+ if self.cursor.rowcount > 0:
+ results = []
+ for record in self.cursor:
+ results.append(record[0])
+ return results
+ else:
+ return None
+
+ # Grab one column from multiple rows of a 2-part query; useful as
+ # a wrapper
+ def _get_multi_elements2(self, select_field, table, compare_field1, value1, compare_field2, value2, sort_field):
+
+ if value1 is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field1)
+ if value2 is None:
+ raise ValueError("Asked to look up using None as a key in %s" % compare_field2)
+
+ # We really want to use psycopg's type handling deal with the
+ # (potentially) user-supplied data in the value field, so we
+ # have to pass (sql,data) through to cursor.execute. However,
+ # we can't have psycopg do all the argument substitution here
+ # as it will quote all the params like the table name. That
+ # doesn't work. So, we substitute a "%s" for "%s" here so we
+ # keep it after python's own string substitution.
+ sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s ORDER by %s ASC" % (select_field, table, compare_field1, "%s", compare_field2, "%s", sort_field)
+
+ data = (value1, value2)
+ self.cursor.execute(sql, data)
+
+ if self.cursor.rowcount > 0:
+ results = []
+ for record in self.cursor:
+ results.append(record[0])
+ return results
+ else:
+ return None
+
+ # Simple lookup: look up a switch by ID, and return all the
+ # details of that switch.
+ #
+ # Returns None on failure.
+ def get_switch_by_id(self, switch_id):
+ return self._get_row("switch", "switch_id", int(switch_id))
+
+ # Simple lookup: look up a switch by name, and return the ID of
+ # that switch.
+ #
+ # Returns None on failure.
+ def get_switch_id_by_name(self, name):
+ return self._get_element("switch_id", "switch", "name", name)
+
+ # Simple lookup: look up a switch by ID, and return the name of
+ # that switch.
+ #
+ # Returns None on failure.
+ def get_switch_name_by_id(self, switch_id):
+ return self._get_element("name", "switch", "switch_id", int(switch_id))
+
+ # Simple lookup: look up a port by ID, and return all the details
+ # of that port.
+ #
+ # Returns None on failure.
+ def get_port_by_id(self, port_id):
+ return self._get_row("port", "port_id", int(port_id))
+
+ # Simple lookup: look up a switch by ID, and return the IDs of all
+ # the ports on that switch.
+ #
+ # Returns None on failure.
+ def get_ports_by_switch(self, switch_id):
+ return self._get_multi_elements("port_id", "port", "switch_id", int(switch_id), "port_id")
+
+ # More complex lookup: look up all the trunk ports on a switch by
+ # ID
+ #
+ # Returns None on failure.
+ def get_trunk_port_names_by_switch(self, switch_id):
+ return self._get_multi_elements2("name", "port", "switch_id", int(switch_id), "is_trunk", True, "port_id")
+
+ # Simple lookup: look up a port by its name and its parent switch
+ # by ID, and return the ID of the port.
+ #
+ # Returns None on failure.
+ def get_port_by_switch_and_name(self, switch_id, name):
+ return self._get_element2("port_id", "port", "switch_id", int(switch_id), "name", name)
+
+ # Simple lookup: look up a port by its external name and its
+ # parent switch by ID, and return the ID of the port.
+ #
+ # Returns None on failure.
+ def get_port_by_switch_and_number(self, switch_id, number):
+ return self._get_element2("port_id", "port", "switch_id", int(switch_id), "number", int(number))
+
+ # Simple lookup: look up a port by ID, and return the current VLAN
+ # id of that port.
+ #
+ # Returns None on failure.
+ def get_current_vlan_id_by_port(self, port_id):
+ return self._get_element("current_vlan_id", "port", "port_id", int(port_id))
+
+ # Simple lookup: look up a port by ID, and return the mode of that port.
+ #
+ # Returns None on failure.
+ def get_port_mode(self, port_id):
+ is_trunk = self._get_element("is_trunk", "port", "port_id", int(port_id))
+ if is_trunk is not None:
+ if is_trunk:
+ return "trunk"
+ else:
+ return "access"
+ return None
+
+ # Simple lookup: look up a port by ID, and return the base VLAN
+ # id of that port.
+ #
+ # Returns None on failure.
+ def get_base_vlan_id_by_port(self, port_id):
+ return self._get_element("base_vlan_id", "port", "port_id", int(port_id))
+
+ # Simple lookup: look up a current VLAN by ID, and return the IDs
+ # of all the ports on that VLAN.
+ #
+ # Returns None on failure.
+ def get_ports_by_current_vlan(self, vlan_id):
+ return self._get_multi_elements("port_id", "port", "current_vlan_id", int(vlan_id), "port_id")
+
+ # Simple lookup: look up a base VLAN by ID, and return the IDs
+ # of all the ports on that VLAN.
+ #
+ # Returns None on failure.
+ def get_ports_by_base_vlan(self, vlan_id):
+ return self._get_multi_elements("port_id", "port", "base_vlan_id", int(vlan_id), "port_id")
+
+ # Simple lookup: look up a trunk by ID, and return the IDs of the
+ # ports on both ends of that trunk.
+ #
+ # Returns None on failure.
+ def get_ports_by_trunk(self, trunk_id):
+ return self._get_multi_elements("port_id", "port", "trunk_id", int(trunk_id), "port_id")
+
+ # Simple lookup: look up a VLAN by ID, and return all the details
+ # of that VLAN.
+ #
+ # Returns None on failure.
+ def get_vlan_by_id(self, vlan_id):
+ return self._get_row("vlan", "vlan_id", int(vlan_id))
+
+ # Simple lookup: look up a VLAN by name, and return the ID of that
+ # VLAN.
+ #
+ # Returns None on failure.
+ def get_vlan_id_by_name(self, name):
+ return self._get_element("vlan_id", "vlan", "name", name)
+
+ # Simple lookup: look up a VLAN by tag, and return the ID of that
+ # VLAN.
+ #
+ # Returns None on failure.
+ def get_vlan_id_by_tag(self, tag):
+ return self._get_element("vlan_id", "vlan", "tag", int(tag))
+
+ # Simple lookup: look up a VLAN by ID, and return the name of that
+ # VLAN.
+ #
+ # Returns None on failure.
+ def get_vlan_name_by_id(self, vlan_id):
+ return self._get_element("name", "vlan", "vlan_id", int(vlan_id))
+
+ # Simple lookup: look up a VLAN by ID, and return the tag of that
+ # VLAN.
+ #
+ # Returns None on failure.
+ def get_vlan_tag_by_id(self, vlan_id):
+ return self._get_element("tag", "vlan", "vlan_id", int(vlan_id))
+
+ # Simple lookup: look up a trunk by ID, and return all the details
+ # of that trunk.
+ #
+ # Returns None on failure.
+ def get_trunk_by_id(self, trunk_id):
+ return self._get_row("trunk", "trunk_id", int(trunk_id))
+
+ # Get the last-modified time for the database
+ def get_last_modified_time(self):
+ sql = "SELECT last_modified FROM state"
+ self.cursor.execute(sql)
+ return self.cursor.fetchone()[0]
+
+ # Grab one row of a query on one column; useful as a quick wrapper
+ def _get_row(self, table, field, value):
+
+ # We really want to use psycopg's type handling deal with the
+ # (potentially) user-supplied data in the value field, so we
+ # have to pass (sql,data) through to cursor.execute. However,
+ # we can't have psycopg do all the argument substitution here
+ # as it will quote all the params like the table name. That
+ # doesn't work. So, we substitute a "%s" for "%s" here so we
+ # keep it after python's own string substitution.
+ sql = "SELECT * FROM %s WHERE %s = %s" % (table, field, "%s")
+
+ # Now, the next icky thing: we need to make sure that we're
+ # passing a dict so that psycopg2 can pick it apart properly
+ # for its own substitution code. We force this with the
+ # trailing comma here
+ data = (value, )
+ self.dictcursor.execute(sql, data)
+ return self.dictcursor.fetchone()
+
+ # (Un)Lock a port in the database. This can only be done through
+ # the admin interface, and will stop API users from modifying
+ # settings on the port. Use this to lock down ports that are used
+ # for PDUs and other core infrastructure
+ def set_port_is_locked(self, port_id, is_locked, lock_reason=""):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+ try:
+ sql = "UPDATE port SET is_locked=%s, lock_reason=%s WHERE port_id=%s RETURNING port_id"
+ data = (is_locked, lock_reason, port_id)
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise InputError("lock failed on Port ID %d" % int(port_id))
+ return port_id
+
+ # Set the mode of a port in the database. Valid values for mode
+ # are "trunk" and "access"
+ def set_port_mode(self, port_id, mode):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+ if mode == "access":
+ is_trunk = False
+ elif mode == "trunk":
+ is_trunk = True
+ else:
+ raise InputError("Port mode %s is not valid" % mode)
+ try:
+ sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s RETURNING port_id"
+ data = (is_trunk, port_id)
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+ return port_id
+
+ # Set the current vlan of a port in the database. The VLAN is
+ # passed by ID.
+ #
+ # Constraints:
+ # 1. The port must already exist
+ # 2. The port must not be a trunk port
+ # 3. The port must not be locked
+ # 1. The VLAN must already exist in the database
+ def set_current_vlan(self, port_id, vlan_id):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+
+ if port['is_trunk'] or port['is_locked']:
+ raise CriticalError("The port is locked")
+
+ vlan = self.get_vlan_by_id(vlan_id)
+ if vlan is None:
+ raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id))
+
+ try:
+ sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s RETURNING port_id"
+ data = (vlan_id, port_id)
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+ return port_id
+
+ # Set the base vlan of a port in the database. The VLAN is
+ # passed by ID.
+ #
+ # Constraints:
+ # 1. The port must already exist
+ # 2. The port must not be a trunk port
+ # 3. The port must not be locked
+ # 4. The VLAN must already exist in the database
+ def set_base_vlan(self, port_id, vlan_id):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+
+ if port['is_trunk'] or port['is_locked']:
+ raise CriticalError("The port is locked")
+
+ vlan = self.get_vlan_by_id(vlan_id)
+ if vlan is None:
+ raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id))
+ if not vlan['is_base_vlan']:
+ raise InputError("VLAN ID %d is not a base VLAN" % int(vlan_id))
+
+ try:
+ sql = "UPDATE port SET base_vlan_id=%s WHERE port_id=%s RETURNING port_id"
+ data = (vlan_id, port_id)
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+ return port_id
+
+ # Internal function: Attach a port to a trunk in the database.
+ #
+ # Constraints:
+ # 1. The port must already exist
+ # 2. The port must not be locked
+ def _set_port_trunk(self, port_id, trunk_id):
+ port = self.get_port_by_id(port_id)
+ if port is None:
+ raise NotFoundError("Port ID %d does not exist" % int(port_id))
+ if port['is_locked']:
+ raise CriticalError("The port is locked")
+ try:
+ sql = "UPDATE port SET trunk_id=%s WHERE port_id=%s RETURNING port_id"
+ data = (int(trunk_id), int(port_id))
+ self.cursor.execute(sql, data)
+ port_id = self.cursor.fetchone()[0]
+ self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),))
+ self.connection.commit()
+ except:
+ self.connection.rollback()
+ raise
+ return port_id
+
+ # Trivial helper function to return all the rows in a given table
+ def _dump_table(self, table, order):
+ result = []
+ self.dictcursor.execute("SELECT * FROM %s ORDER by %s ASC" % (table, order))
+ record = self.dictcursor.fetchone()
+ while record != None:
+ result.append(record)
+ record = self.dictcursor.fetchone()
+ return result
+
+ def all_switches(self):
+ return self._dump_table("switch", "switch_id")
+
+ def all_ports(self):
+ return self._dump_table("port", "port_id")
+
+ def all_vlans(self):
+ return self._dump_table("vlan", "vlan_id")
+
+ def all_trunks(self):
+ return self._dump_table("trunk", "trunk_id")
+
+if __name__ == '__main__':
+ db = VlanDB()
+ s = db.all_switches()
+ print 'The DB knows about %d switch(es)' % len(s)
+ print s
+ p = db.all_ports()
+ print 'The DB knows about %d port(s)' % len(p)
+ print p
+ v = db.all_vlans()
+ print 'The DB knows about %d vlan(s)' % len(v)
+ print v
+ t = db.all_trunks()
+ print 'The DB knows about %d trunks(s)' % len(t)
+ print t
+
+ print 'First free VLAN tag is %d' % db.find_lowest_unused_vlan_tag()