From 1f6b90e1e1bb8c8ade9b1416e36465b449eb0225 Mon Sep 17 00:00:00 2001 From: lionarius Date: Fri, 1 Nov 2024 08:37:01 +0300 Subject: [PATCH] lab4.1 --- src/ast/mod.rs | 70 ++++++++++++++++++++++++++++++++++++++++----- src/ast/typed.rs | 16 +++++++---- src/error.rs | 2 ++ src/main.rs | 28 +++++++++++------- src/symbols/mod.rs | 71 +++++++++++++++++++++++++++++++++++----------- 5 files changed, 148 insertions(+), 39 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 71b9aef..c38bf05 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -54,7 +54,17 @@ impl fmt::Display for BinOp { } } -pub fn to_typed_expr(expr: UntypedExpr, symbols: &SymbolsTable) -> error::Result { +pub fn to_typed_expr(expr: UntypedExpr, symbols: &mut SymbolsTable) -> error::Result { + let expr = convert_to_typed_expr(expr, symbols)?; + let expr = coerce_types(expr, symbols)?; + + Ok(expr) +} + +fn convert_to_typed_expr( + expr: UntypedExpr, + symbols: &mut SymbolsTable, +) -> error::Result { let expr = match expr { UntypedExpr::Int { span, value } => TypedExpr::Int { span, value }, UntypedExpr::Float { span, value } => TypedExpr::Float { span, value }, @@ -65,12 +75,28 @@ pub fn to_typed_expr(expr: UntypedExpr, symbols: &SymbolsTable) -> error::Result } => { let ty = typename .and_then(|t| symbols.resolve(t)) + .map(|data| data.name.as_str()) .map(Type::from_str) .transpose() - .map_err(|e| SemanticError::new(span, e))? - .unwrap_or(Type::Int); + .map_err(|e| SemanticError::new(span, e))?; + { + let symbol = symbols.resolve_mut(name).unwrap(); + match (symbol.ty, ty) { + (Some(ty), Some(ty2)) if ty != ty2 => { + return Err(SemanticError::new( + span, + SemanticErrorKind::DuplicateSymbol(symbol.name.clone()), + ) + .into()) + } + (None, Some(ty)) => { + symbol.ty = Some(ty); + } + _ => {} + } + } - TypedExpr::Var { span, name, ty } + TypedExpr::Var { span, name } } UntypedExpr::BinOp { span, lhs, op, rhs } => { let rhs = *rhs; @@ -101,10 +127,39 @@ pub fn to_typed_expr(expr: UntypedExpr, symbols: &SymbolsTable) -> error::Result _ => {} } - let lhs = to_typed_expr(lhs, symbols)?; - let rhs = to_typed_expr(rhs, symbols)?; + let lhs = convert_to_typed_expr(lhs, symbols)?; + let rhs = convert_to_typed_expr(rhs, symbols)?; - let (lhs, rhs) = match (lhs.ty(), rhs.ty()) { + TypedExpr::BinOp { + span, + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + } + } + }; + + Ok(expr) +} + +fn coerce_types(expr: TypedExpr, symbols: &mut SymbolsTable) -> error::Result { + let expr = match expr { + TypedExpr::Int { .. } => expr, + TypedExpr::Float { .. } => expr, + TypedExpr::Var { name, .. } => { + let symbol = symbols.resolve_mut(name).unwrap(); + + if symbol.ty.is_none() { + symbol.ty = Some(Type::Int); + } + + expr + } + TypedExpr::BinOp { lhs, rhs, op, span } => { + let lhs = coerce_types(*lhs, symbols)?; + let rhs = coerce_types(*rhs, symbols)?; + + let (lhs, rhs) = match (lhs.ty(symbols), rhs.ty(symbols)) { (Type::Int, Type::Int) => (lhs, rhs), (Type::Float, Type::Float) => (lhs, rhs), (Type::Int, Type::Float) => { @@ -124,6 +179,7 @@ pub fn to_typed_expr(expr: UntypedExpr, symbols: &SymbolsTable) -> error::Result rhs: Box::new(rhs), } } + TypedExpr::IntToFloat { .. } => expr, }; Ok(expr) diff --git a/src/ast/typed.rs b/src/ast/typed.rs index 340d2ff..fac8eec 100644 --- a/src/ast/typed.rs +++ b/src/ast/typed.rs @@ -1,6 +1,9 @@ use std::{fmt, str::FromStr}; -use crate::{error::SemanticErrorKind, symbols::Symbol}; +use crate::{ + error::SemanticErrorKind, + symbols::{Symbol, SymbolsTable}, +}; use super::{BinOp, Span}; @@ -44,7 +47,6 @@ pub enum TypedExpr { Var { span: Span, name: Symbol, - ty: Type, }, BinOp { span: Span, @@ -74,12 +76,16 @@ impl TypedExpr { } } - pub fn ty(&self) -> Type { + pub fn ty(&self, symbols: &SymbolsTable) -> Type { match self { TypedExpr::Int { .. } => Type::Int, TypedExpr::Float { .. } => Type::Float, - TypedExpr::Var { ty, .. } => *ty, - TypedExpr::BinOp { rhs, .. } => rhs.ty(), + TypedExpr::Var { name, .. } => symbols + .resolve(*name) + .expect("symbol not found") + .ty + .expect("type not found"), + TypedExpr::BinOp { rhs, .. } => rhs.ty(symbols), TypedExpr::IntToFloat { .. } => Type::Float, } } diff --git a/src/error.rs b/src/error.rs index f8180c1..06c2e58 100644 --- a/src/error.rs +++ b/src/error.rs @@ -54,6 +54,7 @@ pub struct SemanticError { pub enum SemanticErrorKind { UnknownType(String), DivisionByZero, + DuplicateSymbol(String), } impl Error { @@ -197,6 +198,7 @@ impl fmt::Display for SemanticErrorKind { match self { SemanticErrorKind::UnknownType(s) => write!(f, "unknown type '{}'", s), SemanticErrorKind::DivisionByZero => write!(f, "division by zero"), + SemanticErrorKind::DuplicateSymbol(s) => write!(f, "duplicate symbol '{}'", s), } } } diff --git a/src/main.rs b/src/main.rs index c47b978..7c4bd2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use symbols::SymbolsTable; fn write_typed_expr( expr: &TypedExpr, + symbols: &SymbolsTable, writer: &mut impl Write, prefix: &str, is_last: bool, @@ -23,8 +24,9 @@ fn write_typed_expr( match expr { TypedExpr::Int { value, .. } => writeln!(writer, "<{}>", value), TypedExpr::Float { value, .. } => writeln!(writer, "<{}>", value), - TypedExpr::Var { name: id, ty, .. } => { - writeln!(writer, "", id, ty) + TypedExpr::Var { name, .. } => { + let ty = symbols.resolve(*name).unwrap().ty.unwrap(); + writeln!(writer, "", name, ty) } TypedExpr::BinOp { lhs, op, rhs, .. } => { writeln!(writer, "<{}>", op)?; @@ -35,8 +37,8 @@ fn write_typed_expr( format!("{}│ ", prefix) }; - write_typed_expr(lhs, writer, &new_prefix, false)?; - write_typed_expr(rhs, writer, &new_prefix, true) + write_typed_expr(lhs, symbols, writer, &new_prefix, false)?; + write_typed_expr(rhs, symbols, writer, &new_prefix, true) } TypedExpr::IntToFloat { value, .. } => { writeln!(writer, "i2f")?; @@ -47,13 +49,17 @@ fn write_typed_expr( format!("{}│ ", prefix) }; - write_typed_expr(value, writer, &new_prefix, true) + write_typed_expr(value, symbols, writer, &new_prefix, true) } } } -fn print_typed_expr(expr: &TypedExpr, writer: &mut impl Write) -> io::Result<()> { - write_typed_expr(expr, writer, "", true) +fn print_typed_expr( + expr: &TypedExpr, + symbols: &SymbolsTable, + writer: &mut impl Write, +) -> io::Result<()> { + write_typed_expr(expr, symbols, writer, "", true) } fn write_untyped_expr( @@ -122,8 +128,8 @@ fn main() -> anyhow::Result<()> { let mut writer_symbols = io::BufWriter::new(std::fs::File::create(output_symbols)?); - for (name, id) in &symbols { - writeln!(writer_symbols, "{name} -> {id}")?; + for (name, data) in &symbols { + writeln!(writer_symbols, "{name} -> {}", data)?; } } Err(e) => { @@ -155,12 +161,12 @@ fn main() -> anyhow::Result<()> { let mut parser = Parser::new(tokens); parser.parse() } - .and_then(|expr| ast::to_typed_expr(expr, &symbols)); + .and_then(|expr| ast::to_typed_expr(expr, &mut symbols)); match res { Ok(expr) => { let mut writer_tree = io::BufWriter::new(std::fs::File::create(output_tree)?); - print_typed_expr(&expr, &mut writer_tree)?; + print_typed_expr(&expr, &symbols, &mut writer_tree)?; } Err(e) => eprintln!("error: {}", e), } diff --git a/src/symbols/mod.rs b/src/symbols/mod.rs index 102919f..0af391d 100644 --- a/src/symbols/mod.rs +++ b/src/symbols/mod.rs @@ -5,9 +5,29 @@ use std::{ fmt::Display, }; +use crate::ast::typed::Type; + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct Symbol(usize); +#[derive(Debug)] +pub struct SymbolData { + pub id: Symbol, + pub name: String, + pub ty: Option, +} + +impl Display for SymbolData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id)?; + if let Some(ty) = self.ty { + write!(f, ":{}", ty)?; + } + + Ok(()) + } +} + impl Display for Symbol { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) @@ -16,7 +36,7 @@ impl Display for Symbol { #[derive(Debug)] pub struct SymbolsTable { - symbols: HashMap, + symbols: HashMap, next_id: usize, } @@ -29,24 +49,43 @@ impl SymbolsTable { } pub fn add(&mut self, symbol: impl Into) { - if let hash_map::Entry::Vacant(e) = self.symbols.entry(symbol.into()) { - e.insert(Symbol(self.next_id)); + let symbol = symbol.into(); + if let hash_map::Entry::Vacant(e) = self.symbols.entry(symbol.clone()) { + e.insert(SymbolData { + id: Symbol(self.next_id), + name: symbol, + ty: None, + }); self.next_id += 1; } } pub fn get(&self, symbol: &str) -> Option { - self.symbols.get(symbol).copied() + self.symbols.get(symbol).map(|data| data.id) } - pub fn resolve(&self, symbol: Symbol) -> Option<&str> { - self.symbols.iter().find_map(|(name, id)| { - if *id == symbol { - Some(name.as_str()) - } else { - None - } - }) + pub fn resolve(&self, symbol: Symbol) -> Option<&SymbolData> { + self.symbols.iter().find_map( + |(_name, data)| { + if data.id == symbol { + Some(data) + } else { + None + } + }, + ) + } + + pub fn resolve_mut(&mut self, symbol: Symbol) -> Option<&mut SymbolData> { + self.symbols.iter_mut().find_map( + |(_name, data)| { + if data.id == symbol { + Some(data) + } else { + None + } + }, + ) } } @@ -57,8 +96,8 @@ impl Default for SymbolsTable { } impl IntoIterator for SymbolsTable { - type Item = (String, Symbol); - type IntoIter = hash_map::IntoIter; + type Item = (String, SymbolData); + type IntoIter = hash_map::IntoIter; fn into_iter(self) -> Self::IntoIter { self.symbols.into_iter() @@ -66,8 +105,8 @@ impl IntoIterator for SymbolsTable { } impl<'a> IntoIterator for &'a SymbolsTable { - type Item = (&'a String, &'a Symbol); - type IntoIter = hash_map::Iter<'a, String, Symbol>; + type Item = (&'a String, &'a SymbolData); + type IntoIter = hash_map::Iter<'a, String, SymbolData>; fn into_iter(self) -> Self::IntoIter { self.symbols.iter()