func unpack25519[out: ptr, input: ptr] : void for i in 0..16 mem.write64(out + i * 8, input[i * 2] + (input[i * 2 + 1] << 8)) mem.write64(out + 8 * 15, mem.read64(out + 8 * 15) & 0x7fff) func carry25519[elem: ptr] : void for i in 0..16 let carry: i64 = mem.read64(elem + i * 8) >> 16 mem.write64(elem + i * 8, mem.read64(elem + i * 8) - (carry << 16)) if i < 15 mem.write64(elem + (i + 1) * 8, mem.read64(elem + (i + 1) * 8) + carry) else mem.write64(elem, mem.read64(elem) + 38 * carry) func fadd[out: ptr, a: ptr, b: ptr] : void for i in 0..16 mem.write64(out + i * 8, mem.read64(a + i * 8) + mem.read64(b + i * 8)) func fsub[out: ptr, a: ptr, b: ptr] : void for i in 0..16 mem.write64(out + i * 8, mem.read64(a + i * 8) - mem.read64(b + i * 8)) func fmul[out: ptr, a: ptr, b: ptr] : void let product: ptr = mem.alloc(31 * 8) for i in 0..31 mem.write64(product + i * 8, 0) for i in 0..16 for j in 0..16 mem.write64(product + (i + j) * 8, mem.read64(product + (i + j) * 8) + (mem.read64(a + i * 8) * mem.read64(b + j * 8))) for i in 0..15 mem.write64(product + i * 8, mem.read64(product + i * 8) + 38 * mem.read64(product + (i + 16) * 8)) for i in 0..16 mem.write64(out + i * 8, mem.read64(product + i * 8)) carry25519(out) carry25519(out) mem.free(product) func finverse[out: ptr, input: ptr] : void let c: ptr = mem.alloc(16 * 8) for i in 0..16 mem.write64(c + i * 8, mem.read64(input + i * 8)) let i = 253 while i >= 0 fmul(c, c, c) if i != 2 && i != 4 fmul(c, c, input) i = i - 1 for i in 0..16 mem.write64(out + i * 8, mem.read64(c + i * 8)) mem.free(c) func swap25519[p: ptr, q: ptr, bit: i64] : void for i in 0..16 let t: i64 = (-bit) & (mem.read64(p + i * 8) ^ mem.read64(q + i * 8)) mem.write64(p + i * 8, mem.read64(p + i * 8) ^ t) mem.write64(q + i * 8, mem.read64(q + i * 8) ^ t) func pack25519[out: ptr, input: ptr] : void let t: ptr = mem.alloc(16 * 8) for i in 0..16 mem.write64(t + i * 8, mem.read64(input + i * 8)) let m: ptr = mem.alloc(16 * 8) carry25519(t) carry25519(t) carry25519(t) for j in 0..2 mem.write64(m, mem.read64(t) - 0xffed) for i in 1..15 mem.write64(m + i * 8, mem.read64(t + i * 8) - 0xffff - ((mem.read64(m + (i - 1) * 8) >> 16) & 1)) mem.write64(m + (i - 1) * 8, mem.read64(m + (i - 1) * 8) & 0xffff) mem.write64(m + 15 * 8, mem.read64(t + 15 * 8) - 0x7fff - ((mem.read64(m + 14 * 8) >> 16) & 1)) let carry: i64 = (mem.read64(m + 15 * 8) >> 16) & 1 mem.write64(m + 14 * 8, mem.read64(m + 14 * 8) & 0xffff) swap25519(t, m, 1 - carry) for i in 0..16 let v: i64 = mem.read64(t + i * 8) mem.write8(out + i * 2, v & 0xff) mem.write8(out + i * 2 + 1, (v >> 8) & 0xff) mem.free(t) mem.free(m) func scalarmult[out: ptr, scalar: ptr, point: ptr] : void let clamped: ptr = mem.alloc(32) let a: ptr = mem.alloc(16 * 8) let b: ptr = mem.alloc(16 * 8) let c: ptr = mem.alloc(16 * 8) let d: ptr = mem.alloc(16 * 8) let e: ptr = mem.alloc(16 * 8) let f: ptr = mem.alloc(16 * 8) let x: ptr = mem.alloc(16 * 8) let magic: ptr = mem.alloc(16 * 8) mem.zero(magic, 16 * 8) mem.write64(magic, 0xdb41) // 121665 mem.write64(magic + 8, 1) // copy and clamp scalar for i in 0..32 mem.write8(clamped + i, scalar[i]) mem.write8(clamped, clamped[0] & 0xf8) mem.write8(clamped + 31, (clamped[31] & 0x7f) | 0x40) // load point unpack25519(x, point) // initialize ladder state for i in 0..16 mem.write64(a + i * 8, 0) mem.write64(b + i * 8, mem.read64(x + i * 8)) mem.write64(c + i * 8, 0) mem.write64(d + i * 8, 0) mem.write64(a, 1) mem.write64(d, 1) let i = 254 while i >= 0 let bit: i64 = (clamped[i >> 3] >> (i & 7)) & 1 swap25519(a, b, bit) swap25519(c, d, bit) fadd(e, a, c) fsub(a, a, c) fadd(c, b, d) fsub(b, b, d) fmul(d, e, e) fmul(f, a, a) fmul(a, c, a) fmul(c, b, e) fadd(e, a, c) fsub(a, a, c) fmul(b, a, a) fsub(c, d, f) fmul(a, c, magic) fadd(a, a, d) fmul(c, c, a) fmul(a, d, f) fmul(d, b, x) fmul(b, e, e) swap25519(a, b, bit) swap25519(c, d, bit) i = i - 1 finverse(c, c) fmul(a, a, c) pack25519(out, a) mem.free(clamped) mem.free(a) mem.free(b) mem.free(c) mem.free(d) mem.free(e) mem.free(f) mem.free(x) mem.free(magic) func main[] : i64 let scalar: ptr = str.hex_decode("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4") let point: ptr = str.hex_decode("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c") let expected: ptr = str.hex_decode("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552") let out: ptr = mem.alloc(32) scalarmult(out, scalar, point) io.print("Computed: ") io.println(str.hex_encode(out, 32)) io.print("Expected: ") io.println(str.hex_encode(expected, 32)) let base_point: ptr = mem.alloc(32) mem.zero(base_point, 32) mem.write8(base_point, 9) let alice_private: ptr = str.hex_decode("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a") io.print("A_priv: ") io.println(str.hex_encode(alice_private, 32)) let alice_public: ptr = mem.alloc(32) scalarmult(alice_public, alice_private, base_point) io.print("A_pub: ") io.println(str.hex_encode(alice_public, 32)) let bob_private: ptr = str.hex_decode("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6dbddb79b1732920165") io.print("B_priv: ") io.println(str.hex_encode(bob_private, 32)) let bob_public: ptr = mem.alloc(32) scalarmult(bob_public, bob_private, base_point) io.print("B_pub: ") io.println(str.hex_encode(bob_public, 32)) let alice_shared: ptr = mem.alloc(32) scalarmult(alice_shared, alice_private, bob_public) io.print("A_shared: ") io.println(str.hex_encode(alice_shared, 32)) let bob_shared: ptr = mem.alloc(32) scalarmult(bob_shared, bob_private, alice_public) io.print("B_shared: ") io.println(str.hex_encode(bob_shared, 32))