diff --git a/examples/crypto.zr b/examples/crypto.zr new file mode 100644 index 0000000..267b678 --- /dev/null +++ b/examples/crypto.zr @@ -0,0 +1,56 @@ +func main[] : i64 + // XChaCha20 + let key: ptr = str.hex_decode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + let nonce: ptr = str.hex_decode("000102030405060708090a0b0c0d0e0f1011121314151617") + + let input: str = "Hello, World!" + let input_len: i64 = str.len(input) + let out: ptr = mem.alloc(input_len) + + crypto.xchacha20.xor(key, nonce, input, out, input_len) + io.println(str.hex_encode(out, input_len)) + + // X25519 + let scalar: ptr = str.hex_decode("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4") + let point: ptr = str.hex_decode("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c") + let expected: ptr = str.hex_decode("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552") + + let out: ptr = mem.alloc(32) + crypto.x25519.scalarmult(out, scalar, point) + + io.print("Computed: ") + io.println(str.hex_encode(out, 32)) + io.print("Expected: ") + io.println(str.hex_encode(expected, 32)) + + let base_point: ptr = mem.alloc(32) + mem.zero(base_point, 32) + mem.write8(base_point, 9) + + let alice_private: ptr = str.hex_decode("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a") + io.print("A_priv: ") + io.println(str.hex_encode(alice_private, 32)) + + let alice_public: ptr = mem.alloc(32) + crypto.x25519.scalarmult(alice_public, alice_private, base_point) + io.print("A_pub: ") + io.println(str.hex_encode(alice_public, 32)) + + let bob_private: ptr = str.hex_decode("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6dbddb79b1732920165") + io.print("B_priv: ") + io.println(str.hex_encode(bob_private, 32)) + + let bob_public: ptr = mem.alloc(32) + crypto.x25519.scalarmult(bob_public, bob_private, base_point) + io.print("B_pub: ") + io.println(str.hex_encode(bob_public, 32)) + + let alice_shared: ptr = mem.alloc(32) + crypto.x25519.scalarmult(alice_shared, alice_private, bob_public) + io.print("A_shared: ") + io.println(str.hex_encode(alice_shared, 32)) + + let bob_shared: ptr = mem.alloc(32) + crypto.x25519.scalarmult(bob_shared, bob_private, alice_public) + io.print("B_shared: ") + io.println(str.hex_encode(bob_shared, 32)) \ No newline at end of file diff --git a/examples/tokenizer.zr b/examples/tokenizer.zr deleted file mode 100644 index d862441..0000000 --- a/examples/tokenizer.zr +++ /dev/null @@ -1,264 +0,0 @@ -func eof[current: i64, source_len: i64] : bool - return current >= source_len - -func peek[current: i64, source: str, source_len: i64] : u8 - if eof(current, source_len) - return 0 - return source[current] - -func advance[current: ptr, column: ptr, source: str, source_len: i64] : u8 - if eof(mem.read64(current), source_len) - return 0 - let c: u8 = source[mem.read64(current)] - mem.write64(current, mem.read64(current) + 1) - mem.write64(column, mem.read64(column) + 1) - return c - -func match_char[expected: u8, current: ptr, column: ptr, source: str, source_len: i64] : bool - if eof(mem.read64(current), source_len) - return false - if source[mem.read64(current)] != expected - return false - mem.write64(current, mem.read64(current) + 1) - mem.write64(column, mem.read64(column) + 1) - return true - -func zern_error[filename: str, line: i64, column: i64, message: str] : void - io.print(filename) - io.print(":") - io.print_i64(line) - io.print(":") - io.print_i64(column) - io.print(" ERROR: ") - io.println(message) - os.exit(1) - -func count_indentation[current: ptr, column: ptr, source: str, source_len: i64] : i64 - let count = 0 - while peek(mem.read64(current), source, source_len) == ' ' - count = count + 1 - advance(current, column, source, source_len) - return count - -func handle_indentation[tokens: array, current: ptr, column: ptr, line: i64, source: str, source_len: i64, indent_stack: array, current_indent: ptr, filename: str] : void - if peek(mem.read64(current), source, source_len) == 10 // \n - return 0 - - let new_indent: i64 = count_indentation(current, column, source, source_len) - - if new_indent > mem.read64(current_indent) - array.push(indent_stack, new_indent) - add_token_with_lexeme("Indent", tokens, "", line, mem.read64(column)) - else if new_indent < mem.read64(current_indent) - while array.size(indent_stack) > 1 && array.nth(indent_stack, array.size(indent_stack) - 1) > new_indent - array.pop(indent_stack) - add_token_with_lexeme("Dedent", tokens, "", line, mem.read64(column)) - - if array.size(indent_stack) == 0 || array.nth(indent_stack, array.size(indent_stack) - 1) != new_indent - zern_error(filename, line, mem.read64(column), "invalid indentation") - - mem.write64(current_indent, new_indent) - -func add_token[type: i64, tokens: array, source: str, start: i64, current: i64, line: i64, column: i64] : void - let len: i64 = current - start - let lexeme: str = mem.alloc(len + 1) - for i in 0..len - str.set(lexeme, i, source[start + i]) - str.set(lexeme, len, 0) - array.push(tokens, [type, lexeme, line, column]) - -func add_token_with_lexeme[type: i64, tokens: array, lexeme: str, line: i64, column: i64] : void - array.push(tokens, [type, lexeme, line, column]) - -func scan_number[current: ptr, column: ptr, source: str, source_len: i64] : void - if match_char('x', current, column, source, source_len) - while str.is_hex_digit(peek(mem.read64(current), source, source_len)) - advance(current, column, source, source_len) - else if match_char('o', current, column, source, source_len) - while peek(mem.read64(current), source, source_len) >= '0' && peek(mem.read64(current), source, source_len) <= '7' - advance(current, column, source, source_len) - else - while str.is_digit(peek(mem.read64(current), source, source_len)) - advance(current, column, source, source_len) - -func scan_identifier[tokens: array, current: ptr, column: ptr, start: i64, line: i64, source: str, source_len: i64] : void - while str.is_alphanumeric(peek(mem.read64(current), source, source_len)) || peek(mem.read64(current), source, source_len) == '_' || peek(mem.read64(current), source, source_len) == '.' - advance(current, column, source, source_len) - - let len: i64 = mem.read64(current) - start - let lexeme: str = mem.alloc(len + 1) - for i in 0..len - str.set(lexeme, i, source[start + i]) - str.set(lexeme, len, 0) - - let type: str = "Identifier" - if str.equal(lexeme, "let") - type = "KeywordLet" - if str.equal(lexeme, "const") - type = "KeywordConst" - if str.equal(lexeme, "if") - type = "KeywordIf" - if str.equal(lexeme, "else") - type = "KeywordElse" - if str.equal(lexeme, "while") - type = "KeywordWhile" - if str.equal(lexeme, "for") - type = "KeywordFor" - if str.equal(lexeme, "in") - type = "KeywordIn" - if str.equal(lexeme, "func") - type = "KeywordFunc" - if str.equal(lexeme, "return") - type = "KeywordReturn" - if str.equal(lexeme, "break") - type = "KeywordBreak" - if str.equal(lexeme, "continue") - type = "KeywordContinue" - if str.equal(lexeme, "extern") - type = "KeywordExtern" - if str.equal(lexeme, "export") - type = "KeywordExport" - if str.equal(lexeme, "true") - type = "True" - if str.equal(lexeme, "false") - type = "False" - - add_token_with_lexeme(type, tokens, lexeme, line, mem.read64(column)) - -func scan_token[tokens: array, current: ptr, line: ptr, column: ptr, source: str, source_len: i64, filename: str, indent_stack: array, current_indent: ptr] : void - let start: i64 = mem.read64(current) - let c: u8 = advance(current, column, source, source_len) - - if c == '(' - add_token("LeftParen", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == ')' - add_token("RightParen", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '[' - add_token("LeftBracket", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == ']' - add_token("RightBracket", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == ',' - add_token("Comma", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '+' - add_token("Plus", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '-' - add_token("Minus", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '*' - add_token("Star", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '%' - add_token("Mod", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '^' - add_token("Xor", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == ':' - add_token("Colon", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '.' - if match_char('.', current, column, source, source_len) - add_token("DoubleDot", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - zern_error(filename, mem.read64(line), mem.read64(column), "expected '.' after '.'") - else if c == '/' - if match_char('/', current, column, source, source_len) - while !eof(mem.read64(current), source_len) && peek(mem.read64(current), source, source_len) != 10 - advance(current, column, source, source_len) - else - add_token("Slash", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '&' - if match_char('&', current, column, source, source_len) - add_token("LogicalAnd", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("BitAnd", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '|' - if match_char('>', current, column, source, source_len) - add_token("Pipe", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if match_char('|', current, column, source, source_len) - add_token("LogicalOr", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("BitOr", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '!' - if match_char('=', current, column, source, source_len) - add_token("NotEqual", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("Bang", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '=' - if match_char('=', current, column, source, source_len) - add_token("DoubleEqual", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("Equal", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '>' - if match_char('>', current, column, source, source_len) - add_token("ShiftRight", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if match_char('=', current, column, source, source_len) - add_token("GreaterEqual", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("Greater", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == '<' - if match_char('<', current, column, source, source_len) - add_token("ShiftLeft", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if match_char('=', current, column, source, source_len) - add_token("LessEqual", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else - add_token("Less", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == 39 // ' - if eof(mem.read64(current), source_len) - zern_error(filename, mem.read64(line), mem.read64(column), "unterminated char literal") - advance(current, column, source, source_len) - if !match_char(39, current, column, source, source_len) - zern_error(filename, mem.read64(line), mem.read64(column), "expected ' after char literal") - add_token("Char", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == 34 // " - while !eof(mem.read64(current), source_len) && peek(mem.read64(current), source, source_len) != 34 - if peek(mem.read64(current), source, source_len) == 10 // \n - mem.write64(line, mem.read64(line) + 1) - mem.write64(column, 1) - advance(current, column, source, source_len) - if eof(mem.read64(current), source_len) - zern_error(filename, mem.read64(line), mem.read64(column), "unterminated string") - advance(current, column, source, source_len) - add_token("String", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if c == ' ' || c == 13 // \r - return 0 - else if c == 10 // \n - mem.write64(line, mem.read64(line) + 1) - mem.write64(column, 1) - handle_indentation(tokens, current, column, mem.read64(line), source, source_len, indent_stack, current_indent, filename) - else if str.is_digit(c) - scan_number(current, column, source, source_len) - add_token("Number", tokens, source, start, mem.read64(current), mem.read64(line), mem.read64(column)) - else if str.is_letter(c) || c == '_' - scan_identifier(tokens, current, column, start, mem.read64(line), source, source_len) - else - zern_error(filename, mem.read64(line), mem.read64(column), "unexpected character") - -func tokenize[source: str, filename: str] : array - let source_len: i64 = str.len(source) - let current = 0 - let line = 1 - let column = 1 - let indent_stack: array = [0] - let current_indent = 0 - let tokens: array = [] - - while !eof(current, source_len) - scan_token(tokens, ^current, ^line, ^column, source, source_len, filename, indent_stack, ^current_indent) - - add_token_with_lexeme("Eof", tokens, "", line, column) - return tokens - -func main[argc: i64, argv: ptr] : i64 - if argc < 2 - dbg.panic("expected an argument") - - let path: str = mem.read64(argv + 8) - let source: str = io.read_file(path) - let tokens: array = tokenize(source, path) - - for i in 0..array.size(tokens) - let token: array = array.nth(tokens, i) - io.print(array.nth(token, 0)) - io.print(" ") - io.print(array.nth(token, 1)) - io.print(" ") - io.print_i64(array.nth(token, 2)) - io.print(" ") - io.print_i64(array.nth(token, 3)) - io.println("") \ No newline at end of file diff --git a/examples/x25519.zr b/examples/x25519.zr deleted file mode 100644 index 9402ff3..0000000 --- a/examples/x25519.zr +++ /dev/null @@ -1,205 +0,0 @@ -func unpack25519[out: ptr, input: ptr] : void - for i in 0..16 - mem.write64(out + i * 8, input[i * 2] + (input[i * 2 + 1] << 8)) - mem.write64(out + 8 * 15, mem.read64(out + 8 * 15) & 0x7fff) - -func carry25519[elem: ptr] : void - for i in 0..16 - let carry: i64 = mem.read64(elem + i * 8) >> 16 - mem.write64(elem + i * 8, mem.read64(elem + i * 8) - (carry << 16)) - if i < 15 - mem.write64(elem + (i + 1) * 8, mem.read64(elem + (i + 1) * 8) + carry) - else - mem.write64(elem, mem.read64(elem) + 38 * carry) - -func fadd[out: ptr, a: ptr, b: ptr] : void - for i in 0..16 - mem.write64(out + i * 8, mem.read64(a + i * 8) + mem.read64(b + i * 8)) - -func fsub[out: ptr, a: ptr, b: ptr] : void - for i in 0..16 - mem.write64(out + i * 8, mem.read64(a + i * 8) - mem.read64(b + i * 8)) - -func fmul[out: ptr, a: ptr, b: ptr] : void - let product: ptr = mem.alloc(31 * 8) - for i in 0..31 - mem.write64(product + i * 8, 0) - for i in 0..16 - for j in 0..16 - mem.write64(product + (i + j) * 8, mem.read64(product + (i + j) * 8) + (mem.read64(a + i * 8) * mem.read64(b + j * 8))) - for i in 0..15 - mem.write64(product + i * 8, mem.read64(product + i * 8) + 38 * mem.read64(product + (i + 16) * 8)) - for i in 0..16 - mem.write64(out + i * 8, mem.read64(product + i * 8)) - - carry25519(out) - carry25519(out) - mem.free(product) - -func finverse[out: ptr, input: ptr] : void - let c: ptr = mem.alloc(16 * 8) - for i in 0..16 - mem.write64(c + i * 8, mem.read64(input + i * 8)) - - let i = 253 - while i >= 0 - fmul(c, c, c) - if i != 2 && i != 4 - fmul(c, c, input) - i = i - 1 - - for i in 0..16 - mem.write64(out + i * 8, mem.read64(c + i * 8)) - mem.free(c) - -func swap25519[p: ptr, q: ptr, bit: i64] : void - for i in 0..16 - let t: i64 = (-bit) & (mem.read64(p + i * 8) ^ mem.read64(q + i * 8)) - mem.write64(p + i * 8, mem.read64(p + i * 8) ^ t) - mem.write64(q + i * 8, mem.read64(q + i * 8) ^ t) - -func pack25519[out: ptr, input: ptr] : void - let t: ptr = mem.alloc(16 * 8) - for i in 0..16 - mem.write64(t + i * 8, mem.read64(input + i * 8)) - let m: ptr = mem.alloc(16 * 8) - - carry25519(t) - carry25519(t) - carry25519(t) - for j in 0..2 - mem.write64(m, mem.read64(t) - 0xffed) - for i in 1..15 - mem.write64(m + i * 8, mem.read64(t + i * 8) - 0xffff - ((mem.read64(m + (i - 1) * 8) >> 16) & 1)) - mem.write64(m + (i - 1) * 8, mem.read64(m + (i - 1) * 8) & 0xffff) - mem.write64(m + 15 * 8, mem.read64(t + 15 * 8) - 0x7fff - ((mem.read64(m + 14 * 8) >> 16) & 1)) - let carry: i64 = (mem.read64(m + 15 * 8) >> 16) & 1 - mem.write64(m + 14 * 8, mem.read64(m + 14 * 8) & 0xffff) - swap25519(t, m, 1 - carry) - - for i in 0..16 - let v: i64 = mem.read64(t + i * 8) - mem.write8(out + i * 2, v & 0xff) - mem.write8(out + i * 2 + 1, (v >> 8) & 0xff) - - mem.free(t) - mem.free(m) - -func scalarmult[out: ptr, scalar: ptr, point: ptr] : void - let clamped: ptr = mem.alloc(32) - let a: ptr = mem.alloc(16 * 8) - let b: ptr = mem.alloc(16 * 8) - let c: ptr = mem.alloc(16 * 8) - let d: ptr = mem.alloc(16 * 8) - let e: ptr = mem.alloc(16 * 8) - let f: ptr = mem.alloc(16 * 8) - let x: ptr = mem.alloc(16 * 8) - - let magic: ptr = mem.alloc(16 * 8) - mem.zero(magic, 16 * 8) - mem.write64(magic, 0xdb41) // 121665 - mem.write64(magic + 8, 1) - - // copy and clamp scalar - for i in 0..32 - mem.write8(clamped + i, scalar[i]) - mem.write8(clamped, clamped[0] & 0xf8) - mem.write8(clamped + 31, (clamped[31] & 0x7f) | 0x40) - - // load point - unpack25519(x, point) - - // initialize ladder state - for i in 0..16 - mem.write64(a + i * 8, 0) - mem.write64(b + i * 8, mem.read64(x + i * 8)) - mem.write64(c + i * 8, 0) - mem.write64(d + i * 8, 0) - mem.write64(a, 1) - mem.write64(d, 1) - - let i = 254 - while i >= 0 - let bit: i64 = (clamped[i >> 3] >> (i & 7)) & 1 - swap25519(a, b, bit) - swap25519(c, d, bit) - fadd(e, a, c) - fsub(a, a, c) - fadd(c, b, d) - fsub(b, b, d) - fmul(d, e, e) - fmul(f, a, a) - fmul(a, c, a) - fmul(c, b, e) - fadd(e, a, c) - fsub(a, a, c) - fmul(b, a, a) - fsub(c, d, f) - fmul(a, c, magic) - fadd(a, a, d) - fmul(c, c, a) - fmul(a, d, f) - fmul(d, b, x) - fmul(b, e, e) - swap25519(a, b, bit) - swap25519(c, d, bit) - i = i - 1 - - finverse(c, c) - fmul(a, a, c) - pack25519(out, a) - - mem.free(clamped) - mem.free(a) - mem.free(b) - mem.free(c) - mem.free(d) - mem.free(e) - mem.free(f) - mem.free(x) - mem.free(magic) - -func main[] : i64 - let scalar: ptr = str.hex_decode("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4") - let point: ptr = str.hex_decode("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c") - let expected: ptr = str.hex_decode("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552") - - let out: ptr = mem.alloc(32) - scalarmult(out, scalar, point) - - io.print("Computed: ") - io.println(str.hex_encode(out, 32)) - io.print("Expected: ") - io.println(str.hex_encode(expected, 32)) - - let base_point: ptr = mem.alloc(32) - mem.zero(base_point, 32) - mem.write8(base_point, 9) - - let alice_private: ptr = str.hex_decode("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a") - io.print("A_priv: ") - io.println(str.hex_encode(alice_private, 32)) - - let alice_public: ptr = mem.alloc(32) - scalarmult(alice_public, alice_private, base_point) - io.print("A_pub: ") - io.println(str.hex_encode(alice_public, 32)) - - let bob_private: ptr = str.hex_decode("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6dbddb79b1732920165") - io.print("B_priv: ") - io.println(str.hex_encode(bob_private, 32)) - - let bob_public: ptr = mem.alloc(32) - scalarmult(bob_public, bob_private, base_point) - io.print("B_pub: ") - io.println(str.hex_encode(bob_public, 32)) - - let alice_shared: ptr = mem.alloc(32) - scalarmult(alice_shared, alice_private, bob_public) - io.print("A_shared: ") - io.println(str.hex_encode(alice_shared, 32)) - - let bob_shared: ptr = mem.alloc(32) - scalarmult(bob_shared, bob_private, alice_public) - io.print("B_shared: ") - io.println(str.hex_encode(bob_shared, 32)) \ No newline at end of file diff --git a/examples/xchacha20.zr b/examples/xchacha20.zr deleted file mode 100644 index 07dfd4b..0000000 --- a/examples/xchacha20.zr +++ /dev/null @@ -1,144 +0,0 @@ -func rotl32[x: i64, r: i64] : i64 - return ((x << r) | (x >> (32 - r))) & 0xffffffff - -func load32_le[p: ptr] : i64 - return p[0] | (p[1] << 8) | (p[2] << 16) | (p[3] << 24) - -func store32_le[p: ptr, v: i64] : void - mem.write8(p, v & 0xff) - mem.write8(p + 1, (v >> 8) & 0xff) - mem.write8(p + 2, (v >> 16) & 0xff) - mem.write8(p + 3, (v >> 24) & 0xff) - -func quarter_round[state: ptr, a: i64, b: i64, c: i64, d: i64] : void - let va: i64 = load32_le(state + a * 4) - let vb: i64 = load32_le(state + b * 4) - let vc: i64 = load32_le(state + c * 4) - let vd: i64 = load32_le(state + d * 4) - va = (va + vb) & 0xffffffff - vd = vd ^ va - vd = rotl32(vd, 16) - vc = (vc + vd) & 0xffffffff - vb = vb ^ vc - vb = rotl32(vb, 12) - va = (va + vb) & 0xffffffff - vd = vd ^ va - vd = rotl32(vd, 8) - vc = (vc + vd) & 0xffffffff - vb = vb ^ vc - vb = rotl32(vb, 7) - store32_le(state + a * 4, va) - store32_le(state + b * 4, vb) - store32_le(state + c * 4, vc) - store32_le(state + d * 4, vd) - -func chacha20_permute[state: ptr] : void - for i in 0..10 - quarter_round(state, 0, 4, 8, 12) - quarter_round(state, 1, 5, 9, 13) - quarter_round(state, 2, 6, 10, 14) - quarter_round(state, 3, 7, 11, 15) - quarter_round(state, 0, 5, 10, 15) - quarter_round(state, 1, 6, 11, 12) - quarter_round(state, 2, 7, 8, 13) - quarter_round(state, 3, 4, 9, 14) - -func chacha20_block[key: ptr, nonce: ptr, blocknum: i64, out: ptr] : void - let sigma: str = "expand 32-byte k" - let state: ptr = mem.alloc(16 * 4) - - store32_le(state + 0, load32_le(sigma + 0)) - store32_le(state + 4, load32_le(sigma + 4)) - store32_le(state + 8, load32_le(sigma + 8)) - store32_le(state + 12, load32_le(sigma + 12)) - - for i in 0..8 - store32_le(state + (4 + i) * 4, load32_le(key + i * 4)) - - store32_le(state + 12 * 4, blocknum) - store32_le(state + 13 * 4, load32_le(nonce + 0)) - store32_le(state + 14 * 4, load32_le(nonce + 4)) - store32_le(state + 15 * 4, load32_le(nonce + 8)) - - let working: ptr = mem.alloc(16 * 4) - for i in 0..16 - store32_le(working + i * 4, load32_le(state + i * 4)) - - chacha20_permute(working) - - for i in 0..16 - let v: i64 = (load32_le(working + i * 4) + load32_le(state + i * 4)) & 0xffffffff - store32_le(out + i * 4, v) - mem.free(working) - mem.free(state) - -func hchacha20[key: ptr, input: ptr, out32: ptr] : void - let sigma: str = "expand 32-byte k" - let state: ptr = mem.alloc(16 * 4) - - store32_le(state + 0, load32_le(sigma + 0)) - store32_le(state + 4, load32_le(sigma + 4)) - store32_le(state + 8, load32_le(sigma + 8)) - store32_le(state + 12, load32_le(sigma + 12)) - - for i in 0..8 - store32_le(state + (4 + i) * 4, load32_le(key + i * 4)) - - for i in 0..4 - store32_le(state + (12 + i) * 4, load32_le(input + i * 4)) - - chacha20_permute(state) - - for i in 0..4 - store32_le(out32 + i * 4, load32_le(state + i * 4)) - for i in 0..4 - store32_le(out32 + 16 + i * 4, load32_le(state + (12 + i) * 4)) - mem.free(state) - -func xchacha20_stream[key: ptr, nonce: ptr, out: ptr, len: i64] : void - let subkey: ptr = mem.alloc(32) - hchacha20(key, nonce, subkey) - - let nonce12: ptr = mem.alloc(12) - for i in 0..12 - mem.write8(nonce12 + i, 0) - for i in 0..8 - mem.write8(nonce12 + 4 + i, nonce[16 + i]) - - let blocknum: i64 = 0 - let remaining: i64 = len - let block: ptr = mem.alloc(64) - - while remaining > 0 - chacha20_block(subkey, nonce12, blocknum, block) - let take: i64 = 64 - if remaining < 64 - take = remaining - for i in 0..take - mem.write8(out + (len - remaining) + i, block[i]) - remaining = remaining - take - blocknum = blocknum + 1 - mem.free(block) - mem.free(nonce12) - mem.free(subkey) - -func xchacha20_xor[key: ptr, nonce: ptr, input: ptr, out: ptr, len: i64] : void - if len <= 0 - return 0 - let ks: ptr = mem.alloc(len) - xchacha20_stream(key, nonce, ks, len) - for i in 0..len - mem.write8(out + i, input[i] ^ ks[i]) - mem.free(ks) - -func main[] : i64 - let key: ptr = str.hex_decode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") - let nonce: ptr = str.hex_decode("000102030405060708090a0b0c0d0e0f1011121314151617") - - let input: str = "Hello, World!" - let input_len: i64 = str.len(input) - let out: ptr = mem.alloc(input_len) - - xchacha20_xor(key, nonce, input, out, input_len) - - io.println(str.hex_encode(out, input_len)) \ No newline at end of file diff --git a/src/codegen_x86_64.rs b/src/codegen_x86_64.rs index b13b0dc..40040a7 100644 --- a/src/codegen_x86_64.rs +++ b/src/codegen_x86_64.rs @@ -101,12 +101,6 @@ impl<'a> CodegenX86_64<'a> { "section .note.GNU-stack db 0 -section .text._builtin_read8 -_builtin_read8: - xor rax, rax - mov al, byte [rdi] - ret - section .text._builtin_read64 _builtin_read64: mov rax, qword [rdi] diff --git a/src/std.zr b/src/std.zr index 1c6f508..5649553 100644 --- a/src/std.zr +++ b/src/std.zr @@ -20,12 +20,13 @@ func mem.zero[x: i64, size: i64] : void mem.write8(x + i, 0) func mem.read8[x: ptr] : u8 - return _builtin_read8(x) + return x[0] func mem.read16[x: ptr] : i64 - let low: i64 = mem.read8(x) - let high: i64 = mem.read8(x + 1) - return low | (high << 8) + return x[0] | (x[1] << 8) + +func mem.read32[x: ptr] : i64 + return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24) func mem.read64[x: ptr] : i64 return _builtin_read64(x) @@ -33,6 +34,12 @@ func mem.read64[x: ptr] : i64 func mem.write8[x: ptr, d: u8] : void _builtin_set8(x, d) +func mem.write32[x: ptr, d: i64] : void + mem.write8(x, d & 0xff) + mem.write8(x + 1, (d >> 8) & 0xff) + mem.write8(x + 2, (d >> 16) & 0xff) + mem.write8(x + 3, (d >> 24) & 0xff) + func mem.write64[x: ptr, d: i64] : void _builtin_set64(x, d) @@ -97,7 +104,7 @@ func io.write_file[path: str, content: str] : void func str.len[s: str] : i64 let i = 0 - while mem.read8(s + i) + while s[i] i = i + 1 return i @@ -658,3 +665,288 @@ func net.close[s: i64] : void func net.pack_addr[a: i64, b: i64, c: i64, d: i64] : i64 return (a << 24) | (b << 16) | (c << 8) | d + +func crypto.rotl32[x: i64, r: i64] : i64 + return ((x << r) | (x >> (32 - r))) & 0xffffffff + +func crypto.chacha20._quarter_round[state: ptr, a: i64, b: i64, c: i64, d: i64] : void + let va: i64 = mem.read32(state + a * 4) + let vb: i64 = mem.read32(state + b * 4) + let vc: i64 = mem.read32(state + c * 4) + let vd: i64 = mem.read32(state + d * 4) + va = (va + vb) & 0xffffffff + vd = vd ^ va + vd = crypto.rotl32(vd, 16) + vc = (vc + vd) & 0xffffffff + vb = vb ^ vc + vb = crypto.rotl32(vb, 12) + va = (va + vb) & 0xffffffff + vd = vd ^ va + vd = crypto.rotl32(vd, 8) + vc = (vc + vd) & 0xffffffff + vb = vb ^ vc + vb = crypto.rotl32(vb, 7) + mem.write32(state + a * 4, va) + mem.write32(state + b * 4, vb) + mem.write32(state + c * 4, vc) + mem.write32(state + d * 4, vd) + +func crypto.xchacha20._permute[state: ptr] : void + for i in 0..10 + crypto.chacha20._quarter_round(state, 0, 4, 8, 12) + crypto.chacha20._quarter_round(state, 1, 5, 9, 13) + crypto.chacha20._quarter_round(state, 2, 6, 10, 14) + crypto.chacha20._quarter_round(state, 3, 7, 11, 15) + crypto.chacha20._quarter_round(state, 0, 5, 10, 15) + crypto.chacha20._quarter_round(state, 1, 6, 11, 12) + crypto.chacha20._quarter_round(state, 2, 7, 8, 13) + crypto.chacha20._quarter_round(state, 3, 4, 9, 14) + +func crypto.xchacha20._block[key: ptr, nonce: ptr, blocknum: i64, out: ptr] : void + let sigma: str = "expand 32-byte k" + let state: ptr = mem.alloc(16 * 4) + + mem.write32(state + 0, mem.read32(sigma + 0)) + mem.write32(state + 4, mem.read32(sigma + 4)) + mem.write32(state + 8, mem.read32(sigma + 8)) + mem.write32(state + 12, mem.read32(sigma + 12)) + + for i in 0..8 + mem.write32(state + (4 + i) * 4, mem.read32(key + i * 4)) + + mem.write32(state + 12 * 4, blocknum) + mem.write32(state + 13 * 4, mem.read32(nonce + 0)) + mem.write32(state + 14 * 4, mem.read32(nonce + 4)) + mem.write32(state + 15 * 4, mem.read32(nonce + 8)) + + let working: ptr = mem.alloc(16 * 4) + for i in 0..16 + mem.write32(working + i * 4, mem.read32(state + i * 4)) + + crypto.xchacha20._permute(working) + + for i in 0..16 + let v: i64 = (mem.read32(working + i * 4) + mem.read32(state + i * 4)) & 0xffffffff + mem.write32(out + i * 4, v) + mem.free(working) + mem.free(state) + +func crypto.xchacha20._hchacha20[key: ptr, input: ptr, out32: ptr] : void + let sigma: str = "expand 32-byte k" + let state: ptr = mem.alloc(16 * 4) + + mem.write32(state + 0, mem.read32(sigma + 0)) + mem.write32(state + 4, mem.read32(sigma + 4)) + mem.write32(state + 8, mem.read32(sigma + 8)) + mem.write32(state + 12, mem.read32(sigma + 12)) + + for i in 0..8 + mem.write32(state + (4 + i) * 4, mem.read32(key + i * 4)) + + for i in 0..4 + mem.write32(state + (12 + i) * 4, mem.read32(input + i * 4)) + + crypto.xchacha20._permute(state) + + for i in 0..4 + mem.write32(out32 + i * 4, mem.read32(state + i * 4)) + for i in 0..4 + mem.write32(out32 + 16 + i * 4, mem.read32(state + (12 + i) * 4)) + mem.free(state) + +func crypto.xchacha20._stream[key: ptr, nonce: ptr, out: ptr, len: i64] : void + let subkey: ptr = mem.alloc(32) + crypto.xchacha20._hchacha20(key, nonce, subkey) + + let nonce12: ptr = mem.alloc(12) + for i in 0..12 + mem.write8(nonce12 + i, 0) + for i in 0..8 + mem.write8(nonce12 + 4 + i, nonce[16 + i]) + + let blocknum: i64 = 0 + let remaining: i64 = len + let block: ptr = mem.alloc(64) + + while remaining > 0 + crypto.xchacha20._block(subkey, nonce12, blocknum, block) + let take: i64 = 64 + if remaining < 64 + take = remaining + for i in 0..take + mem.write8(out + (len - remaining) + i, block[i]) + remaining = remaining - take + blocknum = blocknum + 1 + mem.free(block) + mem.free(nonce12) + mem.free(subkey) + +func crypto.xchacha20.xor[key: ptr, nonce: ptr, input: ptr, out: ptr, len: i64] : void + if len <= 0 + return 0 + let ks: ptr = mem.alloc(len) + crypto.xchacha20._stream(key, nonce, ks, len) + for i in 0..len + mem.write8(out + i, input[i] ^ ks[i]) + mem.free(ks) + +func crypto.x25519.carry[elem: ptr] : void + for i in 0..16 + let carry: i64 = mem.read64(elem + i * 8) >> 16 + mem.write64(elem + i * 8, mem.read64(elem + i * 8) - (carry << 16)) + if i < 15 + mem.write64(elem + (i + 1) * 8, mem.read64(elem + (i + 1) * 8) + carry) + else + mem.write64(elem, mem.read64(elem) + 38 * carry) + +func crypto.x25519.fadd[out: ptr, a: ptr, b: ptr] : void + for i in 0..16 + mem.write64(out + i * 8, mem.read64(a + i * 8) + mem.read64(b + i * 8)) + +func crypto.x25519.fsub[out: ptr, a: ptr, b: ptr] : void + for i in 0..16 + mem.write64(out + i * 8, mem.read64(a + i * 8) - mem.read64(b + i * 8)) + +func crypto.x25519.fmul[out: ptr, a: ptr, b: ptr] : void + let product: ptr = mem.alloc(31 * 8) + for i in 0..31 + mem.write64(product + i * 8, 0) + for i in 0..16 + for j in 0..16 + mem.write64(product + (i + j) * 8, mem.read64(product + (i + j) * 8) + (mem.read64(a + i * 8) * mem.read64(b + j * 8))) + for i in 0..15 + mem.write64(product + i * 8, mem.read64(product + i * 8) + 38 * mem.read64(product + (i + 16) * 8)) + for i in 0..16 + mem.write64(out + i * 8, mem.read64(product + i * 8)) + + crypto.x25519.carry(out) + crypto.x25519.carry(out) + mem.free(product) + +func crypto.x25519.finverse[out: ptr, input: ptr] : void + let c: ptr = mem.alloc(16 * 8) + for i in 0..16 + mem.write64(c + i * 8, mem.read64(input + i * 8)) + + let i = 253 + while i >= 0 + crypto.x25519.fmul(c, c, c) + if i != 2 && i != 4 + crypto.x25519.fmul(c, c, input) + i = i - 1 + + for i in 0..16 + mem.write64(out + i * 8, mem.read64(c + i * 8)) + mem.free(c) + +func crypto.x25519.swap[p: ptr, q: ptr, bit: i64] : void + for i in 0..16 + let t: i64 = (-bit) & (mem.read64(p + i * 8) ^ mem.read64(q + i * 8)) + mem.write64(p + i * 8, mem.read64(p + i * 8) ^ t) + mem.write64(q + i * 8, mem.read64(q + i * 8) ^ t) + +func crypto.x25519.unpack[out: ptr, input: ptr] : void + for i in 0..16 + mem.write64(out + i * 8, input[i * 2] + (input[i * 2 + 1] << 8)) + mem.write64(out + 8 * 15, mem.read64(out + 8 * 15) & 0x7fff) + +func crypto.x25519.pack[out: ptr, input: ptr] : void + let t: ptr = mem.alloc(16 * 8) + for i in 0..16 + mem.write64(t + i * 8, mem.read64(input + i * 8)) + let m: ptr = mem.alloc(16 * 8) + + crypto.x25519.carry(t) + crypto.x25519.carry(t) + crypto.x25519.carry(t) + for j in 0..2 + mem.write64(m, mem.read64(t) - 0xffed) + for i in 1..15 + mem.write64(m + i * 8, mem.read64(t + i * 8) - 0xffff - ((mem.read64(m + (i - 1) * 8) >> 16) & 1)) + mem.write64(m + (i - 1) * 8, mem.read64(m + (i - 1) * 8) & 0xffff) + mem.write64(m + 15 * 8, mem.read64(t + 15 * 8) - 0x7fff - ((mem.read64(m + 14 * 8) >> 16) & 1)) + let carry: i64 = (mem.read64(m + 15 * 8) >> 16) & 1 + mem.write64(m + 14 * 8, mem.read64(m + 14 * 8) & 0xffff) + crypto.x25519.swap(t, m, 1 - carry) + + for i in 0..16 + let v: i64 = mem.read64(t + i * 8) + mem.write8(out + i * 2, v & 0xff) + mem.write8(out + i * 2 + 1, (v >> 8) & 0xff) + + mem.free(t) + mem.free(m) + +func crypto.x25519.scalarmult[out: ptr, scalar: ptr, point: ptr] : void + let clamped: ptr = mem.alloc(32) + let a: ptr = mem.alloc(16 * 8) + let b: ptr = mem.alloc(16 * 8) + let c: ptr = mem.alloc(16 * 8) + let d: ptr = mem.alloc(16 * 8) + let e: ptr = mem.alloc(16 * 8) + let f: ptr = mem.alloc(16 * 8) + let x: ptr = mem.alloc(16 * 8) + + let magic: ptr = mem.alloc(16 * 8) + mem.zero(magic, 16 * 8) + mem.write64(magic, 0xdb41) // 121665 + mem.write64(magic + 8, 1) + + // copy and clamp scalar + for i in 0..32 + mem.write8(clamped + i, scalar[i]) + mem.write8(clamped, clamped[0] & 0xf8) + mem.write8(clamped + 31, (clamped[31] & 0x7f) | 0x40) + + // load point + crypto.x25519.unpack(x, point) + + // initialize ladder state + for i in 0..16 + mem.write64(a + i * 8, 0) + mem.write64(b + i * 8, mem.read64(x + i * 8)) + mem.write64(c + i * 8, 0) + mem.write64(d + i * 8, 0) + mem.write64(a, 1) + mem.write64(d, 1) + + let i = 254 + while i >= 0 + let bit: i64 = (clamped[i >> 3] >> (i & 7)) & 1 + crypto.x25519.swap(a, b, bit) + crypto.x25519.swap(c, d, bit) + crypto.x25519.fadd(e, a, c) + crypto.x25519.fsub(a, a, c) + crypto.x25519.fadd(c, b, d) + crypto.x25519.fsub(b, b, d) + crypto.x25519.fmul(d, e, e) + crypto.x25519.fmul(f, a, a) + crypto.x25519.fmul(a, c, a) + crypto.x25519.fmul(c, b, e) + crypto.x25519.fadd(e, a, c) + crypto.x25519.fsub(a, a, c) + crypto.x25519.fmul(b, a, a) + crypto.x25519.fsub(c, d, f) + crypto.x25519.fmul(a, c, magic) + crypto.x25519.fadd(a, a, d) + crypto.x25519.fmul(c, c, a) + crypto.x25519.fmul(a, d, f) + crypto.x25519.fmul(d, b, x) + crypto.x25519.fmul(b, e, e) + crypto.x25519.swap(a, b, bit) + crypto.x25519.swap(c, d, bit) + i = i - 1 + + crypto.x25519.finverse(c, c) + crypto.x25519.fmul(a, a, c) + crypto.x25519.pack(out, a) + + mem.free(clamped) + mem.free(a) + mem.free(b) + mem.free(c) + mem.free(d) + mem.free(e) + mem.free(f) + mem.free(x) + mem.free(magic) \ No newline at end of file