242 lines
7.3 KiB
Rust
242 lines
7.3 KiB
Rust
use std::fmt;
|
|
|
|
use crate::ast::typed::{Type, TypedExpr};
|
|
use crate::ast::{BinOp, OpCommutative};
|
|
use crate::symbols::SymbolsTable;
|
|
|
|
#[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, clap::ValueEnum)]
|
|
pub enum OLevel {
|
|
None = 0,
|
|
O1 = 1,
|
|
O2 = 2,
|
|
O3 = 3,
|
|
}
|
|
|
|
pub fn optimize_expr(expr: TypedExpr) -> TypedExpr {
|
|
match expr {
|
|
TypedExpr::BinOp { lhs, op, rhs } => optimize_binop(*lhs, op, *rhs),
|
|
TypedExpr::IntToFloat { value } => optimize_int_to_float(*value),
|
|
expr => expr,
|
|
}
|
|
}
|
|
|
|
pub fn bubble_binop_vars(expr: TypedExpr) -> TypedExpr {
|
|
if expr.is_const() {
|
|
return expr;
|
|
}
|
|
|
|
let expr = reorder_commutative_expr(expr);
|
|
|
|
let expr = match expr {
|
|
TypedExpr::BinOp { lhs, op, rhs } => {
|
|
let lhs = *lhs;
|
|
let rhs = *rhs;
|
|
|
|
let lhs = bubble_binop_vars(lhs);
|
|
let rhs = bubble_binop_vars(rhs);
|
|
|
|
let (lhs, rhs, op, rebubble) = match (lhs, rhs) {
|
|
(
|
|
TypedExpr::BinOp {
|
|
lhs: lhs1,
|
|
op: op1,
|
|
rhs: rhs1,
|
|
},
|
|
rhs,
|
|
) if rhs.is_const() && op.commutative(&op1) => (
|
|
TypedExpr::BinOp {
|
|
lhs: lhs1,
|
|
op,
|
|
rhs: Box::new(rhs),
|
|
},
|
|
*rhs1,
|
|
op1,
|
|
true,
|
|
),
|
|
(lhs, rhs) if !lhs.is_const() && rhs.is_const() && op.swappable(&op) => {
|
|
(rhs, lhs, op, false)
|
|
},
|
|
(lhs, rhs) if op.precedence() < lhs.precedence() && op.swappable(&op) => {
|
|
(rhs, lhs, op, true)
|
|
},
|
|
(lhs, rhs) => (lhs, rhs, op, false),
|
|
};
|
|
|
|
let lhs = if rebubble {
|
|
bubble_binop_vars(lhs)
|
|
} else {
|
|
lhs
|
|
};
|
|
|
|
TypedExpr::BinOp {
|
|
lhs: Box::new(lhs),
|
|
op,
|
|
rhs: Box::new(rhs),
|
|
}
|
|
},
|
|
TypedExpr::IntToFloat { value } => TypedExpr::IntToFloat {
|
|
value: Box::new(bubble_binop_vars(*value)),
|
|
},
|
|
expr => expr,
|
|
};
|
|
|
|
expr
|
|
}
|
|
|
|
pub fn propagate_type_conversions(
|
|
expr: TypedExpr,
|
|
root_ty: Type,
|
|
symbols: &SymbolsTable,
|
|
) -> TypedExpr {
|
|
match expr {
|
|
TypedExpr::Int { value } if root_ty == Type::Float => TypedExpr::IntToFloat {
|
|
value: Box::new(TypedExpr::Int { value }),
|
|
},
|
|
TypedExpr::Var { name }
|
|
if root_ty == Type::Float && symbols.resolve(&name).unwrap().ty == Some(Type::Int) =>
|
|
{
|
|
TypedExpr::IntToFloat {
|
|
value: Box::new(TypedExpr::Var { name }),
|
|
}
|
|
},
|
|
TypedExpr::IntToFloat { value } if root_ty == Type::Float => {
|
|
propagate_type_conversions(*value, Type::Float, symbols)
|
|
},
|
|
TypedExpr::BinOp { lhs, rhs, op } => TypedExpr::BinOp {
|
|
lhs: Box::new(propagate_type_conversions(*lhs, root_ty, symbols)),
|
|
rhs: Box::new(propagate_type_conversions(*rhs, root_ty, symbols)),
|
|
op,
|
|
},
|
|
expr => expr,
|
|
}
|
|
}
|
|
|
|
fn optimize_binop(lhs: TypedExpr, op: BinOp, rhs: TypedExpr) -> TypedExpr {
|
|
let lhs = optimize_expr(lhs);
|
|
let rhs = optimize_expr(rhs);
|
|
|
|
match (lhs, rhs) {
|
|
(TypedExpr::Int { value: lhs, .. }, TypedExpr::Int { value: rhs, .. }) => TypedExpr::Int {
|
|
value: op.evaluate(lhs, rhs),
|
|
},
|
|
(TypedExpr::Float { value: lhs, .. }, TypedExpr::Float { value: rhs, .. }) => {
|
|
TypedExpr::Float {
|
|
value: op.evaluate(lhs, rhs),
|
|
}
|
|
},
|
|
(lhs, rhs) => optimize_special_cases(lhs, op, rhs),
|
|
}
|
|
}
|
|
|
|
fn optimize_special_cases(lhs: TypedExpr, op: BinOp, rhs: TypedExpr) -> TypedExpr {
|
|
match (lhs, rhs) {
|
|
// Addition of zero
|
|
(lhs, TypedExpr::Int { value: 0, .. }) | (TypedExpr::Int { value: 0, .. }, lhs)
|
|
if matches!(op, BinOp::Add) =>
|
|
{
|
|
lhs
|
|
},
|
|
(lhs, TypedExpr::Float { value: 0.0, .. }) | (TypedExpr::Float { value: 0.0, .. }, lhs)
|
|
if matches!(op, BinOp::Add) =>
|
|
{
|
|
lhs
|
|
},
|
|
|
|
// Subtraction by zero
|
|
(lhs, TypedExpr::Int { value: 0, .. }) if matches!(op, BinOp::Sub) => lhs,
|
|
(lhs, TypedExpr::Float { value: 0.0, .. }) if matches!(op, BinOp::Sub) => lhs,
|
|
|
|
// Multiplication/Division by one
|
|
(lhs, TypedExpr::Int { value: 1, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs,
|
|
(lhs, TypedExpr::Float { value: 1.0, .. }) if matches!(op, BinOp::Mul | BinOp::Div) => lhs,
|
|
(TypedExpr::Int { value: 1, .. }, rhs) if matches!(op, BinOp::Mul) => rhs,
|
|
(TypedExpr::Float { value: 1.0, .. }, rhs) if matches!(op, BinOp::Mul) => rhs,
|
|
|
|
// Multiplication by zero
|
|
(_, TypedExpr::Int { value: 0, .. }) | (TypedExpr::Int { value: 0, .. }, _)
|
|
if matches!(op, BinOp::Mul) =>
|
|
{
|
|
TypedExpr::Int { value: 0 }
|
|
},
|
|
(_, TypedExpr::Float { value: 0.0, .. }) | (TypedExpr::Float { value: 0.0, .. }, _)
|
|
if matches!(op, BinOp::Mul) =>
|
|
{
|
|
TypedExpr::Float { value: 0.0 }
|
|
},
|
|
|
|
// Zero division
|
|
(TypedExpr::Int { value: 0, .. }, _) if matches!(op, BinOp::Div) => {
|
|
TypedExpr::Int { value: 0 }
|
|
},
|
|
(TypedExpr::Float { value: 0.0, .. }, _) if matches!(op, BinOp::Div) => {
|
|
TypedExpr::Float { value: 0.0 }
|
|
},
|
|
|
|
// Default
|
|
(lhs, rhs) => TypedExpr::BinOp {
|
|
lhs: Box::new(lhs),
|
|
op,
|
|
rhs: Box::new(rhs),
|
|
},
|
|
}
|
|
}
|
|
|
|
fn optimize_int_to_float(value: TypedExpr) -> TypedExpr {
|
|
let value = optimize_expr(value);
|
|
if let TypedExpr::Int { value } = value {
|
|
TypedExpr::Float {
|
|
value: value as f64,
|
|
}
|
|
} else {
|
|
TypedExpr::IntToFloat {
|
|
value: Box::new(value),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn reorder_commutative_expr(expr: TypedExpr) -> TypedExpr {
|
|
match expr {
|
|
TypedExpr::BinOp { lhs, op, rhs, .. } => {
|
|
let commutative = rhs
|
|
.bin_op()
|
|
.map_or(OpCommutative::No, |op1| op.commutative_expr(&op1));
|
|
|
|
if let OpCommutative::Yes(op1) = commutative {
|
|
let (lhs, rhs) = match *rhs {
|
|
TypedExpr::BinOp {
|
|
lhs: lhs1,
|
|
rhs: rhs1,
|
|
..
|
|
} => (
|
|
TypedExpr::BinOp { lhs, op, rhs: lhs1 },
|
|
reorder_commutative_expr(*rhs1),
|
|
),
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
TypedExpr::BinOp {
|
|
lhs: Box::new(lhs),
|
|
op: op1,
|
|
rhs: Box::new(rhs),
|
|
}
|
|
} else {
|
|
TypedExpr::BinOp { lhs, op, rhs }
|
|
}
|
|
},
|
|
expr => expr,
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for OLevel {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
let str = match self {
|
|
OLevel::None => "none",
|
|
OLevel::O1 => "o1",
|
|
OLevel::O2 => "o2",
|
|
OLevel::O3 => "o3",
|
|
};
|
|
|
|
write!(f, "{}", str)
|
|
}
|
|
}
|