The One Billion Row Challenge in Java
背景
The One Billion Row Challenge -- A fun exploration of how quickly 1B rows from a text file can be aggregated with Java
1B rows 约13GB
环境要求
- 硬件: 32core 128G mem
- JDK21
- Linux
实现
迭代1
- 使用 MappedByteBuffer 减少内核和用户空间的buffer copy
- 利用Runtime.getRuntime().availableProcessors() 划分chunk
- HashMap 作为气温的统计k,v=> byte[],MeasurementAggregator::count min sum max
效果不理想
使用Flame graph 分析
- hashmap::contains(key) 性能非常差
- key::hashCode and key::equals 性能糟糕
迭代2
优化
static class CalculateKey {
private final byte[] bytes;
private final int length;
private final int hash;
public CalculateKey(byte[] bytes, int length, int hash) {
this.bytes = bytes;
this.length = length;
this.hash = hash;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CalculateKey that = (CalculateKey) o;
return length == that.length && hash == that.hash && Arrays.equals(bytes, that.bytes);
}
@Override
public int hashCode() {
int result = Objects.hash(length, hash);
result = 31 * result + Arrays.hashCode(bytes);
return result;
}
}
-
移除hashmap::contains
-
优化hashcode:在读取byte时,即时生成key的hashcode,在key 的put操作时不在计算
- hashCode = 31 * hashCode + positionByte;
@Override
public int hashCode() {
return hash;
}
效果有提升1亿rows 本地环境(8core16G)2s多
profiler 分析
- MappedByteBuffer.get()
- hashmap.get
- new byte[] == key字节数组
这三个方法性能损耗最大,可以着手进行改进
迭代3
替换hashMap为自定义实现
内存布局:
-
ByteBuffer data;//original data
-
int[] index;mem的offset
-
long[] mem;每一个station的内存结构
-
mem[0] = byteOffset + mem[1]=keyLength + mem[2]=measurement + mem[3]=key's hashcode
- byteOffset 是data结构中key的start位置,结合keyLength,可以求得Key的byte[]
- key's hashcode ,结合keyLength,决定两个key equals
-
mem[4] ~ mem[7] = count + sum + min + max 统计数据
-
下一个station数据从mem[8]开始
-
一个station 8个long
-
以空间换时间
-
data 是引用,没有内存空间再分配,通过offset找到key
-
static class SimpleHashMap {
ByteBuffer data;//original data
private static final int STATION_SIZE = 8;//memory layout per station
private static final int CAPACITY = 1024 * 64;// station size
private static final int INDEX_MASK = CAPACITY - 1;
// index[] -> mem[]'s offset
int[] index;
// mem[0] = byteOffset + mem[1]=keyLength + mem[2]=measurement + mem[3]=key's hashcode
// mem[4] ~ mem[7] = count + sum + min + max
long[] mem;
public SimpleHashMap(ByteBuffer chunk) {
index = new int[CAPACITY];
mem = new long[CAPACITY * STATION_SIZE + 1];
data = chunk;
}
public void put(int byteOffset, int hash, long value, int keyLength, int memOffset) {
int bucket = hash & INDEX_MASK;
for (;; bucket = (bucket + 1) & INDEX_MASK) {
int offset = this.index[bucket];
if (offset == 0) {
this.index[bucket] = memOffset;
mem[memOffset] = byteOffset;
mem[memOffset + 1] = keyLength;
mem[memOffset + 2] = value;
mem[memOffset + 3] = hash;
mem[memOffset + 4] = 1;// count
mem[memOffset + 5] = value;// sum
mem[memOffset + 6] = value;// min
mem[memOffset + 7] = value;// max
break;
}
else {
int prevKeyLength = (int) mem[offset + 1];
int prevHash = (int) mem[offset + 3];
if (prevHash == hash && prevKeyLength == keyLength) {
mem[offset + 4] += 1;// count
mem[offset + 5] += value;// sum
mem[offset + 6] = Math.min(value, mem[offset + 6]);// min
mem[offset + 7] = Math.max(value, mem[offset + 7]);// max
break;
}
}
}
}
public int get(int hash) {
int bucket = hash & INDEX_MASK;
bucket = (bucket + 1) & INDEX_MASK;
return index[bucket];
}
void merge(Map<String, MeasurementAggregator> target) {
this.data.flip();
for (int i = 0; i < CAPACITY; i++) {
int offset = this.index[i];
if (offset == 0) {
continue;
}
int start = (int) mem[offset];
int keyLen = (int) mem[offset + 1];
byte[] keyByte = new byte[keyLen];
data.get(start, keyByte);
String key = new String(keyByte, StandardCharsets.UTF_8);
target.compute(key, (k, v) -> {
if (v == null) {
v = new MeasurementAggregator();
}
v.min = Math.min(v.min, mem[offset + 6]);
v.max = Math.max(v.max, mem[offset + 7]);
v.sum += mem[offset + 5];
v.count += mem[offset + 4];
return v;
});
}
}
}
EC2 c5a.8xlarge 实例 32core 64GB
[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_gumingcn.sh
real 0m2.068s
user 0m53.515s
sys 0m0.910s
//其他实现的time结果
[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_gonix.sh
real 0m1.023s
user 0m0.000s
sys 0m0.002s
[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_merykitty.sh
real 0m1.502s
user 0m32.113s
sys 0m1.183s
[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_thomaswue.sh
real 0m1.440s
user 0m34.203s
sys 0m0.912s
迭代4
如何优化 MappedByteBuffer.get()
上面gonix作者给出了一个方案
简单说,如果通过技巧找到';'和'\n',切分 station;measurement\n
ByteBuffer.getLong() 一次8个字节,可以减少循环读取的耗时
但如何找到';'呢?
Ascii code ';' - 91 十六进制 3B
private static long valueSepMark(long keyLong) {
// Seen this trick used in multiple other solutions.
// Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
long match = keyLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';'
match = (match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L);
return match;
}
如何快速找到'\n'呢?
- 数字1-9的ascii 二进制表示如6, 0011 0110 第4位是1
- 换行符 ascii 二进制表示 0000 1010 第4位是0
这样便可以通过位运算找到
private static int decimalSepMark(long value) {
// Seen this trick used in multiple other solutions.
// Looks like the original author is @merykitty.
// The 4th binary digit of the ascii of a digit is 1 while
// that of the '.' is 0. This finds the decimal separator
// The value can be 12, 20, 28
return Long.numberOfTrailingZeros(~value & 0x10101000);
}
详细见:read more
位运算的奇特技巧
同时优化内存布局
- key len 使用int即满足(<100)
- 单chunk offset int类型满足
- 使用一个long 既可以解决上面两个类型
- 位运算计算measurement double值
private static long tailAndLen(int tailBits, long keyLong, long keyLen) {
long tailMask = ~(-1L << tailBits);
long tail = keyLong & tailMask;
return (tail << 8) | ((keyLen >> 3) & 0xFF);
}
private static int decimalValue(int decimalSepMark, long value) {
// Seen this trick used in multiple other solutions.
// Looks like the original author is @merykitty.
int shift = 28 - decimalSepMark;
// signed is -1 if negative, 0 otherwise
long signed = (~value << 59) >> 63;
long designMask = ~(signed & 0xFF);
// Align the number to a specific position and transform the ascii code
// to actual digit value in each byte
long digits = ((value & designMask) << shift) & 0x0F000F0F00L;
// Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 +
// 0x00UU00TTHH000000 * 10 +
// 0xUU00TTHH00000000 * 100
// Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
// This results in our value lies in the bit 32 to 41 of this product
// That was close :)
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (int) ((absValue ^ signed) - signed);
}
- gonix的mem结构,对于key len<=8的station 4个index即可解决一个station,同样我的实现需要8个
- measurement 的计算同样通过位运算,非常高效
缺点是大量的位运算代码不易理解
总结
本文的实现:https://github.com/guming/1brc
- 没有使用Unsafe和 MemorySegment/ByteVector(Flink底层使用)再进行测试
- 没有优化jvm ops
原因,个人认为最终的优化还是内存布局和位运算简化byte读取 是关键所在,使用何种工具类并不是核心所在(当然性能会提升)