analyzer, check function arity, catch redefinitions

This commit is contained in:
2025-07-27 13:26:23 +02:00
parent 4e43d23263
commit dc3abc36a1
3 changed files with 168 additions and 6 deletions

146
src/analyzer.rs Normal file
View File

@@ -0,0 +1,146 @@
use std::collections::HashMap;
use crate::{
parser::{Expr, Stmt},
tokenizer::{ZernError, error},
};
pub struct Analyzer {
pub functions: HashMap<String, usize>,
}
impl Analyzer {
pub fn new() -> Analyzer {
Analyzer {
functions: HashMap::new(),
}
}
pub fn register_function(&mut self, stmt: &Stmt) -> Result<(), ZernError> {
if let Stmt::Function {
name,
params,
return_type: _,
body: _,
} = stmt
{
if self.functions.contains_key(&name.lexeme) {
return error!(name.loc, format!("tried to redefine '{}'", name.lexeme));
}
self.functions.insert(name.lexeme.clone(), params.len());
}
Ok(())
}
pub fn analyze_stmt(&mut self, stmt: &Stmt) -> Result<(), ZernError> {
match stmt {
Stmt::Expression(expr) => self.analyze_expr(expr)?,
Stmt::Let {
name: _,
var_type: _,
initializer,
} => {
self.analyze_expr(initializer)?;
}
Stmt::Block(statements) => {
for stmt in statements {
self.analyze_stmt(stmt)?;
}
}
Stmt::If {
condition,
then_branch,
else_branch,
} => {
self.analyze_expr(condition)?;
self.analyze_stmt(then_branch)?;
self.analyze_stmt(else_branch)?;
}
Stmt::While { condition, body } => {
self.analyze_expr(condition)?;
self.analyze_stmt(body)?;
}
Stmt::Function {
name,
params: _,
return_type,
body,
} => {
if name.lexeme == "main" && return_type.lexeme != "I64" {
return error!(&name.loc, "main must return I64");
}
self.analyze_stmt(body)?;
}
Stmt::Return(expr) => {
self.analyze_expr(expr)?;
}
Stmt::For {
var: _,
start,
end,
body,
} => {
self.analyze_expr(start)?;
self.analyze_expr(end)?;
self.analyze_stmt(body)?;
}
Stmt::Break => {}
Stmt::Continue => {}
}
Ok(())
}
pub fn analyze_expr(&mut self, expr: &Expr) -> Result<(), ZernError> {
match expr {
Expr::Binary { left, op: _, right } => {
self.analyze_expr(left)?;
self.analyze_expr(right)?;
}
Expr::Grouping(expr) => self.analyze_expr(expr)?,
Expr::Literal(_) => {}
Expr::Unary { op: _, right } => {
self.analyze_expr(right)?;
}
Expr::Variable(_) => {}
Expr::Assign { name: _, value } => {
self.analyze_expr(value)?;
}
Expr::Call {
callee,
paren,
args,
} => {
let callee = match callee.as_ref() {
Expr::Variable(name) => name.lexeme.clone(),
_ => return error!(&paren.loc, "tried to call a non-constant expression"),
};
if let Some(arity) = self.functions.get(&callee) {
if *arity != args.len() {
return error!(
&paren.loc,
format!("expected {} arguments, got {}", arity, args.len())
);
}
} else {
// TODO: cant error here since we dont analyze externs/builtins
}
for arg in args {
self.analyze_expr(arg)?;
}
}
Expr::ArrayLiteral(exprs) => {
for expr in exprs {
self.analyze_expr(expr)?;
}
}
Expr::Index { expr, index } => {
self.analyze_expr(expr)?;
self.analyze_expr(index)?;
}
}
Ok(())
}
}

View File

@@ -160,6 +160,7 @@ _builtin_rshift:
var_type,
initializer,
} => {
// TODO: move to analyzer
if env.get_var(&name.lexeme).is_some() {
return error!(
name.loc,
@@ -215,14 +216,11 @@ _builtin_rshift:
Stmt::Function {
name,
params,
return_type,
return_type: _,
body,
} => {
if name.lexeme == "main" {
emit!(&mut self.output, "global {}", name.lexeme);
if return_type.lexeme != "I64" {
return error!(&name.loc, "main must return I64");
}
}
emit!(&mut self.output, "section .text.{}", name.lexeme);
emit!(&mut self.output, "{}:", name.lexeme);
@@ -430,6 +428,7 @@ _builtin_rshift:
}
}
Expr::Variable(name) => {
// TODO: move to analyzer
let var = match env.get_var(&name.lexeme) {
Some(x) => x,
None => {
@@ -445,6 +444,7 @@ _builtin_rshift:
Expr::Assign { name, value } => {
self.compile_expr(env, *value)?;
// TODO: move to analyzer
let var = match env.get_var(&name.lexeme) {
Some(x) => x,
None => {

View File

@@ -1,3 +1,4 @@
mod analyzer;
mod codegen_x86_64;
mod parser;
mod tokenizer;
@@ -13,6 +14,7 @@ use tokenizer::ZernError;
use clap::Parser;
fn compile_file_to(
analyzer: &mut analyzer::Analyzer,
codegen: &mut codegen_x86_64::CodegenX86_64,
filename: &str,
source: String,
@@ -23,6 +25,14 @@ fn compile_file_to(
let parser = parser::Parser::new(tokens);
let statements = parser.parse()?;
for stmt in &statements {
analyzer.register_function(stmt)?;
}
for stmt in &statements {
analyzer.analyze_stmt(stmt)?;
}
for stmt in statements {
codegen.compile_stmt(&mut codegen_x86_64::Env::new(), stmt)?;
}
@@ -51,10 +61,16 @@ fn compile_file(args: Args) -> Result<(), ZernError> {
let filename = Path::new(&args.path).file_name().unwrap().to_str().unwrap();
let mut analyzer = analyzer::Analyzer::new();
let mut codegen = codegen_x86_64::CodegenX86_64::new();
codegen.emit_prologue()?;
compile_file_to(&mut codegen, "std.zr", include_str!("std.zr").into())?;
compile_file_to(&mut codegen, filename, source)?;
compile_file_to(
&mut analyzer,
&mut codegen,
"std.zr",
include_str!("std.zr").into(),
)?;
compile_file_to(&mut analyzer, &mut codegen, filename, source)?;
if !args.output_asm {
fs::write(format!("{}.s", args.out), codegen.get_output()).unwrap();