diff --git a/examples/crypto.zr b/examples/crypto.zr index 267b678..1700556 100644 --- a/examples/crypto.zr +++ b/examples/crypto.zr @@ -5,18 +5,16 @@ func main[] : i64 let input: str = "Hello, World!" let input_len: i64 = str.len(input) - let out: ptr = mem.alloc(input_len) - crypto.xchacha20.xor(key, nonce, input, out, input_len) - io.println(str.hex_encode(out, input_len)) + let ciphertext: ptr = crypto.xchacha20.xor(key, nonce, input, input_len) + io.println(str.hex_encode(ciphertext, input_len)) // X25519 let scalar: ptr = str.hex_decode("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4") let point: ptr = str.hex_decode("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c") let expected: ptr = str.hex_decode("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552") - let out: ptr = mem.alloc(32) - crypto.x25519.scalarmult(out, scalar, point) + let out: ptr = crypto.x25519.scalarmult(scalar, point) io.print("Computed: ") io.println(str.hex_encode(out, 32)) @@ -31,8 +29,7 @@ func main[] : i64 io.print("A_priv: ") io.println(str.hex_encode(alice_private, 32)) - let alice_public: ptr = mem.alloc(32) - crypto.x25519.scalarmult(alice_public, alice_private, base_point) + let alice_public: ptr = crypto.x25519.scalarmult(alice_private, base_point) io.print("A_pub: ") io.println(str.hex_encode(alice_public, 32)) @@ -40,17 +37,14 @@ func main[] : i64 io.print("B_priv: ") io.println(str.hex_encode(bob_private, 32)) - let bob_public: ptr = mem.alloc(32) - crypto.x25519.scalarmult(bob_public, bob_private, base_point) + let bob_public: ptr = crypto.x25519.scalarmult(bob_private, base_point) io.print("B_pub: ") io.println(str.hex_encode(bob_public, 32)) - let alice_shared: ptr = mem.alloc(32) - crypto.x25519.scalarmult(alice_shared, alice_private, bob_public) + let alice_shared: ptr = crypto.x25519.scalarmult(alice_private, bob_public) io.print("A_shared: ") io.println(str.hex_encode(alice_shared, 32)) - let bob_shared: ptr = mem.alloc(32) - crypto.x25519.scalarmult(bob_shared, bob_private, alice_public) + let bob_shared: ptr = crypto.x25519.scalarmult(bob_private, alice_public) io.print("B_shared: ") io.println(str.hex_encode(bob_shared, 32)) \ No newline at end of file diff --git a/src/std.zr b/src/std.zr index 5649553..015a75a 100644 --- a/src/std.zr +++ b/src/std.zr @@ -781,14 +781,16 @@ func crypto.xchacha20._stream[key: ptr, nonce: ptr, out: ptr, len: i64] : void mem.free(nonce12) mem.free(subkey) -func crypto.xchacha20.xor[key: ptr, nonce: ptr, input: ptr, out: ptr, len: i64] : void +func crypto.xchacha20.xor[key: ptr, nonce: ptr, input: ptr, len: i64] : ptr if len <= 0 - return 0 + return dbg.panic("empty buffer passed to crypto.xchacha20.xor") + let out: ptr = mem.alloc(len) let ks: ptr = mem.alloc(len) crypto.xchacha20._stream(key, nonce, ks, len) for i in 0..len mem.write8(out + i, input[i] ^ ks[i]) mem.free(ks) + return out func crypto.x25519.carry[elem: ptr] : void for i in 0..16 @@ -877,7 +879,7 @@ func crypto.x25519.pack[out: ptr, input: ptr] : void mem.free(t) mem.free(m) -func crypto.x25519.scalarmult[out: ptr, scalar: ptr, point: ptr] : void +func crypto.x25519.scalarmult[scalar: ptr, point: ptr] : ptr let clamped: ptr = mem.alloc(32) let a: ptr = mem.alloc(16 * 8) let b: ptr = mem.alloc(16 * 8) @@ -939,6 +941,7 @@ func crypto.x25519.scalarmult[out: ptr, scalar: ptr, point: ptr] : void crypto.x25519.finverse(c, c) crypto.x25519.fmul(a, a, c) + let out: ptr = mem.alloc(32) crypto.x25519.pack(out, a) mem.free(clamped) @@ -949,4 +952,5 @@ func crypto.x25519.scalarmult[out: ptr, scalar: ptr, point: ptr] : void mem.free(e) mem.free(f) mem.free(x) - mem.free(magic) \ No newline at end of file + mem.free(magic) + return out \ No newline at end of file