Kaydet (Commit) 714c09b8 authored tarafından Malcolm Tredinnick's avatar Malcolm Tredinnick

Fixed #4831 -- Added an "add" cache key method, for parity with memcached's

API. This works for all cache backends. Patch from Matt McClanahan.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@6572 bcc190cf-cafb-0310-a4f2-bffc1f526a37
üst b678601d
...@@ -14,6 +14,14 @@ class BaseCache(object): ...@@ -14,6 +14,14 @@ class BaseCache(object):
timeout = 300 timeout = 300
self.default_timeout = timeout self.default_timeout = timeout
def add(self, key, value, timeout=None):
"""
Set a value in the cache if the key does not already exist. If
timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used.
"""
raise NotImplementedError
def get(self, key, default=None): def get(self, key, default=None):
""" """
Fetch a given key from the cache. If the key does not exist, return Fetch a given key from the cache. If the key does not exist, return
......
...@@ -24,6 +24,9 @@ class CacheClass(BaseCache): ...@@ -24,6 +24,9 @@ class CacheClass(BaseCache):
except (ValueError, TypeError): except (ValueError, TypeError):
self._cull_frequency = 3 self._cull_frequency = 3
def add(self, key, value, timeout=None):
return self._base_set('add', key, value, timeout)
def get(self, key, default=None): def get(self, key, default=None):
cursor = connection.cursor() cursor = connection.cursor()
cursor.execute("SELECT cache_key, value, expires FROM %s WHERE cache_key = %%s" % self._table, [key]) cursor.execute("SELECT cache_key, value, expires FROM %s WHERE cache_key = %%s" % self._table, [key])
...@@ -38,6 +41,9 @@ class CacheClass(BaseCache): ...@@ -38,6 +41,9 @@ class CacheClass(BaseCache):
return pickle.loads(base64.decodestring(row[1])) return pickle.loads(base64.decodestring(row[1]))
def set(self, key, value, timeout=None): def set(self, key, value, timeout=None):
return self._base_set('set', key, value, timeout)
def _base_set(self, mode, key, value, timeout=None):
if timeout is None: if timeout is None:
timeout = self.default_timeout timeout = self.default_timeout
cursor = connection.cursor() cursor = connection.cursor()
...@@ -50,10 +56,11 @@ class CacheClass(BaseCache): ...@@ -50,10 +56,11 @@ class CacheClass(BaseCache):
encoded = base64.encodestring(pickle.dumps(value, 2)).strip() encoded = base64.encodestring(pickle.dumps(value, 2)).strip()
cursor.execute("SELECT cache_key FROM %s WHERE cache_key = %%s" % self._table, [key]) cursor.execute("SELECT cache_key FROM %s WHERE cache_key = %%s" % self._table, [key])
try: try:
if cursor.fetchone(): if mode == 'set' and cursor.fetchone():
cursor.execute("UPDATE %s SET value = %%s, expires = %%s WHERE cache_key = %%s" % self._table, [encoded, str(exp), key]) cursor.execute("UPDATE %s SET value = %%s, expires = %%s WHERE cache_key = %%s" % self._table, [encoded, str(exp), key])
else: else:
cursor.execute("INSERT INTO %s (cache_key, value, expires) VALUES (%%s, %%s, %%s)" % self._table, [key, encoded, str(exp)]) if mode == 'add':
cursor.execute("INSERT INTO %s (cache_key, value, expires) VALUES (%%s, %%s, %%s)" % self._table, [key, encoded, str(exp)])
except DatabaseError: except DatabaseError:
# To be threadsafe, updates/inserts are allowed to fail silently # To be threadsafe, updates/inserts are allowed to fail silently
pass pass
......
...@@ -6,6 +6,9 @@ class CacheClass(BaseCache): ...@@ -6,6 +6,9 @@ class CacheClass(BaseCache):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
def add(self, *args, **kwargs):
pass
def get(self, key, default=None): def get(self, key, default=None):
return default return default
......
...@@ -17,6 +17,26 @@ class CacheClass(SimpleCacheClass): ...@@ -17,6 +17,26 @@ class CacheClass(SimpleCacheClass):
del self._cache del self._cache
del self._expire_info del self._expire_info
def add(self, key, value, timeout=None):
fname = self._key_to_file(key)
if timeout is None:
timeout = self.default_timeout
try:
filelist = os.listdir(self._dir)
except (IOError, OSError):
self._createdir()
filelist = []
if len(filelist) > self._max_entries:
self._cull(filelist)
if os.path.basename(fname) not in filelist:
try:
f = open(fname, 'wb')
now = time.time()
pickle.dump(now + timeout, f, 2)
pickle.dump(value, f, 2)
except (IOError, OSError):
pass
def get(self, key, default=None): def get(self, key, default=None):
fname = self._key_to_file(key) fname = self._key_to_file(key)
try: try:
......
...@@ -14,6 +14,13 @@ class CacheClass(SimpleCacheClass): ...@@ -14,6 +14,13 @@ class CacheClass(SimpleCacheClass):
SimpleCacheClass.__init__(self, host, params) SimpleCacheClass.__init__(self, host, params)
self._lock = RWLock() self._lock = RWLock()
def add(self, key, value, timeout=None):
self._lock.writer_enters()
try:
SimpleCacheClass.add(self, key, value, timeout)
finally:
self._lock.writer_leaves()
def get(self, key, default=None): def get(self, key, default=None):
should_delete = False should_delete = False
self._lock.reader_enters() self._lock.reader_enters()
......
...@@ -16,6 +16,9 @@ class CacheClass(BaseCache): ...@@ -16,6 +16,9 @@ class CacheClass(BaseCache):
BaseCache.__init__(self, params) BaseCache.__init__(self, params)
self._cache = memcache.Client(server.split(';')) self._cache = memcache.Client(server.split(';'))
def add(self, key, value, timeout=0):
self._cache.add(key.encode('ascii', 'ignore'), value, timeout or self.default_timeout)
def get(self, key, default=None): def get(self, key, default=None):
val = self._cache.get(smart_str(key)) val = self._cache.get(smart_str(key))
if val is None: if val is None:
......
...@@ -21,6 +21,15 @@ class CacheClass(BaseCache): ...@@ -21,6 +21,15 @@ class CacheClass(BaseCache):
except (ValueError, TypeError): except (ValueError, TypeError):
self._cull_frequency = 3 self._cull_frequency = 3
def add(self, key, value, timeout=None):
if len(self._cache) >= self._max_entries:
self._cull()
if timeout is None:
timeout = self.default_timeout
if key not in self._cache.keys():
self._cache[key] = value
self._expire_info[key] = time.time() + timeout
def get(self, key, default=None): def get(self, key, default=None):
now = time.time() now = time.time()
exp = self._expire_info.get(key) exp = self._expire_info.get(key)
......
...@@ -326,6 +326,15 @@ get() can take a ``default`` argument:: ...@@ -326,6 +326,15 @@ get() can take a ``default`` argument::
>>> cache.get('my_key', 'has expired') >>> cache.get('my_key', 'has expired')
'has expired' 'has expired'
To add a key only if it doesn't already exist, there is an add() method. It
takes the same parameters as set(), but will not attempt to update the cache
if the key specified is already present::
>>> cache.set('add_key', 'Initial value')
>>> cache.add('add_key', 'New value')
>>> cache.get('add_key')
'Initial value'
There's also a get_many() interface that only hits the cache once. get_many() There's also a get_many() interface that only hits the cache once. get_many()
returns a dictionary with all the keys you asked for that actually exist in the returns a dictionary with all the keys you asked for that actually exist in the
cache (and haven't expired):: cache (and haven't expired)::
......
...@@ -19,6 +19,12 @@ class Cache(unittest.TestCase): ...@@ -19,6 +19,12 @@ class Cache(unittest.TestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertEqual(cache.get("key"), "value") self.assertEqual(cache.get("key"), "value")
def test_add(self):
# test add (only add if key isn't already in cache)
cache.add("addkey1", "value")
cache.add("addkey1", "newvalue")
self.assertEqual(cache.get("addkey1"), "value")
def test_non_existent(self): def test_non_existent(self):
# get with non-existent keys # get with non-existent keys
self.assertEqual(cache.get("does_not_exist"), None) self.assertEqual(cache.get("does_not_exist"), None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment