connection.c 48.4 KB
Newer Older
1 2
/* connection.c - the connection type
 *
3
 * Copyright (C) 2004-2010 Gerhard Häring <gh@ghaering.de>
4 5
 *
 * This file is part of pysqlite.
6
 *
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 * This software is provided 'as-is', without any express or implied
 * warranty.  In no event will the authors be held liable for any damages
 * arising from the use of this software.
 *
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 *
 * 1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 2. Altered source versions must be plainly marked as such, and must not be
 *    misrepresented as being the original software.
 * 3. This notice may not be removed or altered from any source distribution.
 */

#include "cache.h"
#include "module.h"
26
#include "structmember.h"
27 28 29 30 31 32 33 34 35
#include "connection.h"
#include "statement.h"
#include "cursor.h"
#include "prepare_protocol.h"
#include "util.h"
#include "sqlitecompat.h"

#include "pythread.h"

36 37 38
#define ACTION_FINALIZE 1
#define ACTION_RESET 2

39 40 41 42 43 44
#if SQLITE_VERSION_NUMBER >= 3003008
#ifndef SQLITE_OMIT_LOAD_EXTENSION
#define HAVE_LOAD_EXTENSION
#endif
#endif

45
static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level);
46
static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
47

48

49
static void _sqlite3_result_error(sqlite3_context* ctx, const char* errmsg, int len)
50 51 52 53 54 55 56
{
    /* in older SQLite versions, calling sqlite3_result_error in callbacks
     * triggers a bug in SQLite that leads either to irritating results or
     * segfaults, depending on the SQLite version */
#if SQLITE_VERSION_NUMBER >= 3003003
    sqlite3_result_error(ctx, errmsg, len);
#else
57
    PyErr_SetString(pysqlite_OperationalError, errmsg);
58 59 60
#endif
}

61
int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
{
    static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL};

    char* database;
    int detect_types = 0;
    PyObject* isolation_level = NULL;
    PyObject* factory = NULL;
    int check_same_thread = 1;
    int cached_statements = 100;
    double timeout = 5.0;
    int rc;

    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist,
                                     &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))
    {
77
        return -1;
78 79
    }

80 81
    self->initialized = 1;

82 83 84
    self->begin_statement = NULL;

    self->statement_cache = NULL;
85
    self->statements = NULL;
86
    self->cursors = NULL;
87 88 89 90 91 92 93 94 95 96 97 98

    Py_INCREF(Py_None);
    self->row_factory = Py_None;

    Py_INCREF(&PyUnicode_Type);
    self->text_factory = (PyObject*)&PyUnicode_Type;

    Py_BEGIN_ALLOW_THREADS
    rc = sqlite3_open(database, &self->db);
    Py_END_ALLOW_THREADS

    if (rc != SQLITE_OK) {
99
        _pysqlite_seterror(self->db, NULL);
100 101 102 103
        return -1;
    }

    if (!isolation_level) {
104
        isolation_level = PyUnicode_FromString("");
105 106 107
        if (!isolation_level) {
            return -1;
        }
108 109 110 111
    } else {
        Py_INCREF(isolation_level);
    }
    self->isolation_level = NULL;
112
    pysqlite_connection_set_isolation_level(self, isolation_level);
113 114
    Py_DECREF(isolation_level);

115
    self->statement_cache = (pysqlite_Cache*)PyObject_CallFunction((PyObject*)&pysqlite_CacheType, "Oi", self, cached_statements);
116 117 118 119
    if (PyErr_Occurred()) {
        return -1;
    }

120 121 122 123
    self->created_statements = 0;
    self->created_cursors = 0;

    /* Create lists of weak references to statements/cursors */
124
    self->statements = PyList_New(0);
125 126
    self->cursors = PyList_New(0);
    if (!self->statements || !self->cursors) {
127 128 129
        return -1;
    }

130 131 132 133 134 135 136 137 138 139 140 141
    /* By default, the Cache class INCREFs the factory in its initializer, and
     * decrefs it in its deallocator method. Since this would create a circular
     * reference here, we're breaking it by decrementing self, and telling the
     * cache class to not decref the factory (self) in its deallocator.
     */
    self->statement_cache->decref_factory = 0;
    Py_DECREF(self);

    self->inTransaction = 0;
    self->detect_types = detect_types;
    self->timeout = timeout;
    (void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
142
#ifdef WITH_THREAD
143
    self->thread_ident = PyThread_get_thread_ident();
144
#endif
145 146 147 148 149 150 151 152 153 154 155 156
    self->check_same_thread = check_same_thread;

    self->function_pinboard = PyDict_New();
    if (!self->function_pinboard) {
        return -1;
    }

    self->collations = PyDict_New();
    if (!self->collations) {
        return -1;
    }

157 158 159 160 161 162 163 164 165 166
    self->Warning               = pysqlite_Warning;
    self->Error                 = pysqlite_Error;
    self->InterfaceError        = pysqlite_InterfaceError;
    self->DatabaseError         = pysqlite_DatabaseError;
    self->DataError             = pysqlite_DataError;
    self->OperationalError      = pysqlite_OperationalError;
    self->IntegrityError        = pysqlite_IntegrityError;
    self->InternalError         = pysqlite_InternalError;
    self->ProgrammingError      = pysqlite_ProgrammingError;
    self->NotSupportedError     = pysqlite_NotSupportedError;
167 168 169 170

    return 0;
}

171
/* Empty the entire statement cache of this connection */
172
void pysqlite_flush_statement_cache(pysqlite_Connection* self)
173
{
174 175
    pysqlite_Node* node;
    pysqlite_Statement* statement;
176 177 178 179

    node = self->statement_cache->first;

    while (node) {
180 181
        statement = (pysqlite_Statement*)(node->data);
        (void)pysqlite_statement_finalize(statement);
182 183 184 185
        node = node->next;
    }

    Py_DECREF(self->statement_cache);
186
    self->statement_cache = (pysqlite_Cache*)PyObject_CallFunction((PyObject*)&pysqlite_CacheType, "O", self);
187 188 189 190
    Py_DECREF(self);
    self->statement_cache->decref_factory = 0;
}

191
/* action in (ACTION_RESET, ACTION_FINALIZE) */
192
void pysqlite_do_all_statements(pysqlite_Connection* self, int action, int reset_cursors)
193
{
194 195 196
    int i;
    PyObject* weakref;
    PyObject* statement;
197
    pysqlite_Cursor* cursor;
198 199 200 201 202

    for (i = 0; i < PyList_Size(self->statements); i++) {
        weakref = PyList_GetItem(self->statements, i);
        statement = PyWeakref_GetObject(weakref);
        if (statement != Py_None) {
203
            Py_INCREF(statement);
204 205 206 207 208
            if (action == ACTION_RESET) {
                (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
            } else {
                (void)pysqlite_statement_finalize((pysqlite_Statement*)statement);
            }
209
            Py_DECREF(statement);
210
        }
211
    }
212 213 214 215 216 217 218 219 220 221

    if (reset_cursors) {
        for (i = 0; i < PyList_Size(self->cursors); i++) {
            weakref = PyList_GetItem(self->cursors, i);
            cursor = (pysqlite_Cursor*)PyWeakref_GetObject(weakref);
            if ((PyObject*)cursor != Py_None) {
                cursor->reset = 1;
            }
        }
    }
222 223
}

224
void pysqlite_connection_dealloc(pysqlite_Connection* self)
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
{
    Py_XDECREF(self->statement_cache);

    /* Clean up if user has not called .close() explicitly. */
    if (self->db) {
        Py_BEGIN_ALLOW_THREADS
        sqlite3_close(self->db);
        Py_END_ALLOW_THREADS
    }

    if (self->begin_statement) {
        PyMem_Free(self->begin_statement);
    }
    Py_XDECREF(self->isolation_level);
    Py_XDECREF(self->function_pinboard);
    Py_XDECREF(self->row_factory);
    Py_XDECREF(self->text_factory);
    Py_XDECREF(self->collations);
243
    Py_XDECREF(self->statements);
244
    Py_XDECREF(self->cursors);
245

246
    Py_TYPE(self)->tp_free((PyObject*)self);
247 248
}

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
/*
 * Registers a cursor with the connection.
 *
 * 0 => error; 1 => ok
 */
int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor)
{
    PyObject* weakref;

    weakref = PyWeakref_NewRef((PyObject*)cursor, NULL);
    if (!weakref) {
        goto error;
    }

    if (PyList_Append(connection->cursors, weakref) != 0) {
        Py_CLEAR(weakref);
        goto error;
    }

    Py_DECREF(weakref);

    return 1;
error:
    return 0;
}

275
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
276 277 278 279 280 281 282 283 284 285
{
    static char *kwlist[] = {"factory", NULL, NULL};
    PyObject* factory = NULL;
    PyObject* cursor;

    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
                                     &factory)) {
        return NULL;
    }

286
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
287 288 289 290
        return NULL;
    }

    if (factory == NULL) {
291
        factory = (PyObject*)&pysqlite_CursorType;
292 293 294 295
    }

    cursor = PyObject_CallFunction(factory, "O", self);

296 297
    _pysqlite_drop_unused_cursor_references(self);

298
    if (cursor && self->row_factory != Py_None) {
299
        Py_XDECREF(((pysqlite_Cursor*)cursor)->row_factory);
300
        Py_INCREF(self->row_factory);
301
        ((pysqlite_Cursor*)cursor)->row_factory = self->row_factory;
302 303 304 305 306
    }

    return cursor;
}

307
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
308 309 310
{
    int rc;

311
    if (!pysqlite_check_thread(self)) {
312 313 314
        return NULL;
    }

315
    pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
316 317 318 319 320 321 322

    if (self->db) {
        Py_BEGIN_ALLOW_THREADS
        rc = sqlite3_close(self->db);
        Py_END_ALLOW_THREADS

        if (rc != SQLITE_OK) {
323
            _pysqlite_seterror(self->db, NULL);
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
            return NULL;
        } else {
            self->db = NULL;
        }
    }

    Py_INCREF(Py_None);
    return Py_None;
}

/*
 * Checks if a connection object is usable (i. e. not closed).
 *
 * 0 => error; 1 => ok
 */
339
int pysqlite_check_connection(pysqlite_Connection* con)
340
{
341 342 343 344 345
    if (!con->initialized) {
        PyErr_SetString(pysqlite_ProgrammingError, "Base Connection.__init__ not called.");
        return 0;
    }

346
    if (!con->db) {
347
        PyErr_SetString(pysqlite_ProgrammingError, "Cannot operate on a closed database.");
348 349 350 351 352 353
        return 0;
    } else {
        return 1;
    }
}

354
PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)
355 356 357 358 359 360 361 362 363 364
{
    int rc;
    const char* tail;
    sqlite3_stmt* statement;

    Py_BEGIN_ALLOW_THREADS
    rc = sqlite3_prepare(self->db, self->begin_statement, -1, &statement, &tail);
    Py_END_ALLOW_THREADS

    if (rc != SQLITE_OK) {
365
        _pysqlite_seterror(self->db, statement);
366 367 368
        goto error;
    }

369
    rc = pysqlite_step(statement, self);
370 371 372
    if (rc == SQLITE_DONE) {
        self->inTransaction = 1;
    } else {
373
        _pysqlite_seterror(self->db, statement);
374 375 376 377 378 379 380
    }

    Py_BEGIN_ALLOW_THREADS
    rc = sqlite3_finalize(statement);
    Py_END_ALLOW_THREADS

    if (rc != SQLITE_OK && !PyErr_Occurred()) {
381
        _pysqlite_seterror(self->db, NULL);
382 383 384 385 386 387 388 389 390 391 392
    }

error:
    if (PyErr_Occurred()) {
        return NULL;
    } else {
        Py_INCREF(Py_None);
        return Py_None;
    }
}

393
PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args)
394 395 396 397 398
{
    int rc;
    const char* tail;
    sqlite3_stmt* statement;

399
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
400 401 402 403
        return NULL;
    }

    if (self->inTransaction) {
404 405
        pysqlite_do_all_statements(self, ACTION_RESET, 0);

406 407 408 409
        Py_BEGIN_ALLOW_THREADS
        rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail);
        Py_END_ALLOW_THREADS
        if (rc != SQLITE_OK) {
410
            _pysqlite_seterror(self->db, NULL);
411 412 413
            goto error;
        }

414
        rc = pysqlite_step(statement, self);
415 416 417
        if (rc == SQLITE_DONE) {
            self->inTransaction = 0;
        } else {
418
            _pysqlite_seterror(self->db, statement);
419 420 421 422 423 424
        }

        Py_BEGIN_ALLOW_THREADS
        rc = sqlite3_finalize(statement);
        Py_END_ALLOW_THREADS
        if (rc != SQLITE_OK && !PyErr_Occurred()) {
425
            _pysqlite_seterror(self->db, NULL);
426 427 428 429 430 431 432 433 434 435 436 437 438
        }

    }

error:
    if (PyErr_Occurred()) {
        return NULL;
    } else {
        Py_INCREF(Py_None);
        return Py_None;
    }
}

439
PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args)
440 441 442 443 444
{
    int rc;
    const char* tail;
    sqlite3_stmt* statement;

445
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
446 447 448 449
        return NULL;
    }

    if (self->inTransaction) {
450
        pysqlite_do_all_statements(self, ACTION_RESET, 1);
451 452

        Py_BEGIN_ALLOW_THREADS
453
        rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail);
454 455
        Py_END_ALLOW_THREADS
        if (rc != SQLITE_OK) {
456
            _pysqlite_seterror(self->db, NULL);
457 458 459
            goto error;
        }

460
        rc = pysqlite_step(statement, self);
461 462 463
        if (rc == SQLITE_DONE) {
            self->inTransaction = 0;
        } else {
464
            _pysqlite_seterror(self->db, statement);
465 466 467 468 469 470
        }

        Py_BEGIN_ALLOW_THREADS
        rc = sqlite3_finalize(statement);
        Py_END_ALLOW_THREADS
        if (rc != SQLITE_OK && !PyErr_Occurred()) {
471
            _pysqlite_seterror(self->db, NULL);
472 473 474 475 476 477 478 479 480 481 482 483 484
        }

    }

error:
    if (PyErr_Occurred()) {
        return NULL;
    } else {
        Py_INCREF(Py_None);
        return Py_None;
    }
}

485 486
static int
_pysqlite_set_result(sqlite3_context* context, PyObject* py_val)
487
{
488
    if (py_val == Py_None) {
489
        sqlite3_result_null(context);
490
    } else if (PyLong_Check(py_val)) {
491 492 493 494
        sqlite_int64 value = _pysqlite_long_as_int64(py_val);
        if (value == -1 && PyErr_Occurred())
            return -1;
        sqlite3_result_int64(context, value);
495 496
    } else if (PyFloat_Check(py_val)) {
        sqlite3_result_double(context, PyFloat_AsDouble(py_val));
497
    } else if (PyUnicode_Check(py_val)) {
498 499 500 501
        const char *str = _PyUnicode_AsString(py_val);
        if (str == NULL)
            return -1;
        sqlite3_result_text(context, str, -1, SQLITE_TRANSIENT);
502
    } else if (PyObject_CheckBuffer(py_val)) {
503 504
        const char* buffer;
        Py_ssize_t buflen;
505 506
        if (PyObject_AsCharBuffer(py_val, &buffer, &buflen) != 0) {
            PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer");
507
            return -1;
508
        }
509
        sqlite3_result_blob(context, buffer, buflen, SQLITE_TRANSIENT);
510
    } else {
511
        return -1;
512
    }
513
    return 0;
514 515
}

516
PyObject* _pysqlite_build_py_params(sqlite3_context *context, int argc, sqlite3_value** argv)
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
{
    PyObject* args;
    int i;
    sqlite3_value* cur_value;
    PyObject* cur_py_value;
    const char* val_str;
    Py_ssize_t buflen;

    args = PyTuple_New(argc);
    if (!args) {
        return NULL;
    }

    for (i = 0; i < argc; i++) {
        cur_value = argv[i];
        switch (sqlite3_value_type(argv[i])) {
            case SQLITE_INTEGER:
534
                cur_py_value = _pysqlite_long_from_int64(sqlite3_value_int64(cur_value));
535 536 537 538 539 540
                break;
            case SQLITE_FLOAT:
                cur_py_value = PyFloat_FromDouble(sqlite3_value_double(cur_value));
                break;
            case SQLITE_TEXT:
                val_str = (const char*)sqlite3_value_text(cur_value);
541
                cur_py_value = PyUnicode_FromString(val_str);
542 543
                /* TODO: have a way to show errors here */
                if (!cur_py_value) {
544
                    PyErr_Clear();
545 546 547 548 549 550
                    Py_INCREF(Py_None);
                    cur_py_value = Py_None;
                }
                break;
            case SQLITE_BLOB:
                buflen = sqlite3_value_bytes(cur_value);
551
                cur_py_value = PyBytes_FromStringAndSize(
552
                    sqlite3_value_blob(cur_value), buflen);
553 554 555 556 557 558
                break;
            case SQLITE_NULL:
            default:
                Py_INCREF(Py_None);
                cur_py_value = Py_None;
        }
559 560 561 562 563 564

        if (!cur_py_value) {
            Py_DECREF(args);
            return NULL;
        }

565 566 567 568 569 570 571
        PyTuple_SetItem(args, i, cur_py_value);

    }

    return args;
}

572
void _pysqlite_func_callback(sqlite3_context* context, int argc, sqlite3_value** argv)
573 574 575
{
    PyObject* args;
    PyObject* py_func;
576
    PyObject* py_retval = NULL;
577
    int ok;
578

579
#ifdef WITH_THREAD
580 581 582
    PyGILState_STATE threadstate;

    threadstate = PyGILState_Ensure();
583
#endif
584 585 586

    py_func = (PyObject*)sqlite3_user_data(context);

587
    args = _pysqlite_build_py_params(context, argc, argv);
588 589 590 591
    if (args) {
        py_retval = PyObject_CallObject(py_func, args);
        Py_DECREF(args);
    }
592

593
    ok = 0;
594
    if (py_retval) {
595
        ok = _pysqlite_set_result(context, py_retval) == 0;
596
        Py_DECREF(py_retval);
597 598
    }
    if (!ok) {
599 600 601 602 603 604 605
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }
        _sqlite3_result_error(context, "user-defined function raised exception", -1);
    }
606

607
#ifdef WITH_THREAD
608
    PyGILState_Release(threadstate);
609
#endif
610 611
}

612
static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_value** params)
613 614
{
    PyObject* args;
615
    PyObject* function_result = NULL;
616 617
    PyObject* aggregate_class;
    PyObject** aggregate_instance;
618
    PyObject* stepmethod = NULL;
619

620
#ifdef WITH_THREAD
621 622 623
    PyGILState_STATE threadstate;

    threadstate = PyGILState_Ensure();
624
#endif
625 626 627 628 629 630 631 632

    aggregate_class = (PyObject*)sqlite3_user_data(context);

    aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));

    if (*aggregate_instance == 0) {
        *aggregate_instance = PyObject_CallFunction(aggregate_class, "");

633
        if (PyErr_Occurred()) {
634
            *aggregate_instance = 0;
635 636 637 638 639 640
            if (_enable_callback_tracebacks) {
                PyErr_Print();
            } else {
                PyErr_Clear();
            }
            _sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1);
641
            goto error;
642 643 644 645
        }
    }

    stepmethod = PyObject_GetAttrString(*aggregate_instance, "step");
646 647
    if (!stepmethod) {
        goto error;
648 649
    }

650
    args = _pysqlite_build_py_params(context, argc, params);
651 652 653
    if (!args) {
        goto error;
    }
654 655 656 657

    function_result = PyObject_CallObject(stepmethod, args);
    Py_DECREF(args);

658
    if (!function_result) {
659 660 661 662 663 664
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }
        _sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1);
665 666
    }

667 668 669 670
error:
    Py_XDECREF(stepmethod);
    Py_XDECREF(function_result);

671
#ifdef WITH_THREAD
672
    PyGILState_Release(threadstate);
673
#endif
674 675
}

676
void _pysqlite_final_callback(sqlite3_context* context)
677
{
678
    PyObject* function_result;
679
    PyObject** aggregate_instance;
680
    _Py_IDENTIFIER(finalize);
681
    int ok;
682

683
#ifdef WITH_THREAD
684 685 686
    PyGILState_STATE threadstate;

    threadstate = PyGILState_Ensure();
687
#endif
688 689 690 691 692 693

    aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
    if (!*aggregate_instance) {
        /* this branch is executed if there was an exception in the aggregate's
         * __init__ */

694
        goto error;
695 696
    }

697
    function_result = _PyObject_CallMethodId(*aggregate_instance, &PyId_finalize, "");
698 699 700 701 702 703 704 705
    Py_DECREF(*aggregate_instance);

    ok = 0;
    if (function_result) {
        ok = _pysqlite_set_result(context, function_result) == 0;
        Py_DECREF(function_result);
    }
    if (!ok) {
706 707 708 709 710 711
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }
        _sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1);
712 713
    }

714
error:
715
#ifdef WITH_THREAD
716
    PyGILState_Release(threadstate);
717
#endif
718 719 720
    /* explicit return to avoid a compilation error if WITH_THREAD
       is not defined */
    return;
721 722
}

723
static void _pysqlite_drop_unused_statement_references(pysqlite_Connection* self)
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
{
    PyObject* new_list;
    PyObject* weakref;
    int i;

    /* we only need to do this once in a while */
    if (self->created_statements++ < 200) {
        return;
    }

    self->created_statements = 0;

    new_list = PyList_New(0);
    if (!new_list) {
        return;
    }

    for (i = 0; i < PyList_Size(self->statements); i++) {
        weakref = PyList_GetItem(self->statements, i);
743
        if (PyWeakref_GetObject(weakref) != Py_None) {
744 745 746 747 748 749 750 751 752 753
            if (PyList_Append(new_list, weakref) != 0) {
                Py_DECREF(new_list);
                return;
            }
        }
    }

    Py_DECREF(self->statements);
    self->statements = new_list;
}
754

755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
{
    PyObject* new_list;
    PyObject* weakref;
    int i;

    /* we only need to do this once in a while */
    if (self->created_cursors++ < 200) {
        return;
    }

    self->created_cursors = 0;

    new_list = PyList_New(0);
    if (!new_list) {
        return;
    }

    for (i = 0; i < PyList_Size(self->cursors); i++) {
        weakref = PyList_GetItem(self->cursors, i);
        if (PyWeakref_GetObject(weakref) != Py_None) {
            if (PyList_Append(new_list, weakref) != 0) {
                Py_DECREF(new_list);
                return;
            }
        }
    }

    Py_DECREF(self->cursors);
    self->cursors = new_list;
}

787
PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
788 789 790 791 792 793 794 795
{
    static char *kwlist[] = {"name", "narg", "func", NULL, NULL};

    PyObject* func;
    char* name;
    int narg;
    int rc;

796 797 798 799
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

800 801 802 803 804 805
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "siO", kwlist,
                                     &name, &narg, &func))
    {
        return NULL;
    }

806
    rc = sqlite3_create_function(self->db, name, narg, SQLITE_UTF8, (void*)func, _pysqlite_func_callback, NULL, NULL);
807

808 809
    if (rc != SQLITE_OK) {
        /* Workaround for SQLite bug: no error code or string is available here */
810
        PyErr_SetString(pysqlite_OperationalError, "Error creating function");
811 812
        return NULL;
    } else {
813 814
        if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1)
            return NULL;
815

816 817 818
        Py_INCREF(Py_None);
        return Py_None;
    }
819 820
}

821
PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
822 823 824 825 826 827 828 829
{
    PyObject* aggregate_class;

    int n_arg;
    char* name;
    static char *kwlist[] = { "name", "n_arg", "aggregate_class", NULL };
    int rc;

830 831 832 833
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

834 835 836 837 838
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "siO:create_aggregate",
                                      kwlist, &name, &n_arg, &aggregate_class)) {
        return NULL;
    }

839
    rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback);
840
    if (rc != SQLITE_OK) {
841
        /* Workaround for SQLite bug: no error code or string is available here */
842
        PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");
843 844
        return NULL;
    } else {
845 846
        if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1)
            return NULL;
847 848 849 850 851 852

        Py_INCREF(Py_None);
        return Py_None;
    }
}

853
static int _authorizer_callback(void* user_arg, int action, const char* arg1, const char* arg2 , const char* dbname, const char* access_attempt_source)
854 855 856
{
    PyObject *ret;
    int rc;
857
#ifdef WITH_THREAD
858 859 860
    PyGILState_STATE gilstate;

    gilstate = PyGILState_Ensure();
861
#endif
862 863 864 865 866 867 868 869 870 871 872
    ret = PyObject_CallFunction((PyObject*)user_arg, "issss", action, arg1, arg2, dbname, access_attempt_source);

    if (!ret) {
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }

        rc = SQLITE_DENY;
    } else {
873
        if (PyLong_Check(ret)) {
874 875 876
            rc = _PyLong_AsInt(ret);
            if (rc == -1 && PyErr_Occurred())
                rc = SQLITE_DENY;
877 878 879 880 881 882
        } else {
            rc = SQLITE_DENY;
        }
        Py_DECREF(ret);
    }

883
#ifdef WITH_THREAD
884
    PyGILState_Release(gilstate);
885
#endif
886 887 888
    return rc;
}

889 890 891 892
static int _progress_handler(void* user_arg)
{
    int rc;
    PyObject *ret;
893
#ifdef WITH_THREAD
894 895 896
    PyGILState_STATE gilstate;

    gilstate = PyGILState_Ensure();
897
#endif
898 899 900 901 902 903 904 905 906
    ret = PyObject_CallFunction((PyObject*)user_arg, "");

    if (!ret) {
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }

907
        /* abort query if error occurred */
908
        rc = 1;
909 910 911 912 913
    } else {
        rc = (int)PyObject_IsTrue(ret);
        Py_DECREF(ret);
    }

914
#ifdef WITH_THREAD
915
    PyGILState_Release(gilstate);
916
#endif
917 918 919
    return rc;
}

920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
static void _trace_callback(void* user_arg, const char* statement_string)
{
    PyObject *py_statement = NULL;
    PyObject *ret = NULL;

#ifdef WITH_THREAD
    PyGILState_STATE gilstate;

    gilstate = PyGILState_Ensure();
#endif
    py_statement = PyUnicode_DecodeUTF8(statement_string,
            strlen(statement_string), "replace");
    if (py_statement) {
        ret = PyObject_CallFunctionObjArgs((PyObject*)user_arg, py_statement, NULL);
        Py_DECREF(py_statement);
    }

    if (ret) {
        Py_DECREF(ret);
    } else {
        if (_enable_callback_tracebacks) {
            PyErr_Print();
        } else {
            PyErr_Clear();
        }
    }

#ifdef WITH_THREAD
    PyGILState_Release(gilstate);
#endif
}

952
static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
953 954 955 956 957 958
{
    PyObject* authorizer_cb;

    static char *kwlist[] = { "authorizer_callback", NULL };
    int rc;

959 960 961 962
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

963 964 965 966 967 968 969 970
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_authorizer",
                                      kwlist, &authorizer_cb)) {
        return NULL;
    }

    rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);

    if (rc != SQLITE_OK) {
971
        PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback");
972 973
        return NULL;
    } else {
974 975
        if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1)
            return NULL;
976 977 978 979 980 981

        Py_INCREF(Py_None);
        return Py_None;
    }
}

982
static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
983 984 985 986 987 988
{
    PyObject* progress_handler;
    int n;

    static char *kwlist[] = { "progress_handler", "n", NULL };

989 990 991 992
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

993 994 995 996 997 998 999 1000 1001 1002
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler",
                                      kwlist, &progress_handler, &n)) {
        return NULL;
    }

    if (progress_handler == Py_None) {
        /* None clears the progress handler previously set */
        sqlite3_progress_handler(self->db, 0, 0, (void*)0);
    } else {
        sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
1003 1004
        if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1)
            return NULL;
1005 1006 1007 1008 1009 1010
    }

    Py_INCREF(Py_None);
    return Py_None;
}

1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{
    PyObject* trace_callback;

    static char *kwlist[] = { "trace_callback", NULL };

    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_trace_callback",
                                      kwlist, &trace_callback)) {
        return NULL;
    }

    if (trace_callback == Py_None) {
        /* None clears the trace callback previously set */
        sqlite3_trace(self->db, 0, (void*)0);
    } else {
        if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
            return NULL;
        sqlite3_trace(self->db, _trace_callback, trace_callback);
    }

    Py_INCREF(Py_None);
    return Py_None;
}

1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
#ifdef HAVE_LOAD_EXTENSION
static PyObject* pysqlite_enable_load_extension(pysqlite_Connection* self, PyObject* args)
{
    int rc;
    int onoff;

    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

    if (!PyArg_ParseTuple(args, "i", &onoff)) {
        return NULL;
    }

    rc = sqlite3_enable_load_extension(self->db, onoff);

    if (rc != SQLITE_OK) {
        PyErr_SetString(pysqlite_OperationalError, "Error enabling load extension");
        return NULL;
    } else {
        Py_INCREF(Py_None);
        return Py_None;
    }
}

static PyObject* pysqlite_load_extension(pysqlite_Connection* self, PyObject* args)
{
    int rc;
    char* extension_name;
    char* errmsg;

    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

    if (!PyArg_ParseTuple(args, "s", &extension_name)) {
        return NULL;
    }

    rc = sqlite3_load_extension(self->db, extension_name, 0, &errmsg);
    if (rc != 0) {
        PyErr_SetString(pysqlite_OperationalError, errmsg);
        return NULL;
    } else {
        Py_INCREF(Py_None);
        return Py_None;
    }
}
#endif

1089
int pysqlite_check_thread(pysqlite_Connection* self)
1090
{
1091
#ifdef WITH_THREAD
1092 1093
    if (self->check_same_thread) {
        if (PyThread_get_thread_ident() != self->thread_ident) {
1094
            PyErr_Format(pysqlite_ProgrammingError,
1095 1096 1097 1098 1099 1100 1101
                        "SQLite objects created in a thread can only be used in that same thread."
                        "The object was created in thread id %ld and this is thread id %ld",
                        self->thread_ident, PyThread_get_thread_ident());
            return 0;
        }

    }
1102
#endif
1103 1104 1105
    return 1;
}

1106
static PyObject* pysqlite_connection_get_isolation_level(pysqlite_Connection* self, void* unused)
1107 1108 1109 1110 1111
{
    Py_INCREF(self->isolation_level);
    return self->isolation_level;
}

1112
static PyObject* pysqlite_connection_get_total_changes(pysqlite_Connection* self, void* unused)
1113
{
1114
    if (!pysqlite_check_connection(self)) {
1115 1116 1117 1118 1119 1120
        return NULL;
    } else {
        return Py_BuildValue("i", sqlite3_total_changes(self->db));
    }
}

1121
static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level)
1122 1123 1124
{
    PyObject* res;
    PyObject* begin_statement;
1125
    static PyObject* begin_word;
1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137

    Py_XDECREF(self->isolation_level);

    if (self->begin_statement) {
        PyMem_Free(self->begin_statement);
        self->begin_statement = NULL;
    }

    if (isolation_level == Py_None) {
        Py_INCREF(Py_None);
        self->isolation_level = Py_None;

1138
        res = pysqlite_connection_commit(self, NULL);
1139 1140 1141 1142 1143 1144 1145
        if (!res) {
            return -1;
        }
        Py_DECREF(res);

        self->inTransaction = 0;
    } else {
1146 1147 1148
        const char *statement;
        Py_ssize_t size;

1149 1150 1151
        Py_INCREF(isolation_level);
        self->isolation_level = isolation_level;

1152 1153 1154
        if (!begin_word) {
            begin_word = PyUnicode_FromString("BEGIN ");
            if (!begin_word) return -1;
1155
        }
1156
        begin_statement = PyUnicode_Concat(begin_word, isolation_level);
1157 1158 1159 1160
        if (!begin_statement) {
            return -1;
        }

1161
        statement = _PyUnicode_AsStringAndSize(begin_statement, &size);
Georg Brandl's avatar
Georg Brandl committed
1162
        if (!statement) {
1163
            Py_DECREF(begin_statement);
Georg Brandl's avatar
Georg Brandl committed
1164 1165
            return -1;
        }
1166
        self->begin_statement = PyMem_Malloc(size + 2);
1167
        if (!self->begin_statement) {
Georg Brandl's avatar
Georg Brandl committed
1168
            Py_DECREF(begin_statement);
1169 1170 1171
            return -1;
        }

1172
        strcpy(self->begin_statement, statement);
1173 1174 1175 1176 1177 1178
        Py_DECREF(begin_statement);
    }

    return 0;
}

1179
PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
1180 1181
{
    PyObject* sql;
1182
    pysqlite_Statement* statement;
1183
    PyObject* weakref;
1184 1185
    int rc;

1186 1187 1188 1189
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
        return NULL;
    }

1190 1191 1192 1193
    if (!PyArg_ParseTuple(args, "O", &sql)) {
        return NULL;
    }

1194
    _pysqlite_drop_unused_statement_references(self);
1195

1196
    statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType);
1197 1198 1199 1200
    if (!statement) {
        return NULL;
    }

1201 1202 1203 1204 1205 1206
    statement->db = NULL;
    statement->st = NULL;
    statement->sql = NULL;
    statement->in_use = 0;
    statement->in_weakreflist = NULL;

1207
    rc = pysqlite_statement_create(statement, self, sql);
1208 1209 1210

    if (rc != SQLITE_OK) {
        if (rc == PYSQLITE_TOO_MUCH_SQL) {
1211
            PyErr_SetString(pysqlite_Warning, "You can only execute one statement at a time.");
1212
        } else if (rc == PYSQLITE_SQL_WRONG_TYPE) {
1213
            PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode.");
1214
        } else {
1215 1216
            (void)pysqlite_statement_reset(statement);
            _pysqlite_seterror(self->db, NULL);
1217 1218
        }

1219
        Py_CLEAR(statement);
1220 1221 1222
    } else {
        weakref = PyWeakref_NewRef((PyObject*)statement, NULL);
        if (!weakref) {
1223
            Py_CLEAR(statement);
1224 1225 1226 1227
            goto error;
        }

        if (PyList_Append(self->statements, weakref) != 0) {
1228
            Py_CLEAR(weakref);
1229 1230 1231 1232
            goto error;
        }

        Py_DECREF(weakref);
1233 1234
    }

1235
error:
1236 1237 1238
    return (PyObject*)statement;
}

1239
PyObject* pysqlite_connection_execute(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
1240 1241 1242 1243
{
    PyObject* cursor = 0;
    PyObject* result = 0;
    PyObject* method = 0;
1244
    _Py_IDENTIFIER(cursor);
1245

1246
    cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
1247 1248 1249 1250 1251 1252
    if (!cursor) {
        goto error;
    }

    method = PyObject_GetAttrString(cursor, "execute");
    if (!method) {
1253
        Py_CLEAR(cursor);
1254 1255 1256 1257 1258
        goto error;
    }

    result = PyObject_CallObject(method, args);
    if (!result) {
1259
        Py_CLEAR(cursor);
1260 1261 1262 1263 1264 1265 1266 1267 1268
    }

error:
    Py_XDECREF(result);
    Py_XDECREF(method);

    return cursor;
}

1269
PyObject* pysqlite_connection_executemany(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
1270 1271 1272 1273
{
    PyObject* cursor = 0;
    PyObject* result = 0;
    PyObject* method = 0;
1274
    _Py_IDENTIFIER(cursor);
1275

1276
    cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
1277 1278 1279 1280 1281 1282
    if (!cursor) {
        goto error;
    }

    method = PyObject_GetAttrString(cursor, "executemany");
    if (!method) {
1283
        Py_CLEAR(cursor);
1284 1285 1286 1287 1288
        goto error;
    }

    result = PyObject_CallObject(method, args);
    if (!result) {
1289
        Py_CLEAR(cursor);
1290 1291 1292 1293 1294 1295 1296 1297 1298
    }

error:
    Py_XDECREF(result);
    Py_XDECREF(method);

    return cursor;
}

1299
PyObject* pysqlite_connection_executescript(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
1300 1301 1302 1303
{
    PyObject* cursor = 0;
    PyObject* result = 0;
    PyObject* method = 0;
1304
    _Py_IDENTIFIER(cursor);
1305

1306
    cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
1307 1308 1309 1310 1311 1312
    if (!cursor) {
        goto error;
    }

    method = PyObject_GetAttrString(cursor, "executescript");
    if (!method) {
1313
        Py_CLEAR(cursor);
1314 1315 1316 1317 1318
        goto error;
    }

    result = PyObject_CallObject(method, args);
    if (!result) {
1319
        Py_CLEAR(cursor);
1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
    }

error:
    Py_XDECREF(result);
    Py_XDECREF(method);

    return cursor;
}

/* ------------------------- COLLATION CODE ------------------------ */

static int
1332
pysqlite_collation_callback(
1333 1334 1335 1336 1337 1338 1339
        void* context,
        int text1_length, const void* text1_data,
        int text2_length, const void* text2_data)
{
    PyObject* callback = (PyObject*)context;
    PyObject* string1 = 0;
    PyObject* string2 = 0;
1340
#ifdef WITH_THREAD
1341
    PyGILState_STATE gilstate;
1342
#endif
1343
    PyObject* retval = NULL;
1344
    long longval;
1345
    int result = 0;
1346
#ifdef WITH_THREAD
1347
    gilstate = PyGILState_Ensure();
1348
#endif
1349 1350 1351 1352 1353

    if (PyErr_Occurred()) {
        goto finally;
    }

1354 1355
    string1 = PyUnicode_FromStringAndSize((const char*)text1_data, text1_length);
    string2 = PyUnicode_FromStringAndSize((const char*)text2_data, text2_length);
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367

    if (!string1 || !string2) {
        goto finally; /* failed to allocate strings */
    }

    retval = PyObject_CallFunctionObjArgs(callback, string1, string2, NULL);

    if (!retval) {
        /* execution failed */
        goto finally;
    }

1368 1369 1370
    longval = PyLong_AsLongAndOverflow(retval, &result);
    if (longval == -1 && PyErr_Occurred()) {
        PyErr_Clear();
1371 1372
        result = 0;
    }
1373 1374 1375 1376 1377 1378
    else if (!result) {
        if (longval > 0)
            result = 1;
        else if (longval < 0)
            result = -1;
    }
1379 1380 1381 1382 1383

finally:
    Py_XDECREF(string1);
    Py_XDECREF(string2);
    Py_XDECREF(retval);
1384
#ifdef WITH_THREAD
1385
    PyGILState_Release(gilstate);
1386
#endif
1387 1388 1389
    return result;
}

1390
static PyObject *
1391
pysqlite_connection_interrupt(pysqlite_Connection* self, PyObject* args)
1392 1393 1394
{
    PyObject* retval = NULL;

1395
    if (!pysqlite_check_connection(self)) {
1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407
        goto finally;
    }

    sqlite3_interrupt(self->db);

    Py_INCREF(Py_None);
    retval = Py_None;

finally:
    return retval;
}

1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
/* Function author: Paul Kippes <kippesp@gmail.com>
 * Class method of Connection to call the Python function _iterdump
 * of the sqlite3 module.
 */
static PyObject *
pysqlite_connection_iterdump(pysqlite_Connection* self, PyObject* args)
{
    PyObject* retval = NULL;
    PyObject* module = NULL;
    PyObject* module_dict;
    PyObject* pyfn_iterdump;

    if (!pysqlite_check_connection(self)) {
        goto finally;
    }

    module = PyImport_ImportModule(MODULE_NAME ".dump");
    if (!module) {
        goto finally;
    }

    module_dict = PyModule_GetDict(module);
    if (!module_dict) {
        goto finally;
    }

    pyfn_iterdump = PyDict_GetItemString(module_dict, "_iterdump");
    if (!pyfn_iterdump) {
        PyErr_SetString(pysqlite_OperationalError, "Failed to obtain _iterdump() reference");
        goto finally;
    }

    args = PyTuple_New(1);
    if (!args) {
        goto finally;
    }
    Py_INCREF(self);
    PyTuple_SetItem(args, 0, (PyObject*)self);
    retval = PyObject_CallObject(pyfn_iterdump, args);

finally:
    Py_XDECREF(args);
    Py_XDECREF(module);
    return retval;
}

1454
static PyObject *
1455
pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
1456 1457 1458 1459 1460
{
    PyObject* callable;
    PyObject* uppercase_name = 0;
    PyObject* name;
    PyObject* retval;
1461
    Py_ssize_t i, len;
1462
    _Py_IDENTIFIER(upper);
1463
    char *uppercase_name_str;
1464
    int rc;
Martin v. Löwis's avatar
Martin v. Löwis committed
1465 1466
    unsigned int kind;
    void *data;
1467

1468
    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
1469 1470 1471
        goto finally;
    }

1472
    if (!PyArg_ParseTuple(args, "O!O:create_collation(name, callback)", &PyUnicode_Type, &name, &callable)) {
1473 1474 1475
        goto finally;
    }

1476
    uppercase_name = _PyObject_CallMethodId(name, &PyId_upper, "");
1477 1478 1479 1480
    if (!uppercase_name) {
        goto finally;
    }

Martin v. Löwis's avatar
Martin v. Löwis committed
1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
    if (PyUnicode_READY(uppercase_name))
        goto finally;
    len = PyUnicode_GET_LENGTH(uppercase_name);
    kind = PyUnicode_KIND(uppercase_name);
    data = PyUnicode_DATA(uppercase_name);
    for (i=0; i<len; i++) {
        Py_UCS4 ch = PyUnicode_READ(kind, data, i);
        if ((ch >= '0' && ch <= '9')
         || (ch >= 'A' && ch <= 'Z')
         || (ch == '_'))
1491
        {
1492
            continue;
1493
        } else {
1494
            PyErr_SetString(pysqlite_ProgrammingError, "invalid character in collation name");
1495 1496 1497 1498
            goto finally;
        }
    }

1499 1500 1501 1502
    uppercase_name_str = _PyUnicode_AsString(uppercase_name);
    if (!uppercase_name_str)
        goto finally;

1503 1504 1505 1506 1507 1508
    if (callable != Py_None && !PyCallable_Check(callable)) {
        PyErr_SetString(PyExc_TypeError, "parameter must be callable");
        goto finally;
    }

    if (callable != Py_None) {
1509 1510
        if (PyDict_SetItem(self->collations, uppercase_name, callable) == -1)
            goto finally;
1511
    } else {
1512 1513
        if (PyDict_DelItem(self->collations, uppercase_name) == -1)
            goto finally;
1514 1515 1516
    }

    rc = sqlite3_create_collation(self->db,
1517
                                  uppercase_name_str,
1518 1519
                                  SQLITE_UTF8,
                                  (callable != Py_None) ? callable : NULL,
1520
                                  (callable != Py_None) ? pysqlite_collation_callback : NULL);
1521 1522
    if (rc != SQLITE_OK) {
        PyDict_DelItem(self->collations, uppercase_name);
1523
        _pysqlite_seterror(self->db, NULL);
1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539
        goto finally;
    }

finally:
    Py_XDECREF(uppercase_name);

    if (PyErr_Occurred()) {
        retval = NULL;
    } else {
        Py_INCREF(Py_None);
        retval = Py_None;
    }

    return retval;
}

1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576
/* Called when the connection is used as a context manager. Returns itself as a
 * convenience to the caller. */
static PyObject *
pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args)
{
    Py_INCREF(self);
    return (PyObject*)self;
}

/** Called when the connection is used as a context manager. If there was any
 * exception, a rollback takes place; otherwise we commit. */
static PyObject *
pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args)
{
    PyObject* exc_type, *exc_value, *exc_tb;
    char* method_name;
    PyObject* result;

    if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) {
        return NULL;
    }

    if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) {
        method_name = "commit";
    } else {
        method_name = "rollback";
    }

    result = PyObject_CallMethod((PyObject*)self, method_name, "");
    if (!result) {
        return NULL;
    }
    Py_DECREF(result);

    Py_RETURN_FALSE;
}

1577
static char connection_doc[] =
1578
PyDoc_STR("SQLite database connection object.");
1579 1580

static PyGetSetDef connection_getset[] = {
1581 1582
    {"isolation_level",  (getter)pysqlite_connection_get_isolation_level, (setter)pysqlite_connection_set_isolation_level},
    {"total_changes",  (getter)pysqlite_connection_get_total_changes, (setter)0},
1583 1584 1585 1586
    {NULL}
};

static PyMethodDef connection_methods[] = {
1587
    {"cursor", (PyCFunction)pysqlite_connection_cursor, METH_VARARGS|METH_KEYWORDS,
1588
        PyDoc_STR("Return a cursor for the connection.")},
1589
    {"close", (PyCFunction)pysqlite_connection_close, METH_NOARGS,
1590
        PyDoc_STR("Closes the connection.")},
1591
    {"commit", (PyCFunction)pysqlite_connection_commit, METH_NOARGS,
1592
        PyDoc_STR("Commit the current transaction.")},
1593
    {"rollback", (PyCFunction)pysqlite_connection_rollback, METH_NOARGS,
1594
        PyDoc_STR("Roll back the current transaction.")},
1595
    {"create_function", (PyCFunction)pysqlite_connection_create_function, METH_VARARGS|METH_KEYWORDS,
1596
        PyDoc_STR("Creates a new function. Non-standard.")},
1597
    {"create_aggregate", (PyCFunction)pysqlite_connection_create_aggregate, METH_VARARGS|METH_KEYWORDS,
1598
        PyDoc_STR("Creates a new aggregate. Non-standard.")},
1599
    {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,
1600
        PyDoc_STR("Sets authorizer callback. Non-standard.")},
1601 1602 1603 1604 1605 1606
    #ifdef HAVE_LOAD_EXTENSION
    {"enable_load_extension", (PyCFunction)pysqlite_enable_load_extension, METH_VARARGS,
        PyDoc_STR("Enable dynamic loading of SQLite extension modules. Non-standard.")},
    {"load_extension", (PyCFunction)pysqlite_load_extension, METH_VARARGS,
        PyDoc_STR("Load SQLite extension module. Non-standard.")},
    #endif
1607 1608
    {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS,
        PyDoc_STR("Sets progress handler callback. Non-standard.")},
1609 1610
    {"set_trace_callback", (PyCFunction)pysqlite_connection_set_trace_callback, METH_VARARGS|METH_KEYWORDS,
        PyDoc_STR("Sets a trace callback called for each SQL statement (passed as unicode). Non-standard.")},
1611
    {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,
1612
        PyDoc_STR("Executes a SQL statement. Non-standard.")},
1613
    {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS,
1614
        PyDoc_STR("Repeatedly executes a SQL statement. Non-standard.")},
1615
    {"executescript", (PyCFunction)pysqlite_connection_executescript, METH_VARARGS,
1616
        PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")},
1617
    {"create_collation", (PyCFunction)pysqlite_connection_create_collation, METH_VARARGS,
1618
        PyDoc_STR("Creates a collation function. Non-standard.")},
1619
    {"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS,
1620
        PyDoc_STR("Abort any pending database operation. Non-standard.")},
1621
    {"iterdump", (PyCFunction)pysqlite_connection_iterdump, METH_NOARGS,
1622
        PyDoc_STR("Returns iterator to the dump of the database in an SQL text format. Non-standard.")},
1623 1624 1625 1626
    {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS,
        PyDoc_STR("For context manager. Non-standard.")},
    {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS,
        PyDoc_STR("For context manager. Non-standard.")},
1627 1628 1629 1630 1631
    {NULL, NULL}
};

static struct PyMemberDef connection_members[] =
{
1632 1633 1634 1635 1636 1637 1638 1639 1640 1641
    {"Warning", T_OBJECT, offsetof(pysqlite_Connection, Warning), READONLY},
    {"Error", T_OBJECT, offsetof(pysqlite_Connection, Error), READONLY},
    {"InterfaceError", T_OBJECT, offsetof(pysqlite_Connection, InterfaceError), READONLY},
    {"DatabaseError", T_OBJECT, offsetof(pysqlite_Connection, DatabaseError), READONLY},
    {"DataError", T_OBJECT, offsetof(pysqlite_Connection, DataError), READONLY},
    {"OperationalError", T_OBJECT, offsetof(pysqlite_Connection, OperationalError), READONLY},
    {"IntegrityError", T_OBJECT, offsetof(pysqlite_Connection, IntegrityError), READONLY},
    {"InternalError", T_OBJECT, offsetof(pysqlite_Connection, InternalError), READONLY},
    {"ProgrammingError", T_OBJECT, offsetof(pysqlite_Connection, ProgrammingError), READONLY},
    {"NotSupportedError", T_OBJECT, offsetof(pysqlite_Connection, NotSupportedError), READONLY},
1642 1643
    {"row_factory", T_OBJECT, offsetof(pysqlite_Connection, row_factory)},
    {"text_factory", T_OBJECT, offsetof(pysqlite_Connection, text_factory)},
1644
    {"in_transaction", T_BOOL, offsetof(pysqlite_Connection, inTransaction), READONLY},
1645 1646 1647
    {NULL}
};

1648
PyTypeObject pysqlite_ConnectionType = {
1649
        PyVarObject_HEAD_INIT(NULL, 0)
1650
        MODULE_NAME ".Connection",                      /* tp_name */
1651
        sizeof(pysqlite_Connection),                    /* tp_basicsize */
1652
        0,                                              /* tp_itemsize */
1653
        (destructor)pysqlite_connection_dealloc,        /* tp_dealloc */
1654 1655 1656
        0,                                              /* tp_print */
        0,                                              /* tp_getattr */
        0,                                              /* tp_setattr */
1657
        0,                                              /* tp_reserved */
1658 1659 1660 1661 1662
        0,                                              /* tp_repr */
        0,                                              /* tp_as_number */
        0,                                              /* tp_as_sequence */
        0,                                              /* tp_as_mapping */
        0,                                              /* tp_hash */
1663
        (ternaryfunc)pysqlite_connection_call,          /* tp_call */
1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683
        0,                                              /* tp_str */
        0,                                              /* tp_getattro */
        0,                                              /* tp_setattro */
        0,                                              /* tp_as_buffer */
        Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,         /* tp_flags */
        connection_doc,                                 /* tp_doc */
        0,                                              /* tp_traverse */
        0,                                              /* tp_clear */
        0,                                              /* tp_richcompare */
        0,                                              /* tp_weaklistoffset */
        0,                                              /* tp_iter */
        0,                                              /* tp_iternext */
        connection_methods,                             /* tp_methods */
        connection_members,                             /* tp_members */
        connection_getset,                              /* tp_getset */
        0,                                              /* tp_base */
        0,                                              /* tp_dict */
        0,                                              /* tp_descr_get */
        0,                                              /* tp_descr_set */
        0,                                              /* tp_dictoffset */
1684
        (initproc)pysqlite_connection_init,             /* tp_init */
1685 1686 1687 1688 1689
        0,                                              /* tp_alloc */
        0,                                              /* tp_new */
        0                                               /* tp_free */
};

1690
extern int pysqlite_connection_setup_types(void)
1691
{
1692 1693
    pysqlite_ConnectionType.tp_new = PyType_GenericNew;
    return PyType_Ready(&pysqlite_ConnectionType);
1694
}