分布式主键ID生成方案

unimof 2021年02月19日 166次浏览

分库分表带来的好处是将数据读写压力分流到多个数据库,从而实现更大负载能力,但任何事物都有两面性,分库分表也有很多副作用,比如,增加程序复杂性,跨库事务处理,分布式主键 ID 也是其中最直接的问题。

起源

问题的起源很明显,分库分表前,所有数据在一个表,使用数据库表的唯一主键,可以简单可靠的实现数据主键ID的唯一性,不需要做任何处理;分库分表后,问题就来了,一个表的数据需要按照一定的规则分的不同的库或表中,那么就需要有一个机制,来保证所有表数据的并集ID不重复。

当然,对于静态数据来说,数据入库后,就不再改变,那就不存在问题,但这是极少数的情况,不在我们讨论范围之列。

解决方案

1. UUID

UUID是国际标准化组织(ISO)提出的一个概念。UUID是一个128比特的数值,这个数值可以通过一定的算法计算出来。为了提高效率,常用的UUID可缩短至16位。UUID用来识别属性类型,在所有空间和时间上被视为唯一的标识。一般来说,可以保证这个值是真正唯一的任何地方产生的任意一个UUID都不会有相同的值。使用UUID的一个好处是可以为新的服务创建新的标识符。这样一来,客户端在查找一个服务时,只需要在它的服务查找请求中指出与某类服务(或某个特定服务)有关的UUID,如果服务的提供者能将可用的服务与这个UUID相匹配,就返回一个响应。
UUID是基于当前时间、计数器(counter)和硬件标识(通常为无线网卡的MAC地址)等数据计算生成的。UUID可以被任何人独立创建,并按需发布。UUID没有集中管理机构,因为它们是不会被复制的独特标识符。属性协议允许设备使用UUID识别属性类型,从而不需要用读/写请求来识别它们的本地句柄。

目前最广泛应用的 UUID,即是微软的 Microsoft's Globally Unique Identifiers (GUIDs)。

基于UUID良好的特性,将自增主键改为uuid,可以有效的避免主键重复问题。但是一般情况下,不推荐使用 UUID 作为主键:

  • 生成和查询消耗性能比较大
  • 存储空间大

之所以会有上述差异,主要是innodb引擎采用B+树作为索引算法,Innodb的存储结构,是聚簇索引。对于聚簇索引顺序主键和随机主键的对效率的影响很大。

自增是顺序主键存储,查找和插入都很方便(插入会按顺序插到前一个的后面),但UUid是无序的,通过计算获得的hashcode也会是无序的(是按照hashcode选择存储位置)所以对于他的查找效率很低;而且因为他是无序的,他的插入有可能会插到前面的数据中,会造成很多其他的操作,很影响性能或者很多存储空间因为没有顺序的存储而被空缺浪费。

2. SNOWFLAKE

也就是 雪花算法。
雪花算法的思想和uuid类似,采用一定的规则(可以根据自己的业务需要进行设计)生成一个数据Id,只是生成的Id是一个大整数,通常是64位,这样就避免了UUID导致效率降低的问题,又有了UUID生成数据灵活、不重复的优点。

经典的雪花算法如下:

image.png

  • 1bit,不用,因为二进制中最高位是符号位,1表示负数,0表示正数。生成的id一般都是用整数,所以最高位固定为0。

  • 41bit-时间戳,用来记录时间戳,毫秒级。

    • 41位可以表示个数字,
    • 如果只用来表示正整数(计算机中正数包含0),可以表示的数值范围是:0 至 ,减1是因为可表示的数值范围是从0开始算的,而不是1。
    • 也就是说41位可以表示个毫秒的值,转化成单位年则是年
  • 10bit-工作机器id,用来记录工作机器id。

    • 可以部署在个节点,包括5位datacenterId和5位workerId
    • 5位(bit)可以表示的最大正整数是,即可以用0、1、2、3、....31这32个数字,来表示不同的datecenterId或workerId
  • 12bit-序列号,序列号,用来记录同毫秒内产生的不同id。

    • 12位(bit)可以表示的最大正整数是,即可以用0、1、2、3、....4094这4095个数字,来表示同一机器同一时间截(毫秒)内产生的4095个ID序号。

由于在Java中64bit的整数是long类型,所以在Java中SnowFlake算法生成的id就是long来存储的。根据该算法,可以实现如下优点:

  • 所有生成的id按时间趋势递增
  • 整个分布式系统内不会产生重复id(因为有datacenterId和workerId来做区分)

原始的雪花算法是Twitter的Scala版本:

/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake

import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger

/**
 * An object that generates IDs.
 * This is broken into a separate class in case
 * we ever want to support multiple worker threads
 * per process
 */
class IdWorker(
    val workerId: Long, 
    val datacenterId: Long, 
    private val reporter: Reporter, 
    var sequence: Long = 0L) extends Snowflake.Iface {
    
  private[this] def genCounter(agent: String) = {
    Stats.incr("ids_generated")
    Stats.incr("ids_generated_%s".format(agent))
  }
  private[this] val exceptionCounter = Stats.getCounter("exceptions")
  private[this] val log = Logger.get
  private[this] val rand = new Random

  val twepoch = 1288834974657L

  private[this] val workerIdBits = 5L
  private[this] val datacenterIdBits = 5L
  private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
  private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
  private[this] val sequenceBits = 12L

  private[this] val workerIdShift = sequenceBits
  private[this] val datacenterIdShift = sequenceBits + workerIdBits
  private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
  private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)

  private[this] var lastTimestamp = -1L

  // sanity check for workerId
  if (workerId > maxWorkerId || workerId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
  }

  if (datacenterId > maxDatacenterId || datacenterId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
  }

  log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
    timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)

  def get_id(useragent: String): Long = {
    if (!validUseragent(useragent)) {
      exceptionCounter.incr(1)
      throw new InvalidUserAgentError
    }

    val id = nextId()
    genCounter(useragent)

    reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
    id
  }

  def get_worker_id(): Long = workerId
  def get_datacenter_id(): Long = datacenterId
  def get_timestamp() = System.currentTimeMillis

  protected[snowflake] def nextId(): Long = synchronized {
    var timestamp = timeGen()

    if (timestamp < lastTimestamp) {
      exceptionCounter.incr(1)
      log.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
      throw new InvalidSystemClock("Clock moved backwards.  Refusing to generate id for %d milliseconds".format(
        lastTimestamp - timestamp))
    }

    if (lastTimestamp == timestamp) {
      sequence = (sequence + 1) & sequenceMask
      if (sequence == 0) {
        timestamp = tilNextMillis(lastTimestamp)
      }
    } else {
      sequence = 0
    }

    lastTimestamp = timestamp
    ((timestamp - twepoch) << timestampLeftShift) |
      (datacenterId << datacenterIdShift) |
      (workerId << workerIdShift) | 
      sequence
  }

  protected def tilNextMillis(lastTimestamp: Long): Long = {
    var timestamp = timeGen()
    while (timestamp <= lastTimestamp) {
      timestamp = timeGen()
    }
    timestamp
  }

  protected def timeGen(): Long = System.currentTimeMillis()

  val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r

  def validUseragent(useragent: String): Boolean = useragent match {
    case AgentParser(_) => true
    case _ => false
  }
}

Java 版本:

public class IdWorker{

    private long workerId;
    private long datacenterId;
    private long sequence;

    public IdWorker(long workerId, long datacenterId, long sequence){
        // sanity check for workerId
        if (workerId > maxWorkerId || workerId < 0) {
            throw new IllegalArgumentException(String.format("worker Id can't be greater than %d or less than 0",maxWorkerId));
        }
        if (datacenterId > maxDatacenterId || datacenterId < 0) {
            throw new IllegalArgumentException(String.format("datacenter Id can't be greater than %d or less than 0",maxDatacenterId));
        }
        System.out.printf("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
                timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId);

        this.workerId = workerId;
        this.datacenterId = datacenterId;
        this.sequence = sequence;
    }

    private long twepoch = 1288834974657L;

    private long workerIdBits = 5L;
    private long datacenterIdBits = 5L;
    private long maxWorkerId = -1L ^ (-1L << workerIdBits);
    private long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
    private long sequenceBits = 12L;

    private long workerIdShift = sequenceBits;
    private long datacenterIdShift = sequenceBits + workerIdBits;
    private long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
    private long sequenceMask = -1L ^ (-1L << sequenceBits);

    private long lastTimestamp = -1L;

    public long getWorkerId(){
        return workerId;
    }

    public long getDatacenterId(){
        return datacenterId;
    }

    public long getTimestamp(){
        return System.currentTimeMillis();
    }

    public synchronized long nextId() {
        long timestamp = timeGen();

        if (timestamp < lastTimestamp) {
            System.err.printf("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
            throw new RuntimeException(String.format("Clock moved backwards.  Refusing to generate id for %d milliseconds",
                    lastTimestamp - timestamp));
        }

        if (lastTimestamp == timestamp) {
            sequence = (sequence + 1) & sequenceMask;
            if (sequence == 0) {
                timestamp = tilNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0;
        }

        lastTimestamp = timestamp;
        return ((timestamp - twepoch) << timestampLeftShift) |
                (datacenterId << datacenterIdShift) |
                (workerId << workerIdShift) |
                sequence;
    }

    private long tilNextMillis(long lastTimestamp) {
        long timestamp = timeGen();
        while (timestamp <= lastTimestamp) {
            timestamp = timeGen();
        }
        return timestamp;
    }

    private long timeGen(){
        return System.currentTimeMillis();
    }

    //---------------测试---------------
    public static void main(String[] args) {
        IdWorker worker = new IdWorker(1,1,1);
        for (int i = 0; i < 30; i++) {
            System.out.println(worker.nextId());
        }
    }

}

引用

https://segmentfault.com/a/1190000011282426