From 8f7d085ae788fbe08d6b1e77d6989f2393d4dd91 Mon Sep 17 00:00:00 2001 From: lionarius Date: Wed, 13 Nov 2024 20:24:40 +0300 Subject: [PATCH] refactor 2 --- src/ast/mod.rs | 1 + src/ast/optimization.rs | 227 +++++++++++++++++++++++++++++++++++ src/ast/typed.rs | 259 ++++------------------------------------ src/cli.rs | 2 +- src/main.rs | 3 +- 5 files changed, 252 insertions(+), 240 deletions(-) create mode 100644 src/ast/optimization.rs diff --git a/src/ast/mod.rs b/src/ast/mod.rs index c873d06..80f199b 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,3 +1,4 @@ +pub mod optimization; pub mod typed; pub mod untyped; diff --git a/src/ast/optimization.rs b/src/ast/optimization.rs new file mode 100644 index 0000000..85e55a3 --- /dev/null +++ b/src/ast/optimization.rs @@ -0,0 +1,227 @@ +use std::fmt; + +use crate::ast::typed::{Type, TypedExpr}; +use crate::ast::{BinOp, OpCommutative}; +use crate::symbols::SymbolsTable; + +#[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, clap::ValueEnum)] +pub enum OLevel { + None = 0, + O1 = 1, + O2 = 2, + O3 = 3, +} + +pub fn optimize_expr(expr: TypedExpr) -> TypedExpr { + match expr { + TypedExpr::BinOp { lhs, op, rhs } => optimize_binop(*lhs, op, *rhs), + TypedExpr::IntToFloat { value } => optimize_int_to_float(*value), + expr => expr, + } +} + +pub fn bubble_binop_vars(expr: TypedExpr) -> TypedExpr { + let expr = reorder_commutative_expr(expr); + + let expr = match expr { + TypedExpr::BinOp { lhs, op, rhs } => { + let lhs = *lhs; + let rhs = *rhs; + + let lhs = bubble_binop_vars(lhs); + let rhs = bubble_binop_vars(rhs); + + let (lhs, rhs, op) = match (lhs, rhs) { + ( + TypedExpr::BinOp { + lhs: lhs1, + op: op1, + rhs: rhs1, + }, + rhs, + ) if rhs.is_const() && op.commutative(&op1) => ( + TypedExpr::BinOp { + lhs: lhs1, + op, + rhs: Box::new(rhs), + }, + *rhs1, + op1, + ), + (lhs, rhs) if !lhs.is_const() && rhs.is_const() && op.swappable(&op) => (rhs, lhs, op), + (lhs, rhs) if op.precedence() < lhs.precedence() && op.swappable(&op) => { + (rhs, lhs, op) + }, + (lhs, rhs) => (lhs, rhs, op), + }; + + let lhs = bubble_binop_vars(lhs); + let rhs = bubble_binop_vars(rhs); + + TypedExpr::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + } + }, + TypedExpr::IntToFloat { value } => TypedExpr::IntToFloat { + value: Box::new(bubble_binop_vars(*value)), + }, + expr => expr, + }; + + expr +} + +pub fn propagate_type_conversions( + expr: TypedExpr, + root_ty: Type, + symbols: &SymbolsTable, +) -> TypedExpr { + match expr { + TypedExpr::Int { value } if root_ty == Type::Float => TypedExpr::IntToFloat { + value: Box::new(TypedExpr::Int { value }), + }, + TypedExpr::Var { name } + if root_ty == Type::Float && symbols.resolve(&name).unwrap().ty == Some(Type::Int) => + { + TypedExpr::IntToFloat { + value: Box::new(TypedExpr::Var { name }), + } + }, + TypedExpr::IntToFloat { value } if root_ty == Type::Float => { + propagate_type_conversions(*value, Type::Float, symbols) + }, + TypedExpr::BinOp { lhs, rhs, op } => TypedExpr::BinOp { + lhs: Box::new(propagate_type_conversions(*lhs, root_ty, symbols)), + rhs: Box::new(propagate_type_conversions(*rhs, root_ty, symbols)), + op, + }, + expr => expr, + } +} + +fn optimize_binop(lhs: TypedExpr, op: BinOp, rhs: TypedExpr) -> TypedExpr { + let lhs = optimize_expr(lhs); + let rhs = optimize_expr(rhs); + + match (lhs, rhs) { + (TypedExpr::Int { value: lhs, .. }, TypedExpr::Int { value: rhs, .. }) => TypedExpr::Int { + value: op.evaluate(lhs, rhs), + }, + (TypedExpr::Float { value: lhs, .. }, TypedExpr::Float { value: rhs, .. }) => { + TypedExpr::Float { + value: op.evaluate(lhs, rhs), + } + }, + (lhs, rhs) => optimize_special_cases(lhs, op, rhs), + } +} + +fn optimize_special_cases(lhs: TypedExpr, op: BinOp, rhs: TypedExpr) -> TypedExpr { + match (lhs, rhs) { + // Addition of zero + (lhs, TypedExpr::Int { value: 0, .. }) | (TypedExpr::Int { value: 0, .. }, lhs) + if matches!(op, BinOp::Add) => + { + lhs + }, + (lhs, TypedExpr::Float { value: 0.0, .. }) | (TypedExpr::Float { value: 0.0, .. }, lhs) + if matches!(op, BinOp::Add) => + { + lhs + }, + + // Multiplication/Division by one + (lhs, TypedExpr::Int { value: 1, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs, + (lhs, TypedExpr::Float { value: 1.0, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs, + (TypedExpr::Int { value: 1, .. }, rhs) if matches!(op, BinOp::Mul) => rhs, + (TypedExpr::Float { value: 1.0, .. }, rhs) if matches!(op, BinOp::Mul) => rhs, + + // Multiplication by zero + (_, TypedExpr::Int { value: 0, .. }) | (TypedExpr::Int { value: 0, .. }, _) + if matches!(op, BinOp::Mul) => + { + TypedExpr::Int { value: 0 } + }, + (_, TypedExpr::Float { value: 0.0, .. }) | (TypedExpr::Float { value: 0.0, .. }, _) + if matches!(op, BinOp::Mul) => + { + TypedExpr::Float { value: 0.0 } + }, + + // Zero division + (TypedExpr::Int { value: 0, .. }, _) if matches!(op, BinOp::Div) => { + TypedExpr::Int { value: 0 } + }, + (TypedExpr::Float { value: 0.0, .. }, _) if matches!(op, BinOp::Div) => { + TypedExpr::Float { value: 0.0 } + }, + + // Default + (lhs, rhs) => TypedExpr::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }, + } +} + +fn optimize_int_to_float(value: TypedExpr) -> TypedExpr { + let value = optimize_expr(value); + if let TypedExpr::Int { value } = value { + TypedExpr::Float { + value: value as f64, + } + } else { + TypedExpr::IntToFloat { + value: Box::new(value), + } + } +} + +fn reorder_commutative_expr(expr: TypedExpr) -> TypedExpr { + match expr { + TypedExpr::BinOp { lhs, op, rhs, .. } => { + let commutative = rhs + .bin_op() + .map_or(OpCommutative::No, |op1| op.commutative_expr(&op1)); + + if let OpCommutative::Yes(op1) = commutative { + let (lhs, rhs) = match *rhs { + TypedExpr::BinOp { + lhs: lhs1, + rhs: rhs1, + .. + } => ( + TypedExpr::BinOp { lhs, op, rhs: lhs1 }, + reorder_commutative_expr(*rhs1), + ), + _ => unreachable!(), + }; + + TypedExpr::BinOp { + lhs: Box::new(lhs), + op: op1, + rhs: Box::new(rhs), + } + } else { + TypedExpr::BinOp { lhs, op, rhs } + } + }, + expr => expr, + } +} + +impl fmt::Display for OLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = match self { + OLevel::None => "none", + OLevel::O1 => "o1", + OLevel::O2 => "o2", + OLevel::O3 => "o3", + }; + + write!(f, "{}", str) + } +} diff --git a/src/ast/typed.rs b/src/ast/typed.rs index 22027f9..e50445e 100644 --- a/src/ast/typed.rs +++ b/src/ast/typed.rs @@ -1,33 +1,12 @@ use std::fmt; -use std::fmt::Display; use std::str::FromStr; -use super::{BinOp, OpCommutative, UntypedExpr}; +use super::{optimization, BinOp, UntypedExpr}; +use crate::ast::optimization::OLevel; use crate::error; use crate::error::{SemanticError, SemanticErrorKind}; use crate::symbols::{Symbol, SymbolsTable}; -#[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, clap::ValueEnum)] -pub enum OLevel { - None = 0, - O1 = 1, - O2 = 2, - O3 = 3, -} - -impl Display for OLevel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let str = match self { - OLevel::None => "none", - OLevel::O1 => "o1", - OLevel::O2 => "o2", - OLevel::O3 => "o3", - }; - - write!(f, "{}", str) - } -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Type { Int, @@ -99,18 +78,18 @@ impl fmt::Display for Type { impl TypedExpr { pub fn cast_to_float(self) -> TypedExpr { - TypedExpr::IntToFloat { + Self::IntToFloat { value: Box::new(self), } } pub fn ty(&self, symbols: &SymbolsTable) -> Type { match self { - TypedExpr::Int { .. } => Type::Int, - TypedExpr::Float { .. } => Type::Float, - TypedExpr::Var { name, .. } => symbols.type_of(name).expect("type not found"), - TypedExpr::BinOp { lhs, .. } => lhs.ty(symbols), - TypedExpr::IntToFloat { .. } => Type::Float, + Self::Int { .. } => Type::Int, + Self::Float { .. } => Type::Float, + Self::Var { name, .. } => symbols.type_of(name).expect("type not found"), + Self::BinOp { lhs, .. } => lhs.ty(symbols), + Self::IntToFloat { .. } => Type::Float, } } @@ -123,27 +102,27 @@ impl TypedExpr { pub fn is_var(&self) -> bool { match self { - TypedExpr::Var { .. } => true, - TypedExpr::IntToFloat { value } => value.is_var(), + Self::Var { .. } => true, + Self::IntToFloat { value } => value.is_var(), _ => false, } } pub fn is_const(&self) -> bool { match self { - TypedExpr::Int { .. } | TypedExpr::Float { .. } => true, - TypedExpr::IntToFloat { value } => value.is_const(), - TypedExpr::BinOp { lhs, rhs, .. } => lhs.is_const() && rhs.is_const(), + Self::Int { .. } | Self::Float { .. } => true, + Self::IntToFloat { value } => value.is_const(), + Self::BinOp { lhs, rhs, .. } => lhs.is_const() && rhs.is_const(), _ => false, } } pub fn precedence(&self) -> u8 { match self { - TypedExpr::Int { .. } | TypedExpr::Float { .. } => 0, - TypedExpr::Var { .. } => 0, - TypedExpr::BinOp { op, .. } => op.precedence(), - TypedExpr::IntToFloat { value } => value.precedence() + 1, + Self::Int { .. } | Self::Float { .. } => 0, + Self::Var { .. } => 0, + Self::BinOp { op, .. } => op.precedence(), + Self::IntToFloat { value } => value.precedence(), } } @@ -151,24 +130,24 @@ impl TypedExpr { expr: UntypedExpr, optimization_level: OLevel, symbols: &mut SymbolsTable, - ) -> error::Result { + ) -> error::Result { let expr = Self::convert_to_typed_expr(expr, symbols)?; let expr = Self::coerce_types(expr, symbols)?; let expr = if optimization_level > OLevel::None { let expr = if optimization_level >= OLevel::O3 { let ty = expr.ty(symbols); - Self::propagate_type_conversions(expr, ty, symbols) + optimization::propagate_type_conversions(expr, ty, symbols) } else { expr }; let expr = if optimization_level >= OLevel::O2 { - Self::bubble_binop_vars(expr) + optimization::bubble_binop_vars(expr) } else { expr }; - let expr = Self::optimize_expr(expr); + let expr = optimization::optimize_expr(expr); expr } else { @@ -252,202 +231,6 @@ impl TypedExpr { Ok(()) } - fn bubble_binop_vars(expr: Self) -> Self { - let expr = Self::reorder_commutative_expr(expr); - - let expr = match expr { - Self::BinOp { lhs, op, rhs } => { - let lhs = *lhs; - let rhs = *rhs; - - let lhs = Self::bubble_binop_vars(lhs); - let rhs = Self::bubble_binop_vars(rhs); - - let (lhs, rhs, op) = match (lhs, rhs) { - ( - Self::BinOp { - lhs: lhs1, - op: op1, - rhs: rhs1, - }, - rhs, - ) if rhs.is_const() && op.commutative(&op1) => ( - Self::BinOp { - lhs: lhs1, - op, - rhs: Box::new(rhs), - }, - *rhs1, - op1, - ), - (lhs, rhs) if lhs.is_var() && rhs.is_const() && op.swappable(&op) => { - (rhs, lhs, op) - }, - (lhs, rhs) if op.precedence() < lhs.precedence() && op.swappable(&op) => { - (rhs, lhs, op) - }, - (lhs, rhs) => (lhs, rhs, op), - }; - - let lhs = Self::bubble_binop_vars(lhs); - let rhs = Self::bubble_binop_vars(rhs); - - Self::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - } - }, - Self::IntToFloat { value } => Self::IntToFloat { - value: Box::new(Self::bubble_binop_vars(*value)), - }, - expr => expr, - }; - - expr - } - - fn reorder_commutative_expr(expr: Self) -> Self { - match expr { - Self::BinOp { lhs, op, rhs, .. } => { - let commutative = rhs - .bin_op() - .map_or(OpCommutative::No, |op1| op.commutative_expr(&op1)); - - if let OpCommutative::Yes(op1) = commutative { - let (lhs, rhs) = match *rhs { - Self::BinOp { - lhs: lhs1, - rhs: rhs1, - .. - } => ( - Self::BinOp { lhs, op, rhs: lhs1 }, - Self::reorder_commutative_expr(*rhs1), - ), - _ => unreachable!(), - }; - - Self::BinOp { - lhs: Box::new(lhs), - op: op1, - rhs: Box::new(rhs), - } - } else { - Self::BinOp { lhs, op, rhs } - } - }, - expr => expr, - } - } - - fn optimize_expr(expr: Self) -> Self { - match expr { - Self::BinOp { lhs, op, rhs } => Self::optimize_binop(*lhs, op, *rhs), - Self::IntToFloat { value } => Self::optimize_int_to_float(*value), - expr => expr, - } - } - - fn optimize_binop(lhs: Self, op: BinOp, rhs: Self) -> Self { - let lhs = Self::optimize_expr(lhs); - let rhs = Self::optimize_expr(rhs); - - match (lhs, rhs) { - (Self::Int { value: lhs, .. }, Self::Int { value: rhs, .. }) => Self::Int { - value: op.evaluate(lhs, rhs), - }, - (Self::Float { value: lhs, .. }, Self::Float { value: rhs, .. }) => Self::Float { - value: op.evaluate(lhs, rhs), - }, - (lhs, rhs) => Self::optimize_special_cases(lhs, op, rhs), - } - } - - fn optimize_special_cases(lhs: Self, op: BinOp, rhs: Self) -> Self { - match (lhs, rhs) { - // Addition of zero - (lhs, Self::Int { value: 0, .. }) | (Self::Int { value: 0, .. }, lhs) - if matches!(op, BinOp::Add) => - { - lhs - }, - (lhs, Self::Float { value: 0.0, .. }) | (Self::Float { value: 0.0, .. }, lhs) - if matches!(op, BinOp::Add) => - { - lhs - }, - - // Multiplication/Division by one - (lhs, Self::Int { value: 1, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs, - (lhs, Self::Float { value: 1.0, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs, - (Self::Int { value: 1, .. }, rhs) if matches!(op, BinOp::Mul) => rhs, - (Self::Float { value: 1.0, .. }, rhs) if matches!(op, BinOp::Mul) => rhs, - - // Multiplication by zero - (_, Self::Int { value: 0, .. }) | (Self::Int { value: 0, .. }, _) - if matches!(op, BinOp::Mul) => - { - Self::Int { value: 0 } - }, - (_, Self::Float { value: 0.0, .. }) | (Self::Float { value: 0.0, .. }, _) - if matches!(op, BinOp::Mul) => - { - Self::Float { value: 0.0 } - }, - - // Zero division - (Self::Int { value: 0, .. }, _) if matches!(op, BinOp::Div) => Self::Int { value: 0 }, - (Self::Float { value: 0.0, .. }, _) if matches!(op, BinOp::Div) => { - Self::Float { value: 0.0 } - }, - - // Default - (lhs, rhs) => Self::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }, - } - } - - fn optimize_int_to_float(value: Self) -> Self { - let value = Self::optimize_expr(value); - if let Self::Int { value } = value { - Self::Float { - value: value as f64, - } - } else { - Self::IntToFloat { - value: Box::new(value), - } - } - } - - fn propagate_type_conversions(expr: Self, root_ty: Type, symbols: &SymbolsTable) -> Self { - match expr { - Self::Int { value } if root_ty == Type::Float => Self::IntToFloat { - value: Box::new(Self::Int { value }), - }, - Self::Var { name } - if root_ty == Type::Float - && symbols.resolve(&name).unwrap().ty == Some(Type::Int) => - { - Self::IntToFloat { - value: Box::new(Self::Var { name }), - } - }, - Self::IntToFloat { value } if root_ty == Type::Float => { - Self::propagate_type_conversions(*value, Type::Float, symbols) - }, - Self::BinOp { lhs, rhs, op } => Self::BinOp { - lhs: Box::new(Self::propagate_type_conversions(*lhs, root_ty, symbols)), - rhs: Box::new(Self::propagate_type_conversions(*rhs, root_ty, symbols)), - op, - }, - expr => expr, - } - } - fn coerce_types(expr: Self, symbols: &mut SymbolsTable) -> error::Result { let expr = match expr { TypedExpr::Int { .. } | TypedExpr::Float { .. } | TypedExpr::IntToFloat { .. } => expr, diff --git a/src/cli.rs b/src/cli.rs index 3604ad8..75f4ce5 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use std::path::PathBuf; use clap::{CommandFactory, Parser, Subcommand}; -use developing_compilers::ast::typed::OLevel; +use developing_compilers::ast::optimization::OLevel; pub struct Args { inner: ArgsInner, diff --git a/src/main.rs b/src/main.rs index db85c74..3af34cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,8 @@ use std::io::{self, Write}; use std::path::Path; use cli::GenMode; -use developing_compilers::ast::typed::{OLevel, Type, TypedExpr}; +use developing_compilers::ast::optimization::OLevel; +use developing_compilers::ast::typed::{Type, TypedExpr}; use developing_compilers::interpreter::{Interpreter, Value}; use developing_compilers::representation::intermediate::IntermediateExpr; use developing_compilers::*;