implement xchacha20 and x25519
This commit is contained in:
205
examples/x25519.zr
Normal file
205
examples/x25519.zr
Normal file
@@ -0,0 +1,205 @@
|
||||
func unpack25519[out: ptr, input: ptr] : void
|
||||
for i in 0..16
|
||||
mem.write64(out + i * 8, input[i * 2] + (input[i * 2 + 1] << 8))
|
||||
mem.write64(out + 8 * 15, mem.read64(out + 8 * 15) & 0x7fff)
|
||||
|
||||
func carry25519[elem: ptr] : void
|
||||
for i in 0..16
|
||||
let carry: i64 = mem.read64(elem + i * 8) >> 16
|
||||
mem.write64(elem + i * 8, mem.read64(elem + i * 8) - (carry << 16))
|
||||
if i < 15
|
||||
mem.write64(elem + (i + 1) * 8, mem.read64(elem + (i + 1) * 8) + carry)
|
||||
else
|
||||
mem.write64(elem, mem.read64(elem) + 38 * carry)
|
||||
|
||||
func fadd[out: ptr, a: ptr, b: ptr] : void
|
||||
for i in 0..16
|
||||
mem.write64(out + i * 8, mem.read64(a + i * 8) + mem.read64(b + i * 8))
|
||||
|
||||
func fsub[out: ptr, a: ptr, b: ptr] : void
|
||||
for i in 0..16
|
||||
mem.write64(out + i * 8, mem.read64(a + i * 8) - mem.read64(b + i * 8))
|
||||
|
||||
func fmul[out: ptr, a: ptr, b: ptr] : void
|
||||
let product: ptr = mem.alloc(31 * 8)
|
||||
for i in 0..31
|
||||
mem.write64(product + i * 8, 0)
|
||||
for i in 0..16
|
||||
for j in 0..16
|
||||
mem.write64(product + (i + j) * 8, mem.read64(product + (i + j) * 8) + (mem.read64(a + i * 8) * mem.read64(b + j * 8)))
|
||||
for i in 0..15
|
||||
mem.write64(product + i * 8, mem.read64(product + i * 8) + 38 * mem.read64(product + (i + 16) * 8))
|
||||
for i in 0..16
|
||||
mem.write64(out + i * 8, mem.read64(product + i * 8))
|
||||
|
||||
carry25519(out)
|
||||
carry25519(out)
|
||||
mem.free(product)
|
||||
|
||||
func finverse[out: ptr, input: ptr] : void
|
||||
let c: ptr = mem.alloc(16 * 8)
|
||||
for i in 0..16
|
||||
mem.write64(c + i * 8, mem.read64(input + i * 8))
|
||||
|
||||
let i = 253
|
||||
while i >= 0
|
||||
fmul(c, c, c)
|
||||
if i != 2 && i != 4
|
||||
fmul(c, c, input)
|
||||
i = i - 1
|
||||
|
||||
for i in 0..16
|
||||
mem.write64(out + i * 8, mem.read64(c + i * 8))
|
||||
mem.free(c)
|
||||
|
||||
func swap25519[p: ptr, q: ptr, bit: i64] : void
|
||||
for i in 0..16
|
||||
let t: i64 = (-bit) & (mem.read64(p + i * 8) ^ mem.read64(q + i * 8))
|
||||
mem.write64(p + i * 8, mem.read64(p + i * 8) ^ t)
|
||||
mem.write64(q + i * 8, mem.read64(q + i * 8) ^ t)
|
||||
|
||||
func pack25519[out: ptr, input: ptr] : void
|
||||
let t: ptr = mem.alloc(16 * 8)
|
||||
for i in 0..16
|
||||
mem.write64(t + i * 8, mem.read64(input + i * 8))
|
||||
let m: ptr = mem.alloc(16 * 8)
|
||||
|
||||
carry25519(t)
|
||||
carry25519(t)
|
||||
carry25519(t)
|
||||
for j in 0..2
|
||||
mem.write64(m, mem.read64(t) - 0xffed)
|
||||
for i in 1..15
|
||||
mem.write64(m + i * 8, mem.read64(t + i * 8) - 0xffff - ((mem.read64(m + (i - 1) * 8) >> 16) & 1))
|
||||
mem.write64(m + (i - 1) * 8, mem.read64(m + (i - 1) * 8) & 0xffff)
|
||||
mem.write64(m + 15 * 8, mem.read64(t + 15 * 8) - 0x7fff - ((mem.read64(m + 14 * 8) >> 16) & 1))
|
||||
let carry: i64 = (mem.read64(m + 15 * 8) >> 16) & 1
|
||||
mem.write64(m + 14 * 8, mem.read64(m + 14 * 8) & 0xffff)
|
||||
swap25519(t, m, 1 - carry)
|
||||
|
||||
for i in 0..16
|
||||
let v: i64 = mem.read64(t + i * 8)
|
||||
mem.write8(out + i * 2, v & 0xff)
|
||||
mem.write8(out + i * 2 + 1, (v >> 8) & 0xff)
|
||||
|
||||
mem.free(t)
|
||||
mem.free(m)
|
||||
|
||||
func scalarmult[out: ptr, scalar: ptr, point: ptr] : void
|
||||
let clamped: ptr = mem.alloc(32)
|
||||
let a: ptr = mem.alloc(16 * 8)
|
||||
let b: ptr = mem.alloc(16 * 8)
|
||||
let c: ptr = mem.alloc(16 * 8)
|
||||
let d: ptr = mem.alloc(16 * 8)
|
||||
let e: ptr = mem.alloc(16 * 8)
|
||||
let f: ptr = mem.alloc(16 * 8)
|
||||
let x: ptr = mem.alloc(16 * 8)
|
||||
|
||||
let magic: ptr = mem.alloc(16 * 8)
|
||||
mem.zero(magic, 16 * 8)
|
||||
mem.write64(magic, 0xdb41) // 121665
|
||||
mem.write64(magic + 8, 1)
|
||||
|
||||
// copy and clamp scalar
|
||||
for i in 0..32
|
||||
mem.write8(clamped + i, scalar[i])
|
||||
mem.write8(clamped, clamped[0] & 0xf8)
|
||||
mem.write8(clamped + 31, (clamped[31] & 0x7f) | 0x40)
|
||||
|
||||
// load point
|
||||
unpack25519(x, point)
|
||||
|
||||
// initialize ladder state
|
||||
for i in 0..16
|
||||
mem.write64(a + i * 8, 0)
|
||||
mem.write64(b + i * 8, mem.read64(x + i * 8))
|
||||
mem.write64(c + i * 8, 0)
|
||||
mem.write64(d + i * 8, 0)
|
||||
mem.write64(a, 1)
|
||||
mem.write64(d, 1)
|
||||
|
||||
let i = 254
|
||||
while i >= 0
|
||||
let bit: i64 = (clamped[i >> 3] >> (i & 7)) & 1
|
||||
swap25519(a, b, bit)
|
||||
swap25519(c, d, bit)
|
||||
fadd(e, a, c)
|
||||
fsub(a, a, c)
|
||||
fadd(c, b, d)
|
||||
fsub(b, b, d)
|
||||
fmul(d, e, e)
|
||||
fmul(f, a, a)
|
||||
fmul(a, c, a)
|
||||
fmul(c, b, e)
|
||||
fadd(e, a, c)
|
||||
fsub(a, a, c)
|
||||
fmul(b, a, a)
|
||||
fsub(c, d, f)
|
||||
fmul(a, c, magic)
|
||||
fadd(a, a, d)
|
||||
fmul(c, c, a)
|
||||
fmul(a, d, f)
|
||||
fmul(d, b, x)
|
||||
fmul(b, e, e)
|
||||
swap25519(a, b, bit)
|
||||
swap25519(c, d, bit)
|
||||
i = i - 1
|
||||
|
||||
finverse(c, c)
|
||||
fmul(a, a, c)
|
||||
pack25519(out, a)
|
||||
|
||||
mem.free(clamped)
|
||||
mem.free(a)
|
||||
mem.free(b)
|
||||
mem.free(c)
|
||||
mem.free(d)
|
||||
mem.free(e)
|
||||
mem.free(f)
|
||||
mem.free(x)
|
||||
mem.free(magic)
|
||||
|
||||
func main[] : i64
|
||||
let scalar: ptr = str.hex_decode("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4")
|
||||
let point: ptr = str.hex_decode("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c")
|
||||
let expected: ptr = str.hex_decode("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552")
|
||||
|
||||
let out: ptr = mem.alloc(32)
|
||||
scalarmult(out, scalar, point)
|
||||
|
||||
io.print("Computed: ")
|
||||
io.println(str.hex_encode(out, 32))
|
||||
io.print("Expected: ")
|
||||
io.println(str.hex_encode(expected, 32))
|
||||
|
||||
let base_point: ptr = mem.alloc(32)
|
||||
mem.zero(base_point, 32)
|
||||
mem.write8(base_point, 9)
|
||||
|
||||
let alice_private: ptr = str.hex_decode("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a")
|
||||
io.print("A_priv: ")
|
||||
io.println(str.hex_encode(alice_private, 32))
|
||||
|
||||
let alice_public: ptr = mem.alloc(32)
|
||||
scalarmult(alice_public, alice_private, base_point)
|
||||
io.print("A_pub: ")
|
||||
io.println(str.hex_encode(alice_public, 32))
|
||||
|
||||
let bob_private: ptr = str.hex_decode("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6dbddb79b1732920165")
|
||||
io.print("B_priv: ")
|
||||
io.println(str.hex_encode(bob_private, 32))
|
||||
|
||||
let bob_public: ptr = mem.alloc(32)
|
||||
scalarmult(bob_public, bob_private, base_point)
|
||||
io.print("B_pub: ")
|
||||
io.println(str.hex_encode(bob_public, 32))
|
||||
|
||||
let alice_shared: ptr = mem.alloc(32)
|
||||
scalarmult(alice_shared, alice_private, bob_public)
|
||||
io.print("A_shared: ")
|
||||
io.println(str.hex_encode(alice_shared, 32))
|
||||
|
||||
let bob_shared: ptr = mem.alloc(32)
|
||||
scalarmult(bob_shared, bob_private, alice_public)
|
||||
io.print("B_shared: ")
|
||||
io.println(str.hex_encode(bob_shared, 32))
|
||||
Reference in New Issue
Block a user