分布式锁 以及 频率限制器

xiaoxiao2021-02-28  56

最近工作中遇到几个问题,记录一下

1 分布式锁

    最近遇到一个类似减库存问题,原本以为使用数据库事务,加上这个业务调用不频繁,应该没有问题。但是测试环境发现

前端兄弟有个bug每次提交了两次,项目使用f5代理,两个节点,刚好一边一个请求,造成数据库数据混乱。发现后,我先将

serviceImpl中的@Transactional的isolation设置成Isolation.SERIALIZABLE,虽然这样数据库事务并发性最差,但是业务比较

简单,觉得可以解决问题。发布测试,发现出现了二种情况,分别是两个都成功,一个成功一个数据库提交事务异常。

   按照设想,数据库序列化事务,应该会一前一后都成功,可是却没有这样,数据库使用的oracle,不知道数据库事务最高隔离级别序列化是如何执行的。虽然前端可以修改那个bug,但是假如遇到高并发还是有这个风险,还是要解决,想到了之前看curator时候有分布式锁,觉得可以使用。参考了curator,重新写了一下,使用redis实现了,并且自测使用500条线程竞争,业务还是很快正确执行完毕。回到家记录下来,然后看了下zkClient接口,使用zookeeper也实现了,但是我没有测试。

代码如下:

package com.test.util.lock; import java.util.concurrent.TimeUnit; public interface DistributedLock { void acquire(String lockKey) throws Exception; /** * String lockValue = null; * try{ * lockValue = lock.acquire(lockKey,time,unit); * ... * * return x; * }catch(Exception e){ * e.printStackTrace(); * return y; * }finnaly{ * if(lockValue != null){ * lock.release(lockKey, lockValue); * } * } * * * @param lockKey * @param time * @param unit * @return lockValue non-null if acquired, null if not * @throws Exception */ String acquire(String lockKey, long time, TimeUnit unit) throws Exception; /** * * @param lockKey * @param lockValue * @return true if released, false if not */ boolean release(String lockKey, String lockValue); } package com.test.util.lock; import java.io.IOException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** * 参考 * {@link org.apache.curator.framework.recipes.locks.InterProcessMutex} * */ public abstract class AbstractDistributedLock implements DistributedLock { private final ConcurrentMap<Thread, LockData> threadData = new ConcurrentHashMap<>(); private static class LockData { final Thread owningThread; final String lockValue; final AtomicInteger lockCount = new AtomicInteger(1); private LockData(Thread owningThread, String lockValue) { this.owningThread = owningThread; this.lockValue = lockValue; } } private static final String LOCK_NAME = "lock-"; public void acquire(String lockKey) throws Exception { if(acquire(lockKey, -1, null) == null){ throw new IOException("Lost connection while trying to acquire lock: " + lockKey); } } public String acquire(String lockKey, long time, TimeUnit unit) throws Exception { lockKey = this.preLockKey(lockKey); if(time < 0 || unit == null){ return this.internalLock(lockKey, -1); } long millis = unit.toMillis(time); return this.internalLock(lockKey, millis); } public boolean release(String lockKey, String lockValue) { lockKey = this.preLockKey(lockKey); Thread currentThread = Thread.currentThread(); LockData lockData = threadData.get(currentThread); if (lockData == null) { throw new IllegalMonitorStateException("You do not own the lock: " + lockKey); } if (!lockData.lockValue.equals(lockValue)) { throw new IllegalMonitorStateException("Wrong lockValue for the lock: " + lockKey); } int newLockCount = lockData.lockCount.decrementAndGet(); if (newLockCount > 0) { return true; } if (newLockCount < 0) { throw new IllegalMonitorStateException("Lock count has gone negative for lock: " + lockKey); } try { return doRelease(lockKey, lockValue); } finally { threadData.remove(currentThread); } } /** * * @param lockKey * @param millis * @return */ protected abstract String doAcquire(String lockKey, long millis); /** * * @param lockKey * @param lockValue * @return */ protected abstract boolean doRelease(String lockKey, String lockValue); private String internalLock(String lockKey, long millis) throws Exception { Thread currentThread = Thread.currentThread(); LockData lockData = threadData.get(currentThread); if (lockData != null) { // re-entering lockData.lockCount.incrementAndGet(); return lockData.lockValue; } String lockValue = doAcquire(lockKey, millis); if (lockValue != null) { LockData newLockData = new LockData(currentThread, lockValue); threadData.put(currentThread, newLockData); return lockValue; } return null; } private String preLockKey(String lockKey) { if (lockKey == null) return LOCK_NAME; return LOCK_NAME + lockKey; } } package com.test.util.lock; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.LockSupport; import redis.clients.jedis.JedisCluster; public class RedisDistributedLock extends AbstractDistributedLock{ private static final long defaultForever = 3600000; //默认永久获取时间 3600秒 private static final int defaultExpire = 15; //默认锁有效时间 防止没有释放 15秒 private JedisCluster jedisCluster; private long forever; private int expire; public RedisDistributedLock() { super(); this.forever = defaultForever; this.expire = defaultExpire; } public RedisDistributedLock(long forever, int expire) { super(); this.forever = forever; this.expire = expire; } @Override protected String doAcquire(String lockKey, long millis) { long endLine = 0; long now = System.currentTimeMillis(); if(millis < 0){ endLine = now + forever; }else if(millis == 0){ endLine = now; }else{ endLine = now + millis; } String lockValue = UUID.randomUUID().toString(); String setnx = "0"; while(now <= endLine && !"OK".equals(setnx)){ /* 在指定的 key 不存在时,为 key 设置指定的值 * 设置成功,返回 1 ; 设置失败,返回 0 */ setnx = jedisCluster.set(lockKey, lockValue, "NX", "EX", expire); //setnx = jedisCluster.setnx(lockKey, lockValue); if("OK".equals(setnx)){ return lockValue; } if(!"OK".equals(setnx)){ //等待1s LockSupport.parkNanos(TimeUnit.SECONDS.toNanos(1)); now = System.currentTimeMillis(); } } return null; } @Override protected boolean doRelease(String lockKey, String lockValue) { String value = jedisCluster.get(lockKey); if(value == null){//expired return true; } if(value.equals(lockValue)){ jedisCluster.del(lockKey); return true; } return false; } public JedisCluster getJedisCluster() { return jedisCluster; } public void setJedisCluster(JedisCluster jedisCluster) { this.jedisCluster = jedisCluster; } }

package com.test.util.lock; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.LockSupport; import org.I0Itec.zkclient.ZkClient; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; public class ZooKeeperDistributedLock extends AbstractDistributedLock implements InitializingBean,DisposableBean{ private static final String ROOT_PATH = "/lock"; private ZkClient zkClient; private String zkServers; private int connectionTimeout; private long forever = 3600000; @Override protected String doAcquire(String lockKey, long millis) { this.createRoot(); String lockKeyPath = this.lockKeyPath(lockKey); long endLine = 0; long now = System.currentTimeMillis(); if(millis < 0){ endLine = now + forever; }else if(millis == 0){ endLine = now; }else{ endLine = now + millis; } String lockValue = UUID.randomUUID().toString(); boolean retry = true; while(now <= endLine && retry){ try{ zkClient.createEphemeral(lockKeyPath, lockValue); retry = false; }catch (Exception e) { e.printStackTrace(); } if(!retry){ return lockValue; } if(retry){ //等待1s LockSupport.parkNanos(TimeUnit.SECONDS.toNanos(1)); now = System.currentTimeMillis(); } } return null; } @Override protected boolean doRelease(String lockKey, String lockValue) { String lockKeyPath = this.lockKeyPath(lockKey); Object readData = this.zkClient.readData(lockKeyPath, true); if(readData == null){//无此节点 return true; } if(lockValue.equals(readData)){ return this.zkClient.delete(lockKeyPath); } return false; } @Override public void destroy() throws Exception { this.zkClient.close(); } @Override public void afterPropertiesSet() throws Exception { this.zkClient = new ZkClient(zkServers, connectionTimeout); } private void createRoot(){ boolean exists = this.zkClient.exists(ROOT_PATH); if(!exists){ this.zkClient.createPersistent(ROOT_PATH); } } private String lockKeyPath(String lockKey){ return ROOT_PATH+"/"+lockKey; } public String getZkServers() { return zkServers; } public void setZkServers(String zkServers) { this.zkServers = zkServers; } public int getConnectionTimeout() { return connectionTimeout; } public void setConnectionTimeout(int connectionTimeout) { this.connectionTimeout = connectionTimeout; } public long getForever() { return forever; } public void setForever(long forever) { this.forever = forever; } }

这种循环重试,感觉也可以,但是想到AbstractQueuedSynchronizer里面的实现方式,有这种思路,也使用一个list保存

线程信息,当然list肯定保存在redis中,AbstractQueuedSynchronizer主要是release时候唤醒下一个节点线程,如果我们这个

也想这种方式,是不是可以使用redis的订阅功能,每个节点订阅一个频道比如:redisLock,每个节点里面的线程在获取锁的时候如果不成功,构造一个状态对象放到一个map里面,同时线程park。每次头节点释放锁的时候,向redisLock里面发布下一个节点的key,然后每个节点收到这个消息,检查自己map里面是否有这个key的节点,有则唤醒线程,尝试获取锁。同理好像

Condition也可以实现。按照这种想法好像实现了一个分布式的ReentrantLock,但是比较复杂,中间涉及太多网络传输,是不是还不如循环重试呢

2 频率限制器 

假如需要对系统的接口做调用频率限制,防止恶意调用,可以使用频率限制器,常见的是使用令牌桶算法,参考了guava的

RateLimiter实现,我的实现如下

令牌桶,只写了内存存储的,后面加上redis的适应分布式系统

package com.test.util.ratelimit.ticket; public abstract class AbstractTicketBucket implements TicketBucket{ private double capacity;//容量 private long duration;//时间段 毫秒 public AbstractTicketBucket(double capacity, long duration) { super(); this.capacity = capacity; this.duration = duration; this.setTimes((int)Math.round(capacity)); this.setLastAccess(System.currentTimeMillis()); } /** * 分布式情况需要Override加分布式锁 * @return */ @Override public synchronized boolean access() { this.supply(); if(getTimes() > 0){ int times = getTimes() - 1; setTimes(times); return true; } return false; } protected abstract int getTimes(); protected abstract void setTimes(int times); protected abstract long getLastAccess(); protected abstract void setLastAccess(long lastAccess); private void supply(){ long now = System.currentTimeMillis(); double times = (now - getLastAccess())/duration * capacity + getTimes(); if(times > capacity){ setTimes((int)Math.round(capacity)); }else{ setTimes((int)Math.round(times)); } setLastAccess(now); } } package com.test.util.ratelimit.ticket; public class SimpleTicketBucket extends AbstractTicketBucket{ private volatile int times; private volatile long lastAccess; public SimpleTicketBucket(double capacity, long duration) { super(capacity, duration); } @Override protected int getTimes() { return this.times; } @Override protected void setTimes(int times) { this.times = times; } @Override protected long getLastAccess() { return this.lastAccess; } @Override protected void setLastAccess(long lastAccess) { this.lastAccess = lastAccess; } } RateLimiter

package com.test.util.ratelimit; /** * * */ public interface RateLimiter { /** * * @param url * @param ip * @param uid * @return true if support */ boolean support(String url, String ip, String uid); /** * * @param url * @param ip * @param uid * @return true if can access */ boolean access(String url, String ip, String uid); } package com.test.util.ratelimit; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import com.test.util.ratelimit.ticket.TicketBucket; public abstract class AbstractRateLimiter implements RateLimiter{ private final Map<String,TicketBucket> ticketBucketMap = new HashMap<>(); private static final ReadWriteLock rwl = new ReentrantReadWriteLock(); private static final Lock readLock = rwl.readLock(); private static final Lock writeLock = rwl.writeLock(); @Override public boolean access(String url, String ip, String uid) { readLock.lock(); TicketBucket ticketBucket = ticketBucketMap.get(buildKey(url, ip, uid)); if(ticketBucket == null){ readLock.unlock(); writeLock.lock(); try{ if(null == ticketBucketMap.get(buildKey(url, ip, uid))){ ticketBucket = newTicketBucket(url, ip, uid); ticketBucketMap.put(buildKey(url, ip, uid), ticketBucket); } ticketBucket = ticketBucketMap.get(buildKey(url, ip, uid)); readLock.lock(); }finally { writeLock.unlock(); } } try{ return ticketBucket.access(); }finally { readLock.unlock(); } } protected abstract String buildKey(String url, String ip, String uid); protected abstract TicketBucket newTicketBucket(String url, String ip, String uid); } package com.test.util.ratelimit; import com.test.util.ratelimit.ticket.SimpleTicketBucket; import com.test.util.ratelimit.ticket.TicketBucket; /** * 针对每个用户的某些url调用次数做限制 */ public class SimpleUidUrlRateLimiter extends AbstractRateLimiter{ private String[] urls; private double[] times; private long[] durations; public SimpleUidUrlRateLimiter(String[] urls, double[] times, long[] durations) { super(); this.urls = urls; this.times = times; this.durations = durations; if(urls.length != times.length || urls.length != durations.length || durations.length != times.length){ throw new IllegalArgumentException("error length"); } } @Override public boolean support(String url, String ip, String uid) { if(this.index(url) >= 0) return true; return false; } @Override protected String buildKey(String url, String ip, String uid) { if(uid == null) uid = "uid:"; return uid + url; } @Override protected TicketBucket newTicketBucket(String url, String ip, String uid) { int index = this.index(url); TicketBucket ticketBucket = new SimpleTicketBucket(times[index], durations[index]); return ticketBucket; } private int index(String url){ if(urls != null){ for (int i=0; i<urls.length; i++) { if(urls[i].indexOf(url) >= 0){ return i; } } } return -1; } } package com.test.util.ratelimit; import com.test.util.ratelimit.ticket.SimpleTicketBucket; import com.test.util.ratelimit.ticket.TicketBucket; /** * 针对每个用户的所有url调用次数做限制 */ public class SimpleUidRateLimiter extends AbstractRateLimiter{ private double times; private long durations; public SimpleUidRateLimiter(double times, long durations) { super(); this.times = times; this.durations = durations; } @Override public boolean support(String url, String ip, String uid) { return true; } @Override protected String buildKey(String url, String ip, String uid) { if(uid == null) return "uid:allUrl"; return uid + ":allUrl"; } @Override protected TicketBucket newTicketBucket(String url, String ip, String uid) { TicketBucket ticketBucket = new SimpleTicketBucket(times, durations); return ticketBucket; } } RateLimiterChain总起使用,放在过滤器里面

package com.test.util.ratelimit; import java.util.List; public class RateLimiterChain { private List<RateLimiter> rateLimiters; public boolean access(String url, String ip, String uid){ if(rateLimiters != null){ for (RateLimiter rateLimiter : rateLimiters) { if(rateLimiter.support(url, ip, uid)){ if(rateLimiter.access(url, ip, uid)){ return false; } } } } return true; } public List<RateLimiter> getRateLimiters() { return rateLimiters; } public void setRateLimiters(List<RateLimiter> rateLimiters) { this.rateLimiters = rateLimiters; } } 完整代码  https://github.com/renyiiiii/util.git

转载请注明原文地址: https://www.6miu.com/read-1599970.html

最新回复(0)