#!/neo/opt/bin/python
#
# odb.py
#
# Object Database Api
#
# Written by David Jeske <jeske@neotonic.com>, 2001/07. 
# Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998.
#
# Copyright (C) 2001, by David Jeske and Neotonic
#
# Goals:
#       - a simple object-like interface to database data
#       - database independent (someday)
#       - relational-style "rigid schema definition"
#       - object style easy-access
#
#
# simple example:
#    import odb_mysql
#    import odb
#    odb_mysql.DB('localhost')

import string
import sys
from log import *

import MySQLdb

eNoSuchColumn         = "odb.eNoSuchColumn"
eNonUniqueMatchSpec   = "odb.eNonUniqueMatchSpec"
eNoMatchingRows       = "odb.eNoMatchingRows"
eInternalError        = "odb.eInternalError"
eInvalidMatchSpec     = "odb.eInvalidMatchSpec"
eInvalidData          = "odb.eInvalidData"
eUnsavedObjectLost    = "odb.eUnsavedObjectLost"
eDuplicateKey         = "odb.eDuplicateKey"

#####################################
# COLUMN TYPES                       
################                     ######################
# typename     ####################### size data means:
#              #                     # 
kInteger       = "kInteger"          # -
kFixedString   = "kFixedString"      # size
kVarString     = "kVarString"        # maxsize
kBigString     = "kBigString"        # -
kIncInteger    = "kIncInteger"       # -
kDateTime      = "kDateTime"
kTimeStamp     = "kTimeStamp"


DEBUG = 0

##############
# Database
#
# this will ultimately turn into a mostly abstract base class for
# the DB adaptors for different database types....
#

class Database:
    def __init__(self,db):
        self._tables = {}
        self.db = db
        self._cursor = None

    # __init__ = None
    # list_tables = None   # list_tables() -> ['a_table','b_table']
    # list_fields = None   # list_fields(tbl_name) -> ['
    # checkTable = None    # checkTable
    # createTable

    def defaultCursor(self):
        if self._cursor is None:
            self._cursor = self.db.cursor()
        return self._cursor

    def escape(self,str):
        return MySQLdb.escape_string(str)
    def defaultRowClass(self):
	return Row

    def defaultRowListClass(self):
	return None

    def addTable(self, attrname, tblname, tblclass, rowClass = None, check = 0, create = 0, rowListClass = None):
        self._tables[attrname] = tblclass(self, tblname, rowClass=rowClass, check=check, create=create, rowListClass=rowListClass)

    def close(self):
        for name, tbl in self._tables.items():
            tbl.db = None
        self._tables = {}
        if self.db is not None:
            self.db.close()
            self.db = None

    def __getattr__(self, key):
        try:
            return self._tables[key]
        except KeyError:
            raise AttributeError, "unknown attribute %s" % (key)
        

##########################################
# Table
#


class Table:
    def subclassinit(self):
        pass
    def __init__(self,database,table_name,rowClass = None, check = 0, create = 0, rowListClass = None):
	self.db = database
	self.__table_name = table_name
	if rowClass:
	    self.__defaultRowClass = rowClass
	else:
	    self.__defaultRowClass = database.defaultRowClass()

	if rowListClass:
	    self.__defaultRowListClass = rowListClass
	else:
	    self.__defaultRowListClass = database.defaultRowListClass()

	# get this stuff ready!
	
	self.__column_list = []
	self.__vcolumn_list = []
	self.__columns_locked = 0
	self.__has_value_column = 0

	# this will be used during init...
	self.__col_def_hash = None
	self.__vcol_def_hash = None
	self.__primary_key_list = None
        self.__relations_by_table = {}

	# ask the subclass to def his rows
	self._defineRows()

	# get ready to run!
	self.__lockColumnsAndInit()

        self.subclassinit()
        
	if create:
	    self.db.createTable(self)

	if check:
	    self.db.checkTable(self)

    def getColumnDef(self,column_name):
        try:
            return self.__col_def_hash[column_name]
        except KeyError:
            try:
                return self.__vcol_def_hash[column_name]
            except KeyError:
                raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name)
        

    def getColumnList(self):
        return self.__column_list + self.__vcolumn_list

    def convertDataForColumn(self,data,col_name):
	try:
	    col_def = self.__col_def_hash[col_name]
	except KeyError:
	    try:
		col_def = self.__vcol_def_hash[col_name]
	    except KeyError:
		raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)

	c_name,c_type,c_options = col_def

        if c_type == kIncInteger:
            raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name)

	if c_type == kInteger:
	    try:
		return long(data)
	    except ValueError:
		raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name)
	else:
	    if type(data) == type(long(0)):
		return "%d" % data
	    else:
		return str(data)

    def getPrimaryKeyList(self):
	return self.__primary_key_list
    
    def getTableName(self):
	return self.__table_name
    def hasValueColumn(self):
	return self.__has_value_column

    def hasColumn(self,name):
	return self.__col_def_hash.has_key(name)
    def hasVColumn(self,name):
	return self.__vcol_def_hash.has_key(name)
	

    def _defineRows(self):
	raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()"

    def __lockColumnsAndInit(self):
	# add a 'odb_value column' before we lockdown the table def
	if self.__has_value_column:
	    self.d_addColumn("odb_value",kBigText,default='')

	self.__columns_locked = 1
	# walk column list and make lookup hashes, primary_key_list, etc..

	primary_key_list = []
	col_def_hash = {}
	for a_col in self.__column_list:
	    name,type,options = a_col
	    col_def_hash[name] = a_col
	    if options.has_key('primarykey'):
		primary_key_list.append(name)

	self.__col_def_hash = col_def_hash
	self.__primary_key_list = primary_key_list

	# setup the value columns!

	if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0):
	    raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()"

	vcol_def_hash = {}
	for a_col in self.__vcolumn_list:
	    name,type,size_data,options = a_col
	    vcol_def_hash[name] = a_col

	self.__vcol_def_hash = vcol_def_hash
	
	
    def __checkColumnLock(self):
	if self.__columns_locked:
	    raise "can't change column definitions outside of subclass' _defineRows() method!"

    # table definition methods, these are only available while inside the
    # subclass's _defineRows method
    #
    # Ex:
    #
    # import odb
    # class MyTable(odb.Table):
    #   def _defineRows(self):
    #     self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1)
    #     self.d_addColumn("name",kVarString,120)
    #     self.d_addColumn("type",kInteger,
    #                      enum_values = { 0 : "alive", 1 : "dead" }

    def d_addColumn(self,col_name,ctype,size=None,primarykey = 0, notnull = 0,indexed=0,
		    default=None,unique=0,autoincrement=0,safeupdate=0,enum_values = None,
                    relations=None):

	self.__checkColumnLock()

	options = {}
	options['default']       = default
	if primarykey:
	    options['primarykey']    = primarykey
	if indexed:
	    options['indexed']       = indexed
	if unique:
	    options['unique']        = unique
	if safeupdate:
	    options['safeupdate']    = safeupdate
	if autoincrement:
	    options['autoincrement'] = autoincrement
	if notnull:
	    options['notnull']       = notnull
	if size:
	    options['size']          = size
	if enum_values:
	    options['enum_values']   = enum_values
	    inv_enum_values = {}
	    for k,v in enum_values.items():
		if inv_enum_values.has_key(v):
		    raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name
		else:
		    inv_enum_values[v] = k
	    options['inv_enum_values'] = inv_enum_values
        if relations:
            options['relations']      = relations
            for a_relation in relations:
                table, foreign_column_name = a_relation
                if self.__relations_by_table.has_key(table):
                    raise eInvalidData, "multiple relations for the same foreign table are not yet supported" 
                self.__relations_by_table[table] = (col_name,foreign_column_name)

	
	self.__column_list.append( (col_name,ctype,options) )
	

    def d_addValueColumn(self):
	self.__checkColumnLock()
	self.__has_value_column = 1

    def d_addVColumn(self,col_name,type,size=None,default=None):
	self.__checkColumnLock()

	col_name = string.lower(col_name)

	if (not self.__has_value_column):
	    raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first"

	options = {}
	if default:
	    options['default'] = default
	if size:
	    options['size']    = size

	self.__vcolumn_list.append( (col_name,type,options) )

    #####################
    # _checkColMatchSpec(col_match_spec,should_match_unique_row = 0)
    #
    # raise an error if the col_match_spec contains invalid columns, or
    # (in the case of should_match_unique_row) if it does not fully specify
    # a unique row.
    #
    # NOTE: we don't currently support where clauses with value column fields!
    #
    
    def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0):
	if type(col_match_spec) == type([]):
	    if type(col_match_spec[0]) != type((0,)):
		raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
	elif type(col_match_spec) == type((0,)):
	    col_match_spec = [ col_match_spec ]
        elif type(col_match_spec) == type(None):
            if should_match_unique_row:
                raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
            else:
                return None
	else:
	    raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"

	if should_match_unique_row:
            unique_column_lists = []

            # first the primary key list
	    my_primary_key_list = []
	    for a_key in self.__primary_key_list:
		my_primary_key_list.append(a_key)

            # then other unique keys
            for a_col in self.__column_list:
                col_name,a_type,options = a_col
                if options.has_key('unique'):
                    unique_column_lists.append( (col_name, [col_name]) )

            unique_column_lists.append( ('primary_key', my_primary_key_list) )
                
	
	new_col_match_spec = []
	for a_col in col_match_spec:
	    name,val = a_col
	    newname = string.lower(name)
	    if not self.__col_def_hash.has_key(newname):
		raise eNoSuchColumn, "no such column in match spec %s" % name

	    new_col_match_spec.append( (newname,val) )

	    if should_match_unique_row:
                for name,a_list in unique_column_lists:
                    try:
                        a_list.remove(newname)
                    except ValueError:
                        # it's okay if they specify too many columns!
                        pass

	if should_match_unique_row:
            for name,a_list in unique_column_lists:
                if len(a_list) == 0:
                    # we matched at least one unique colum spec!
                    # log("using unique column (%s) for query %s" % (name,col_match_spec))
                    return new_col_match_spec
            
            raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec

	return new_col_match_spec

    def __buildWhereClause (self, col_match_spec,other_clauses = None):
	sql_where_list = []

        if not col_match_spec is None:
            for m_col in col_match_spec:
                m_col_name,m_col_val = m_col
                c_name,c_type,c_options = self.__col_def_hash[m_col_name]
                if c_type in (kIncInteger, kInteger):
                    try:
                        m_col_val_long = long(m_col_val)
                    except ValueError:
                        raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name)
                        
                    sql_where_list.append("%s = %d" % (c_name, m_col_val_long))
                else:
                    sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val)))

        if other_clauses is None:
            pass
        elif type(other_clauses) == type(""):
            sql_where_list = sql_where_list + [other_clauses]
        elif type(other_clauses) == type([]):
            sql_where_list = sql_where_list + other_clauses
        else:
            raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses)
                    
        return sql_where_list

    def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None,
                    skip_to = None, join = None):
	if cursor is None:
	    cursor = self.db.defaultCursor()

        # build column list
        sql_columns = []
        for name,t,options in self.__column_list:
            sql_columns.append(name)

        # build join information

        joined_cols = []
        joined_cols_hash = {}
        join_clauses = []
        if not join is None:
            for a_table,retrieve_foreign_cols in join:
                try:
                    my_col,foreign_col = self.__relations_by_table[a_table]
                    for a_col in retrieve_foreign_cols:
                        full_col_name = "%s.%s" % (my_col,a_col)
                        joined_cols_hash[full_col_name] = 1
                        joined_cols.append(full_col_name)
                        sql_columns.append( full_col_name )

                    join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col))
                        
                except KeyError:
                    eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name)
                    
        # start buildling SQL
    	sql = "select %s from %s" % (string.join(sql_columns,","),
                                     self.__table_name)

        # add join clause
        if join_clauses:
            sql = sql + string.join(join_clauses," ")
	
	# add where clause elements
        sql_where_list = self.__buildWhereClause (col_match_spec,where)
	if sql_where_list:
	    sql = sql + " where %s" % (string.join(sql_where_list," and "))

        # add order by clause
        if order_by:
            sql = sql + " order by %s " % string.join(order_by,",")

        # add limit
        if not limit_to is None:
            if not skip_to is None:
                sql = sql + " limit %s, %s" % (skip_to,limit_to)
            else:
                sql = sql + " limit %s" % limit_to
        else:
            if not skip_to is None:
                raise eInvalidData, "can't specify skip_to without limit_to in MySQL"

        dlog(DEV_SELECT,sql)
	cursor.execute(sql)

	if self.__defaultRowListClass:
	    return_rows = self.__defaultRowListClass()
	else:
	    return_rows = []
	    
	# should do fetchmany!
	all_rows = cursor.fetchall()
	for a_row in all_rows:
	    data_dict = {}

	    col_num = 0
            
            #	    for a_col in cursor.description:
            #		(name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col
            for name in sql_columns:
		if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name):
		    # only include declared columns!
		    data_dict[name] = a_row[col_num]
		    col_num = col_num + 1

	    newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols)

	    return_rows.append(newrowobj)
	    
	return return_rows

    def __deleteRow(self,a_row,cursor = None):
	if cursor is None:
	    cursor = self.db.defaultCursor()

        # build the where clause!
        match_spec = a_row.getPKMatchSpec()
        sql_where_list = self.__buildWhereClause (match_spec)

        sql = "delete from %s where %s" % (self.__table_name,
                                           string.join(sql_where_list," and "))
        dlog(DEV_UPDATE,sql)
        cursor.execute(sql)
       

    def __updateRowList(self,a_row_list,cursor = None):
	if cursor is None:
	    cursor = self.db.defaultCursor()

	for a_row in a_row_list:
	    update_list = a_row.changedList()

	    # build the set list!
	    sql_set_list = []
	    for a_change in update_list:
		col_name,col_val,col_inc_val = a_change
		c_name,c_type,c_options = self.__col_def_hash[col_name]

                if col_val is None:
                    sql_set_list.append("%s = NULL" % c_name)
                else:
                    if c_type == kInteger:
                        sql_set_list.append("%s = %d" % (c_name, long(col_val)))
                    elif c_type == kIncInteger:
                        sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val)))
                    else:
                        sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))

	    # build the where clause!
	    match_spec = a_row.getPKMatchSpec()
            sql_where_list = self.__buildWhereClause (match_spec)

	    if sql_set_list:
		sql = "update %s set %s where %s" % (self.__table_name,
						 string.join(sql_set_list,","),
						 string.join(sql_where_list," and "))

                dlog(DEV_UPDATE,sql)
                try:
                    cursor.execute(sql)
                except Exception, reason:
                    if string.find(str(reason), "Duplicate entry") != -1:
                        raise eDuplicateKey, reason
                    raise Exception, reason
		a_row.markClean()

    def __insertRow(self,a_row_obj,cursor = None):
	if cursor is None:
	    cursor = self.db.defaultCursor()

	sql_col_list = []
	sql_data_list = []
	auto_increment_column_name = None

	for a_col in self.__column_list:
	    name,type,options = a_col

	    try:
		data = a_row_obj[name]

		sql_col_list.append(name)
                if data is None:
                    sql_data_list.append("NULL")
                else:
                    if type == kInteger:
                        sql_data_list.append("%d" % data)
                    elif type == kIncInteger:
                        sql_data_list.append("%d" % self.__inc_coldata.get(name,0))
                    else:
                        sql_data_list.append("'%s'" % self.db.escape(data))

	    except KeyError:
		if options.has_key("autoincrement"):
		    if auto_increment_column_name:
			raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name)
		    else:
			auto_increment_column_name = name
		
	
	sql = "insert into %s (%s) values (%s)" % (self.__table_name,
						   string.join(sql_col_list,","),
						   string.join(sql_data_list,","))

        dlog(DEV_UPDATE,sql)
        try:
          cursor.execute(sql)
        except Exception, reason:
          if string.find(str(reason), "Duplicate entry") != -1:
            raise eDuplicateKey, reason
          raise Exception, reason
            
	if auto_increment_column_name:
	    a_row_obj[auto_increment_column_name] = cursor.insert_id()

    # ----------------------------------------------------
    #   Helper methods for Rows...
    # ----------------------------------------------------


	
    #####################
    # r_deleteRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "delete()" method
    # but you can call it yourself if you want
    #

    def r_deleteRow(self,a_row_obj, cursor = None):
	curs = cursor
	self.__deleteRow(a_row_obj, cursor = curs)


    #####################
    # r_updateRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "save()" method
    # but you can call it yourself if you want
    #

    def r_updateRow(self,a_row_obj, cursor = None):
	curs = cursor
	self.__updateRowList([a_row_obj], cursor = curs)

    #####################
    # InsertRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "save()" method
    # but you can call it yourself if you want
    #

    def r_insertRow(self,a_row_obj, cursor = None):
	curs = cursor
	self.__insertRow(a_row_obj, cursor = curs)


    # ----------------------------------------------------
    #   Public Methods
    # ----------------------------------------------------


	
    #####################
    # deleteRow(col_match_spec)
    #
    # The col_match_spec paramaters must include all primary key columns.
    #
    # Ex:
    #    a_row = tbl.fetchRow( ("order_id", 1) )
    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )


    def deleteRow(self,col_match_spec, where=None):
        n_match_spec = self._fixColMatchSpec(col_match_spec)
        cursor = self.db.defaultCursor()

        # build sql where clause elements
        sql_where_list = self.__buildWhereClause (n_match_spec,where)
        if not sql_where_list:
            return

        sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and "))

        dlog(DEV_UPDATE,sql)
        cursor.execute(sql)
	
    #####################
    # fetchRow(col_match_spec)
    #
    # The col_match_spec paramaters must include all primary key columns.
    #
    # Ex:
    #    a_row = tbl.fetchRow( ("order_id", 1) )
    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )


    def fetchRow(self, col_match_spec, cursor = None):
	n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1)

	rows = self.__fetchRows(n_match_spec, cursor = cursor)
	if len(rows) == 0:
	    raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec)

	if len(rows) > 1:
	    raise eInternalError, "unique where clause shouldn't return > 1 row"

	return rows[0]
	    

    #####################
    # fetchRows(col_match_spec)
    #
    # Ex:
    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )


    def fetchRows(self, col_match_spec = None, cursor = None, where = None, order_by = None, limit_to = None, skip_to = None, join = None):
	n_match_spec = self._fixColMatchSpec(col_match_spec)

	return self.__fetchRows(n_match_spec,
                                cursor = cursor,
                                where = where,
                                order_by = order_by,
                                limit_to = limit_to,
                                skip_to = skip_to,
                                join = join)

    def fetchRowCount (self, col_match_spec = None, cursor = None, where = None):
	n_match_spec = self._fixColMatchSpec(col_match_spec)

        sql_where_list = self.__buildWhereClause (n_match_spec,where)

    	sql = "select count(*) from %s" % self.__table_name
	if sql_where_list:
	    sql = "%s where %s" % (sql,string.join(sql_where_list," and "))

        if cursor is None:
          cursor = self.db.defaultCursor()
        dlog(DEV_SELECT,sql)
	cursor.execute(sql)
        try:
	    count, = cursor.fetchone()
        except TypeError:
            count = 0
        return count


    #####################
    # fetchAllRows()
    #
    # Ex:
    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )

    def fetchAllRows(self):
        try:
            return self.__fetchRows([])
        except eNoMatchingRows:
            return []

    def newRow(self):
	row = self.__defaultRowClass(self,None,create=1)
        for (cname, ctype, opts) in self.__column_list:
            if opts['default'] is not None and ctype is not kIncInteger:
                row[cname] = opts['default']
        return row

class Row:
    __instance_data_locked  = 0
    def subclassinit(self):
        pass
    def __init__(self,_table,data_dict,create=0,joined_cols = None):

        self._inside_getattr = 0  # stop recursive __getattr__
	self._table = _table
	self._should_insert = create
        self._rowInactive = None
        self._joinedRows = []
	
	self.__pk_match_spec = None
	self.__vcoldata = {}
        self.__inc_coldata = {}

        self.__joined_cols_dict = {}
        for a_col in joined_cols or []:
            self.__joined_cols_dict[a_col] = 1
	
	if create:
	    self.__coldata = {}
	else:
	    if type(data_dict) != type({}):
		raise eInternalError, "rowdict instantiate with bad data_dict"
	    self.__coldata = data_dict
	    self.__unpackVColumn()

	self.markClean()

        self.subclassinit()
	self.__instance_data_locked = 1

    def joinRowData(self,another_row):
        self._joinedRows.append(another_row)

    def getPKMatchSpec(self):
	return self.__pk_match_spec

    def markClean(self):
	self.__vcolchanged = 0
	self.__colchanged_dict = {}

	if not self._should_insert:
	    # rebuild primary column match spec
	    new_match_spec = []
	    for col_name in self._table.getPrimaryKeyList():
		try:
		    rdata = self[col_name]
		except KeyError:
		    raise eInternalError, "must have primary key data filled in to save Row(%s)" % self._table.getTableName()
		    
		new_match_spec.append( (col_name, rdata) )
	    self.__pk_match_spec = new_match_spec

    def __unpackVColumn(self):
	if self._table.hasValueColumn():
	    pass
	
    def __packVColumn(self):
	if self._table.hasValueColumn():
	    pass

    ## ----- utility stuff ----------------------------------

    def __del__(self):
	# check for unsaved changes
	changed_list = self.changedList()
	if len(changed_list):
            info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256])
            if 0:
                raise eUnsavedObjectLost, info
            else:
                sys.stderr.write(info)
                

    def __repr__(self):
	return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata))

    ## ---- class emulation --------------------------------

    def __getattr__(self,key):
        if self._inside_getattr:
          raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName())
        try:
            self._inside_getattr = 1
            try:
                return self[key]
            except KeyError:
                if self._table.hasColumn(key) or self._table.hasVColumn(key):
                    return None
                else:
                    raise AttributeError, "unknown field %s in Row(%s)" % (key,self._table.getTableName())
        finally:
            self._inside_getattr = 0

    def __setattr__(self,key,val):
	if not self.__instance_data_locked:
	    self.__dict__[key] = val
	else:
	    my_dict = self.__dict__
	    if my_dict.has_key(key):
		my_dict[key] = val
	    else:
		# try and put it into the rowdata
		try:
		    self[key] = val
		except KeyError, reason:
		    raise AttributeError, reason


    ## ---- dict emulation ---------------------------------
    
    def __getitem__(self,key):
        self.checkRowActive()
        
	try:
	    return self.__coldata[key]
	except KeyError:
            try:
                return self.__vcoldata[key]
            except KeyError:
                for a_joined_row in self._joinedRows:
                    try:
                        return a_joined_row[key]
                    except KeyError:
                        pass

                raise KeyError, "unknown column %s in %s" % (key,self)

    def __setitem__(self,key,data):
        self.checkRowActive()
        
	try:
	    newdata = self._table.convertDataForColumn(data,key)
	except eNoSuchColumn, reason:
	    raise KeyError, reason

	if self._table.hasColumn(key):
	    self.__coldata[key] = newdata
	    self.__colchanged_dict[key] = 1
	elif self._table.hasVColumn(key):
	    self.__vcoldata[key] = newdata
	    self.__vcolchanged = 1
	else:
            for a_joined_row in self._joinedRows:
                try:
                    a_joined_row[key] = data
                    return
                except KeyError:
                    pass
	    raise KeyError, "unknown column name %s" % key

    def __delitem__(self,key,data):
        self.checkRowActive()
        
	if self.table.hasVColumn(key):
	    del self.__vcoldata[key]
	else:
            for a_joined_row in self._joinedRows:
                try:
                    del a_joined_row[key]
                    return
                except KeyError:
                    pass
	    raise KeyError, "unknown column name %s" % key


    def copyFrom(self,source):
        for name,t,options in self._table.getColumnList():
            if not options.has_key("autoincrement"):
                self[name] = source[name]


    # make sure that .keys(), and .items() come out in a nice order!

    def keys(self):
        self.checkRowActive()
        
        key_list = []
        for name,t,options in self._table.getColumnList():
            key_list.append(name)
        for name in self.__joined_cols_dict.keys():
            key_list.append(name)

        for a_joined_row in self._joinedRows:
            key_list = key_list + a_joined_row.keys()
            
        return key_list


    def items(self):
        self.checkRowActive()
        
        item_list = []
        for name,t,options in self._table.getColumnList():
            item_list.append( (name,self[name]) )
        for name in self.__joined_cols_dict.keys():
            item_list.append( (name,self[name]) )

        for a_joined_row in self._joinedRows:
            item_list = item_list + a_joined_row.items()


        return item_list

    def values(elf):
        self.checkRowActive()

        value_list = self.__coldata.values() + self.__vcoldata.values()

        for a_joined_row in self._joinedRows:
            value_list = value_list + a_joined_row.values()

        return value_list
        

    def __len__(self):
        self.checkRowActive()
        
	my_len = len(self.__coldata) + len(self.__vcoldata)

        for a_joined_row in self._joinedRows:
            my_len = my_len + len(a_joined_row)

        return my_len

    def has_key(self,key):
        self.checkRowActive()
        
	if self.__coldata.has_key(key) or self.__vcoldata.has_key(key):
	    return 1
	else:

            for a_joined_row in self._joinedRows:
                if a_joined_row.has_key(key):
                    return 1
	    return 0
	
    def get(self,key,default = None):
        self.checkRowActive()

        
        
	if self.__coldata.has_key(key):
	    return self.__coldata[key]
	elif self.__vcoldata.has_key(key):
	    return self.__vcoldata[key]
	else:
            for a_joined_row in self._joinedRows:
                try:
                    return a_joined_row.get(key,default)
                except eNoSuchColumn:
                    pass
            
	    raise eNoSuchColumn, "no such column %s" % key

    def inc(self,key,count=1):
        self.checkRowActive()

        if self.__coldata.has_key(key):
            try:
                self.__inc_coldata[key] = self.__inc_coldata[key] + count
            except KeyError:
                self.__inc_coldata[key] = count

            self.__colchanged_dict[key] = 1
    

    ## ----------------------------------
    ## real interface


    def fillDefaults(self):
	for field_def in self._table.fieldList():
	    name,type,size,options = field_def
	    if options.has_key("default"):
		self[name] = options["default"]

    ###############
    # changedList()
    #
    # returns a list of tuples for the columns which have changed
    #
    #   changedList() -> [ ('name', 'fred'), ('age', 20) ]

    def changedList(self):
	if self.__vcolchanged:
	    self.__packVColumn()

	changed_list = []
	for a_col in self.__colchanged_dict.keys():
	    changed_list.append( (a_col,self[a_col],self.__inc_coldata.get(a_col,None)) )

	return changed_list

    def discard(self):
	self.__coldata = None
	self.__vcoldata = None
	self.__colchanged_dict = {}
	self.__vcolchanged = 0

    def delete(self,cursor = None):
        self.checkRowActive()

        
        fromTable = self._table
        curs = cursor
        fromTable.r_deleteRow(self,cursor=curs)
        self._rowInactive = "deleted"

    def save(self,cursor = None):
	toTable = self._table

        self.checkRowActive()

	if self._should_insert:
	    toTable.r_insertRow(self)
	    self._should_insert = 0
	    self.markClean()  # rebuild the primary key list
	else:
            curs = cursor
	    toTable.r_updateRow(self,cursor = curs)

	# the table will mark us clean!
	# self.markClean()

    def checkRowActive(self):
        if self._rowInactive:
            raise eInvalidData, "row is inactive: %s" % self._rowInactive


## -----------------------------------------------------------------------
##                            T  E  S  T S
## -----------------------------------------------------------------------

	
def TEST(output=log):
    class AgentsTable(Table):
	def _defineRows(self):
	    self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1)
	    self.d_addColumn("login",kVarString,200,notnull=1)
	    self.d_addColumn("ext_email",kVarString,200,notnull=1)
	    self.d_addColumn("hashed_pw",kVarString,20,notnull=1)
	    self.d_addColumn("name",kVarString,200)
	    self.d_addColumn("auth_level",kInteger,None)


    import MySQLdb
    rdb = MySQLdb.connect(host = 'localhost',user='root', passwd = '', db='testdb')
    ndb = MySQLdb.connect(host = 'localhost',user='trakken', passwd = 'trakpas', db='testdb')
    db = Database(ndb)

    tbl = AgentsTable(db,"agents")

    cursor = rdb.cursor()

    # ---------------------------------------------------------------
    # initialize
    output("drop table agents")
    cursor.execute("drop table agents")   # clean out the table
    output("creating table")
    cursor.execute("create table agents (agent_id integer not null primary key auto_increment, login varchar(200) not null, unique (login), ext_email varchar(200) not null, hashed_pw varchar(20) not null, name varchar(200), auth_level integer default 0)")

    TEST_INSERT_COUNT = 5

    # ---------------------------------------------------------------
    # make sure we can catch a missing row

    try:
	a_row = tbl.fetchRow( ("agent_id", 1000) )
	raise "test error"
    except eNoMatchingRows:
	pass

    output("PASSED! fetch missing row test")

    # --------------------------------------------------------------
    # create new rows and insert them

    for n in range(TEST_INSERT_COUNT):
	new_id = n + 1
	
	newrow = tbl.newRow()
	newrow.name = "name #%d" % new_id
	newrow.login = "name%d" % new_id
        newrow.ext_email = "%d@name" % new_id
	newrow.save()
	if newrow.agent_id != new_id:
	    raise "new insert id (%s) does not match expected value (%d)" % (newrow.agent_id,new_id)

    output("PASSED! autoinsert test")

    # --------------------------------------------------------------
    # fetch one row
    a_row = tbl.fetchRow( ("agent_id", 1) )

    if a_row.name != "name #1":
	raise "row data incorrect"

    output("PASSED! fetch one row test")

    # ---------------------------------------------------------------
    # don't change and save it
    # (i.e. the "dummy cursor" string should never be called!)
    #
    try:
	a_row.save(cursor = "dummy cursor")
    except AttributeError, reason:
	raise "row tried to access cursor on save() when no changes were made!"

    output("PASSED! don't save when there are no changed")

    # ---------------------------------------------------------------
    # change, save, load, test
    
    a_row.auth_level = 10
    a_row.save()
    b_row = tbl.fetchRow( ("agent_id", 1) )
    if b_row.auth_level != 10:
	raise "save and load failed"
    

    output("PASSED! change, save, load")

    # --------------------------------------------------------------
    # access unknown attribute
    try:
	a = a_row.UNKNOWN_ATTRIBUTE
	raise "test error"
    except AttributeError, reason:
	pass

    try:
	a_row.UNKNOWN_ATTRIBUTE = 1
	raise "test error"
    except AttributeError, reason:
	pass

    output("PASSED! unknown attribute exception")

    # --------------------------------------------------------------
    # access unknown dict item

    try:
	a = a_row["UNKNOWN_ATTRIBUTE"]
	raise "test error"
    except KeyError, reason:
	pass

    try:
	a_row["UNKNOWN_ATTRIBUTE"] = 1
	raise "test error"
    except KeyError, reason:
	pass

    output("PASSED! unknown dict item exception")

    # --------------------------------------------------------------
    # use wrong data for column type

    try:
	a_row.agent_id = "this is a string"
	raise "test error"
    except eInvalidData, reason:
	pass

    output("PASSED! invalid data for column type")

    # --------------------------------------------------------------
    # fetch 1 rows

    rows = tbl.fetchRows( ('agent_id', 1) )
    if len(rows) != 1:
	raise "fetchRows() did not return 1 row!" % (TEST_INSERT_COUNT)

    output("PASSED! fetch one row")


    # --------------------------------------------------------------
    # fetch All rows
    
    rows = tbl.fetchAllRows()
    if len(rows) != TEST_INSERT_COUNT:
        for a_row in rows:
            output(repr(a_row))
	raise "fetchAllRows() did not return TEST_INSERT_COUNT(%d) rows!" % (TEST_INSERT_COUNT)

    output("PASSED! fetchall rows")

  
    # --------------------------------------------------------------
    # delete row object

    row = tbl.fetchRow( ('agent_id', 1) )
    row.delete()
    try:
        row = tbl.fetchRow( ('agent_id', 1) )
        raise "delete failed to delete row!"
    except eNoMatchingRows:
        pass

    # --------------------------------------------------------------
    # table deleteRow() call

    row = tbl.fetchRow( ('agent_id',2) )
    tbl.deleteRow( ('agent_id', 2) )
    try:
        row = tbl.fetchRow( ('agent_id',2) )
        raise "table delete failed"
    except eNoMatchingRows:
        pass

    

    output("\n==== ALL TESTS PASSED ====")
    

if __name__ == "__main__":
    TEST()

