test_associate.py 15 KB
Newer Older
1
"""
Gregory P. Smith's avatar
Gregory P. Smith committed
2
TestCases for DB.associate.
3 4 5 6 7 8 9 10 11 12 13 14 15 16
"""

import sys, os, string
import tempfile
import time
from pprint import pprint

try:
    from threading import Thread, currentThread
    have_threads = 1
except ImportError:
    have_threads = 0

import unittest
17
from test_all import verbose
18

19
try:
20 21 22
    # For Pythons w/distutils pybsddb
    from bsddb3 import db, dbshelve
except ImportError:
23 24
    # For Python 2.3
    from bsddb import db, dbshelve
25

26 27 28 29 30
try:
    from bsddb3 import test_support
except ImportError:
    from test import test_support

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

#----------------------------------------------------------------------


musicdata = {
1 : ("Bad English", "The Price Of Love", "Rock"),
2 : ("DNA featuring Suzanne Vega", "Tom's Diner", "Rock"),
3 : ("George Michael", "Praying For Time", "Rock"),
4 : ("Gloria Estefan", "Here We Are", "Rock"),
5 : ("Linda Ronstadt", "Don't Know Much", "Rock"),
6 : ("Michael Bolton", "How Am I Supposed To Live Without You", "Blues"),
7 : ("Paul Young", "Oh Girl", "Rock"),
8 : ("Paula Abdul", "Opposites Attract", "Rock"),
9 : ("Richard Marx", "Should've Known Better", "Rock"),
10: ("Rod Stewart", "Forever Young", "Rock"),
11: ("Roxette", "Dangerous", "Rock"),
12: ("Sheena Easton", "The Lover In Me", "Rock"),
13: ("Sinead O'Connor", "Nothing Compares 2 U", "Rock"),
14: ("Stevie B.", "Because I Love You", "Rock"),
15: ("Taylor Dayne", "Love Will Lead You Back", "Rock"),
16: ("The Bangles", "Eternal Flame", "Rock"),
17: ("Wilson Phillips", "Release Me", "Rock"),
18: ("Billy Joel", "Blonde Over Blue", "Rock"),
19: ("Billy Joel", "Famous Last Words", "Rock"),
20: ("Billy Joel", "Lullabye (Goodnight, My Angel)", "Rock"),
21: ("Billy Joel", "The River Of Dreams", "Rock"),
22: ("Billy Joel", "Two Thousand Years", "Rock"),
23: ("Janet Jackson", "Alright", "Rock"),
24: ("Janet Jackson", "Black Cat", "Rock"),
25: ("Janet Jackson", "Come Back To Me", "Rock"),
26: ("Janet Jackson", "Escapade", "Rock"),
27: ("Janet Jackson", "Love Will Never Do (Without You)", "Rock"),
28: ("Janet Jackson", "Miss You Much", "Rock"),
29: ("Janet Jackson", "Rhythm Nation", "Rock"),
30: ("Janet Jackson", "State Of The World", "Rock"),
31: ("Janet Jackson", "The Knowledge", "Rock"),
32: ("Spyro Gyra", "End of Romanticism", "Jazz"),
33: ("Spyro Gyra", "Heliopolis", "Jazz"),
34: ("Spyro Gyra", "Jubilee", "Jazz"),
35: ("Spyro Gyra", "Little Linda", "Jazz"),
36: ("Spyro Gyra", "Morning Dance", "Jazz"),
37: ("Spyro Gyra", "Song for Lorraine", "Jazz"),
38: ("Yes", "Owner Of A Lonely Heart", "Rock"),
39: ("Yes", "Rhythm Of Love", "Rock"),
40: ("Cusco", "Dream Catcher", "New Age"),
41: ("Cusco", "Geronimos Laughter", "New Age"),
42: ("Cusco", "Ghost Dance", "New Age"),
43: ("Blue Man Group", "Drumbone", "New Age"),
44: ("Blue Man Group", "Endless Column", "New Age"),
45: ("Blue Man Group", "Klein Mandelbrot", "New Age"),
46: ("Kenny G", "Silhouette", "Jazz"),
47: ("Sade", "Smooth Operator", "Jazz"),
83 84
48: ("David Arkenstone", "Papillon (On The Wings Of The Butterfly)",
     "New Age"),
85 86 87 88 89 90
49: ("David Arkenstone", "Stepping Stars", "New Age"),
50: ("David Arkenstone", "Carnation Lily Lily Rose", "New Age"),
51: ("David Lanz", "Behind The Waterfall", "New Age"),
52: ("David Lanz", "Cristofori's Dream", "New Age"),
53: ("David Lanz", "Heartsounds", "New Age"),
54: ("David Lanz", "Leaves on the Seine", "New Age"),
91
99: ("unknown artist", "Unnamed song", "Unknown"),
92 93
}

94 95 96 97 98
#----------------------------------------------------------------------

class AssociateErrorTestCase(unittest.TestCase):
    def setUp(self):
        self.filename = self.__class__.__name__ + '.db'
99
        homeDir = os.path.join(tempfile.gettempdir(), 'db_home%d'%os.getpid())
100
        self.homeDir = homeDir
101
        try:
102
            os.mkdir(homeDir)
103
        except os.error:
104 105 106 107
            import glob
            files = glob.glob(os.path.join(self.homeDir, '*'))
            for file in files:
                os.remove(file)
108 109 110 111 112
        self.env = db.DBEnv()
        self.env.open(homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)

    def tearDown(self):
        self.env.close()
113
        self.env = None
114
        test_support.rmtree(self.homeDir)
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

    def test00_associateDBError(self):
        if verbose:
            print '\n', '-=' * 30
            print "Running %s.test00_associateDBError..." % \
                  self.__class__.__name__

        dupDB = db.DB(self.env)
        dupDB.set_flags(db.DB_DUP)
        dupDB.open(self.filename, "primary", db.DB_BTREE, db.DB_CREATE)

        secDB = db.DB(self.env)
        secDB.open(self.filename, "secondary", db.DB_BTREE, db.DB_CREATE)

        # dupDB has been configured to allow duplicates, it can't
        # associate with a secondary.  BerkeleyDB will return an error.
131 132 133 134 135 136 137 138 139 140 141
        try:
            def f(a,b): return a+b
            dupDB.associate(secDB, f)
        except db.DBError:
            # good
            secDB.close()
            dupDB.close()
        else:
            secDB.close()
            dupDB.close()
            self.fail("DBError exception was expected")
142 143 144



145 146 147 148 149
#----------------------------------------------------------------------


class AssociateTestCase(unittest.TestCase):
    keytype = ''
150 151
    envFlags = 0
    dbFlags = 0
152 153 154

    def setUp(self):
        self.filename = self.__class__.__name__ + '.db'
155
        homeDir = os.path.join(tempfile.gettempdir(), 'db_home%d'%os.getpid())
156
        self.homeDir = homeDir
157
        try:
158
            os.mkdir(homeDir)
159
        except os.error:
160 161 162 163
            import glob
            files = glob.glob(os.path.join(self.homeDir, '*'))
            for file in files:
                os.remove(file)
164 165
        self.env = db.DBEnv()
        self.env.open(homeDir, db.DB_CREATE | db.DB_INIT_MPOOL |
166
                               db.DB_INIT_LOCK | db.DB_THREAD | self.envFlags)
167 168 169 170

    def tearDown(self):
        self.closeDB()
        self.env.close()
171
        self.env = None
172 173 174 175 176
        import glob
        files = glob.glob(os.path.join(self.homeDir, '*'))
        for file in files:
            os.remove(file)

177
    def addDataToDB(self, d, txn=None):
178 179 180
        for key, value in musicdata.items():
            if type(self.keytype) == type(''):
                key = "%02d" % key
181
            d.put(key, string.join(value, '|'), txn=txn)
182

183
    def createDB(self, txn=None):
184 185
        self.cur = None
        self.secDB = None
186
        self.primary = db.DB(self.env)
187
        self.primary.set_get_returns_none(2)
188 189
        if db.version() >= (4, 1):
            self.primary.open(self.filename, "primary", self.dbtype,
190
                          db.DB_CREATE | db.DB_THREAD | self.dbFlags, txn=txn)
191 192
        else:
            self.primary.open(self.filename, "primary", self.dbtype,
193
                          db.DB_CREATE | db.DB_THREAD | self.dbFlags)
194 195

    def closeDB(self):
196 197
        if self.cur:
            self.cur.close()
198
            self.cur = None
199 200
        if self.secDB:
            self.secDB.close()
201
            self.secDB = None
202
        self.primary.close()
203
        self.primary = None
204 205 206 207

    def getDB(self):
        return self.primary

208

209 210 211
    def test01_associateWithDB(self):
        if verbose:
            print '\n', '-=' * 30
212 213
            print "Running %s.test01_associateWithDB..." % \
                  self.__class__.__name__
214 215 216

        self.createDB()

217 218 219 220
        self.secDB = db.DB(self.env)
        self.secDB.set_flags(db.DB_DUP)
        self.secDB.set_get_returns_none(2)
        self.secDB.open(self.filename, "secondary", db.DB_BTREE,
221
                   db.DB_CREATE | db.DB_THREAD | self.dbFlags)
222
        self.getDB().associate(self.secDB, self.getGenre)
223 224 225

        self.addDataToDB(self.getDB())

226
        self.finish_test(self.secDB)
227 228 229 230 231


    def test02_associateAfterDB(self):
        if verbose:
            print '\n', '-=' * 30
232 233
            print "Running %s.test02_associateAfterDB..." % \
                  self.__class__.__name__
234 235 236 237

        self.createDB()
        self.addDataToDB(self.getDB())

238 239 240
        self.secDB = db.DB(self.env)
        self.secDB.set_flags(db.DB_DUP)
        self.secDB.open(self.filename, "secondary", db.DB_BTREE,
241
                   db.DB_CREATE | db.DB_THREAD | self.dbFlags)
242 243

        # adding the DB_CREATE flag will cause it to index existing records
244
        self.getDB().associate(self.secDB, self.getGenre, db.DB_CREATE)
245

246
        self.finish_test(self.secDB)
247 248


249
    def finish_test(self, secDB, txn=None):
250
        # 'Blues' should not be in the secondary database
251
        vals = secDB.pget('Blues', txn=txn)
252 253
        assert vals == None, vals

254
        vals = secDB.pget('Unknown', txn=txn)
255 256 257 258 259
        assert vals[0] == 99 or vals[0] == '99', vals
        vals[1].index('Unknown')
        vals[1].index('Unnamed')
        vals[1].index('unknown')

260 261
        if verbose:
            print "Primary key traversal:"
262
        self.cur = self.getDB().cursor(txn)
263
        count = 0
264
        rec = self.cur.first()
265 266 267 268 269 270 271 272
        while rec is not None:
            if type(self.keytype) == type(''):
                assert string.atoi(rec[0])  # for primary db, key is a number
            else:
                assert rec[0] and type(rec[0]) == type(0)
            count = count + 1
            if verbose:
                print rec
273
            rec = self.cur.next()
274 275 276 277 278
        assert count == len(musicdata) # all items accounted for


        if verbose:
            print "Secondary key traversal:"
279
        self.cur = secDB.cursor(txn)
280
        count = 0
281 282

        # test cursor pget
283
        vals = self.cur.pget('Unknown', flags=db.DB_LAST)
284 285 286 287 288 289
        assert vals[1] == 99 or vals[1] == '99', vals
        assert vals[0] == 'Unknown'
        vals[2].index('Unknown')
        vals[2].index('Unnamed')
        vals[2].index('unknown')

290
        vals = self.cur.pget('Unknown', data='wrong value', flags=db.DB_GET_BOTH)
291 292
        assert vals == None, vals

293
        rec = self.cur.first()
294 295 296 297 298
        assert rec[0] == "Jazz"
        while rec is not None:
            count = count + 1
            if verbose:
                print rec
299
            rec = self.cur.next()
300 301
        # all items accounted for EXCEPT for 1 with "Blues" genre
        assert count == len(musicdata)-1
302

303
        self.cur = None
304

305 306 307
    def getGenre(self, priKey, priData):
        assert type(priData) == type("")
        if verbose:
308
            print 'getGenre key: %r data: %r' % (priKey, priData)
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
        genre = string.split(priData, '|')[2]
        if genre == 'Blues':
            return db.DB_DONOTINDEX
        else:
            return genre


#----------------------------------------------------------------------


class AssociateHashTestCase(AssociateTestCase):
    dbtype = db.DB_HASH

class AssociateBTreeTestCase(AssociateTestCase):
    dbtype = db.DB_BTREE

class AssociateRecnoTestCase(AssociateTestCase):
    dbtype = db.DB_RECNO
    keytype = 0

329 330 331 332 333 334
#----------------------------------------------------------------------

class AssociateBTreeTxnTestCase(AssociateBTreeTestCase):
    envFlags = db.DB_INIT_TXN
    dbFlags = 0

335
    def txn_finish_test(self, sDB, txn):
336 337 338 339 340 341 342 343
        try:
            self.finish_test(sDB, txn=txn)
        finally:
            if self.cur:
                self.cur.close()
                self.cur = None
            if txn:
                txn.commit()
344 345

    def test13_associate_in_transaction(self):
346 347 348 349 350
        if verbose:
            print '\n', '-=' * 30
            print "Running %s.test13_associateAutoCommit..." % \
                  self.__class__.__name__

351 352 353 354 355 356 357 358 359 360 361 362 363
        txn = self.env.txn_begin()
        try:
            self.createDB(txn=txn)

            self.secDB = db.DB(self.env)
            self.secDB.set_flags(db.DB_DUP)
            self.secDB.set_get_returns_none(2)
            self.secDB.open(self.filename, "secondary", db.DB_BTREE,
                       db.DB_CREATE | db.DB_THREAD, txn=txn)
            if db.version() >= (4,1):
                self.getDB().associate(self.secDB, self.getGenre, txn=txn)
            else:
                self.getDB().associate(self.secDB, self.getGenre)
364

365 366 367 368
            self.addDataToDB(self.getDB(), txn=txn)
        except:
            txn.abort()
            raise
369

370
        self.txn_finish_test(self.secDB, txn=txn)
371

372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392

#----------------------------------------------------------------------

class ShelveAssociateTestCase(AssociateTestCase):

    def createDB(self):
        self.primary = dbshelve.open(self.filename,
                                     dbname="primary",
                                     dbenv=self.env,
                                     filetype=self.dbtype)

    def addDataToDB(self, d):
        for key, value in musicdata.items():
            if type(self.keytype) == type(''):
                key = "%02d" % key
            d.put(key, value)    # save the value as is this time


    def getGenre(self, priKey, priData):
        assert type(priData) == type(())
        if verbose:
393
            print 'getGenre key: %r data: %r' % (priKey, priData)
394 395 396 397 398 399 400 401
        genre = priData[2]
        if genre == 'Blues':
            return db.DB_DONOTINDEX
        else:
            return genre


class ShelveAssociateHashTestCase(ShelveAssociateTestCase):
Tim Peters's avatar
Tim Peters committed
402
    dbtype = db.DB_HASH
403 404

class ShelveAssociateBTreeTestCase(ShelveAssociateTestCase):
Tim Peters's avatar
Tim Peters committed
405
    dbtype = db.DB_BTREE
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440

class ShelveAssociateRecnoTestCase(ShelveAssociateTestCase):
    dbtype = db.DB_RECNO
    keytype = 0


#----------------------------------------------------------------------

class ThreadedAssociateTestCase(AssociateTestCase):

    def addDataToDB(self, d):
        t1 = Thread(target = self.writer1,
                    args = (d, ))
        t2 = Thread(target = self.writer2,
                    args = (d, ))

        t1.start()
        t2.start()
        t1.join()
        t2.join()

    def writer1(self, d):
        for key, value in musicdata.items():
            if type(self.keytype) == type(''):
                key = "%02d" % key
            d.put(key, string.join(value, '|'))

    def writer2(self, d):
        for x in range(100, 600):
            key = 'z%2d' % x
            value = [key] * 4
            d.put(key, string.join(value, '|'))


class ThreadedAssociateHashTestCase(ShelveAssociateTestCase):
Tim Peters's avatar
Tim Peters committed
441
    dbtype = db.DB_HASH
442 443

class ThreadedAssociateBTreeTestCase(ShelveAssociateTestCase):
Tim Peters's avatar
Tim Peters committed
444
    dbtype = db.DB_BTREE
445 446 447 448 449 450 451 452

class ThreadedAssociateRecnoTestCase(ShelveAssociateTestCase):
    dbtype = db.DB_RECNO
    keytype = 0


#----------------------------------------------------------------------

453 454
def test_suite():
    suite = unittest.TestSuite()
455 456

    if db.version() >= (3, 3, 11):
457 458
        suite.addTest(unittest.makeSuite(AssociateErrorTestCase))

459 460 461
        suite.addTest(unittest.makeSuite(AssociateHashTestCase))
        suite.addTest(unittest.makeSuite(AssociateBTreeTestCase))
        suite.addTest(unittest.makeSuite(AssociateRecnoTestCase))
462

463 464
        if db.version() >= (4, 1):
            suite.addTest(unittest.makeSuite(AssociateBTreeTxnTestCase))
465

466 467 468
        suite.addTest(unittest.makeSuite(ShelveAssociateHashTestCase))
        suite.addTest(unittest.makeSuite(ShelveAssociateBTreeTestCase))
        suite.addTest(unittest.makeSuite(ShelveAssociateRecnoTestCase))
469 470

        if have_threads:
471 472 473
            suite.addTest(unittest.makeSuite(ThreadedAssociateHashTestCase))
            suite.addTest(unittest.makeSuite(ThreadedAssociateBTreeTestCase))
            suite.addTest(unittest.makeSuite(ThreadedAssociateRecnoTestCase))
474

475
    return suite
476 477 478


if __name__ == '__main__':
479
    unittest.main(defaultTest='test_suite')