diff --git a/.gitignore b/.gitignore index 56a172d9..ace80227 100644 --- a/.gitignore +++ b/.gitignore @@ -81,4 +81,5 @@ docs/paper/examples/ # Claude Code logs claude-output.log .worktrees/ -.worktree/ \ No newline at end of file +.worktree/ +mydocs/ diff --git a/docs/plans/2026-02-15-symbolic-expr-design.md b/docs/plans/2026-02-15-symbolic-expr-design.md new file mode 100644 index 00000000..f24a65cf --- /dev/null +++ b/docs/plans/2026-02-15-symbolic-expr-design.md @@ -0,0 +1,194 @@ +# Symbolic Expression System for Reduction Overhead + +**Date:** 2026-02-15 +**Status:** Approved + +## Goal + +Replace the current `Polynomial`/`Monomial`/`poly!` system with a general-purpose symbolic expression DSL that supports exponentials, logarithms, min/max, floor/ceil, and arbitrary arithmetic — not just polynomials. + +## Motivation + +The current `Polynomial` type only supports sums of monomials (coefficient × product of variables with integer exponents). It cannot represent: +- Exponential overhead: `1.44 ^ num_vertices` +- Logarithmic factors: `num_vertices * log2(num_edges)` +- Min/max: `max(num_vertices, num_edges)` +- Floor/ceil: `ceil(num_vertices / 2)` + +These are needed to accurately model reduction overhead between problems. + +## Design + +### 1. Expression AST + +```rust +#[derive(Clone, Debug)] +pub enum Expr { + Num(f64), + Var(Box), + BinOp { op: BinOp, lhs: Box, rhs: Box }, + Neg(Box), + Call { func: Func, args: Vec }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BinOp { Add, Sub, Mul, Div, Pow } + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Func { Log2, Log10, Ln, Exp, Sqrt, Min, Max, Floor, Ceil, Abs } +``` + +**Key decisions:** +- `Box` for variable names (avoids `String` allocation overhead on clone) +- No `PartialEq` on `Expr` (f64 makes it dangerous; NaN != NaN) +- `Func` enum for built-ins (compile-time exhaustiveness, no runtime string matching) +- Separate `Neg` variant (cleaner than encoding as `0 - x`) +- No user-defined functions (not needed for this use case) + +### 2. Evaluation + +```rust +pub enum EvalError { + UnknownVar(Box), + DivideByZero, + Arity { func: Func, expected: usize, got: usize }, + Domain { func: Func, detail: Box }, +} +``` + +- Recursive tree walk with `ProblemSize` providing variable bindings +- Returns `Result` +- **No NaN guarantee:** domain violations return `EvalError::Domain` instead of producing NaN + - `log(-1)` → Domain error + - Negative base + non-integer exponent → Domain error + - `0 / 0` → DivideByZero +- `^` is right-associative: `a ^ b ^ c` = `a ^ (b ^ c)` + +### 3. Parser + +Pratt parser (precedence climbing), ~200-300 lines. + +**Tokens:** `Num(f64)`, `Ident(Box)`, `+`, `-`, `*`, `/`, `^`, `(`, `)`, `,`, `Eof` + +**Precedence (lowest to highest):** + +| Level | Operators | Associativity | +|-------|-----------|---------------| +| 1 | `+`, `-` | Left | +| 2 | `*`, `/` | Left | +| 3 | unary `-` | Prefix | +| 4 | `^` | Right | + +Function calls parsed as part of primary expressions. + +**Grammar (for documentation; implementation uses Pratt):** +``` +expr = term (('+' | '-') term)* +term = unary (('*' | '/') unary)* +unary = '-' unary | power +power = primary ('^' power)? +primary = NUM | IDENT '(' args ')' | IDENT | '(' expr ')' +args = expr (',' expr)* +``` + +**Ident resolution:** If followed by `(`, matched against `Func` variants (ASCII case-insensitive). Unknown function → `ParseError::UnknownFunction`. + +**No implicit multiplication:** `2 num_vertices` is a parse error; must write `2 * num_vertices`. + +**Tokenizer rules:** +- Numbers: `42`, `1.5`, `.5` — no scientific notation +- Idents: `[a-zA-Z_][a-zA-Z0-9_]*` +- Whitespace: skipped silently + +**Parse errors with spans:** +```rust +pub struct Span { pub start: usize, pub end: usize } + +pub enum ParseError { + UnexpectedToken { expected: &'static str, got: Box, span: Span }, + UnexpectedEof { expected: &'static str }, + UnknownFunction { name: Box, span: Span }, + InvalidNumber { lexeme: Box, span: Span }, +} +``` + +### 4. Display and Serialization + +**Display** uses minimal parenthesization: +- Parenthesize child when its precedence < parent precedence +- For right-associative `^`: parenthesize left child if same precedence +- Integer-valued floats: `3.0` → `"3"`, `1.5` → `"1.5"` +- Round-trip invariant: `parse(expr.to_string())` ≡ `expr` semantically + +**Serde:** `Expr` serializes as its display string, deserializes by parsing. This means JSON contains human-readable expressions like `"1.44 ^ num_vertices"`. + +### 5. Integration with ReductionOverhead + +**`ReductionOverhead` stores `Expr`:** +```rust +pub struct ReductionOverhead { + pub output_size: Vec<(&'static str, Expr)>, +} +``` + +**Constructor takes string pairs:** +```rust +impl ReductionOverhead { + pub fn new(specs: Vec<(&'static str, &'static str)>) -> Self { + // Parses each expression string immediately. + // Panics on parse error (developer bug — these are static literals). + } +} +``` + +**`evaluate_output_size` returns `Result`:** +```rust +pub fn evaluate_output_size(&self, input: &ProblemSize) -> Result +``` + +Float → usize conversion: `round()` (matching current behavior), error on non-finite/negative/overflow. + +**Reduction macro usage changes from:** +```rust +#[reduction(overhead = { + ReductionOverhead::new(vec![ + ("num_vertices", poly!(num_vertices)), + ("num_edges", poly!(num_edges)), + ]) +})] +``` + +**To:** +```rust +#[reduction(overhead = { + ReductionOverhead::new(vec![ + ("num_vertices", "num_vertices"), + ("num_edges", "num_edges"), + ]) +})] +``` + +The `#[reduction]` proc macro itself needs no changes (it passes through the token stream). + +**JSON export format:** +```json +[ + {"field": "num_vertices", "expression": "num_vertices ^ 2"}, + {"field": "num_edges", "expression": "1.44 ^ num_vertices"} +] +``` + +### 6. File Organization + +| Action | File | Description | +|---------|------|-------------| +| **New** | `src/expr.rs` | `Expr`, `BinOp`, `Func`, parser, evaluator, Display, Serde (~400-500 lines) | +| **New** | `src/unit_tests/expr.rs` | Parser, eval, display round-trip, error case tests | +| **Delete** | `src/polynomial.rs` | Replaced by `src/expr.rs` | +| **Delete** | `src/unit_tests/polynomial.rs` | Replaced by `src/unit_tests/expr.rs` | +| **Modify** | `src/lib.rs` | `mod polynomial` → `mod expr`, update re-exports | +| **Modify** | `src/rules/registry.rs` | `Polynomial` → `Expr`, `new()` takes `(&str, &str)` pairs | +| **Modify** | `src/export.rs` | Remove `MonomialJson`, use expression strings | +| **Modify** | `src/rules/cost.rs` | Propagate `Result` from `evaluate_output_size` | +| **Modify** | `src/rules/graph.rs` | Propagate `Result` from `evaluate_output_size` | +| **Modify** | `src/rules/*.rs` (all reductions) | `poly!(x)` → `"x"` string literals | diff --git a/docs/plans/2026-02-15-symbolic-expr-plan.md b/docs/plans/2026-02-15-symbolic-expr-plan.md new file mode 100644 index 00000000..384c5fc4 --- /dev/null +++ b/docs/plans/2026-02-15-symbolic-expr-plan.md @@ -0,0 +1,1164 @@ +# Symbolic Expression System — Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Replace `Polynomial`/`Monomial`/`poly!` with a general-purpose symbolic expression DSL supporting exponentials, logs, min/max, floor/ceil. + +**Architecture:** New `src/expr.rs` module with AST, Pratt parser, evaluator, and Display. `ReductionOverhead` switches from `Polynomial` to `Expr` parsed from string literals. All 30 reduction files migrate from `poly!()` to string specs. + +**Tech Stack:** Pure Rust, no new dependencies. Pratt parser hand-written. + +**Design doc:** `docs/plans/2026-02-15-symbolic-expr-design.md` + +--- + +### Task 1: Core AST and Evaluator + +**Files:** +- Create: `src/expr.rs` +- Create: `src/unit_tests/expr.rs` + +**Step 1: Write failing tests for AST construction and evaluation** + +Create `src/unit_tests/expr.rs`: + +```rust +use super::*; +use crate::types::ProblemSize; + +#[test] +fn test_eval_num() { + let expr = Expr::Num(42.0); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 42.0); +} + +#[test] +fn test_eval_var() { + let expr = Expr::Var("n".into()); + let size = ProblemSize::new(vec![("n", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 10.0); +} + +#[test] +fn test_eval_unknown_var() { + let expr = Expr::Var("missing".into()); + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::UnknownVar(_)))); +} + +#[test] +fn test_eval_add() { + let expr = Expr::binop(BinOp::Add, Expr::Num(3.0), Expr::Num(4.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_sub() { + let expr = Expr::binop(BinOp::Sub, Expr::Num(10.0), Expr::Num(3.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_mul() { + let expr = Expr::binop(BinOp::Mul, Expr::Num(3.0), Expr::Var("n".into())); + let size = ProblemSize::new(vec![("n", 5)]); + assert_eq!(expr.evaluate(&size).unwrap(), 15.0); +} + +#[test] +fn test_eval_div() { + let expr = Expr::binop(BinOp::Div, Expr::Num(10.0), Expr::Num(4.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 2.5); +} + +#[test] +fn test_eval_div_by_zero() { + let expr = Expr::binop(BinOp::Div, Expr::Num(1.0), Expr::Num(0.0)); + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::DivideByZero))); +} + +#[test] +fn test_eval_pow() { + let expr = Expr::binop(BinOp::Pow, Expr::Num(2.0), Expr::Num(10.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 1024.0); +} + +#[test] +fn test_eval_pow_fractional_base_negative() { + // negative base with non-integer exponent -> domain error + let expr = Expr::binop(BinOp::Pow, Expr::Num(-2.0), Expr::Num(0.5)); + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::Domain { .. }))); +} + +#[test] +fn test_eval_neg() { + let expr = Expr::Neg(Box::new(Expr::Num(5.0))); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), -5.0); +} + +#[test] +fn test_eval_log2() { + let expr = Expr::Call { func: Func::Log2, args: vec![Expr::Num(8.0)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_log2_negative() { + let expr = Expr::Call { func: Func::Log2, args: vec![Expr::Num(-1.0)] }; + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::Domain { .. }))); +} + +#[test] +fn test_eval_sqrt() { + let expr = Expr::Call { func: Func::Sqrt, args: vec![Expr::Num(25.0)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 5.0); +} + +#[test] +fn test_eval_min() { + let expr = Expr::Call { func: Func::Min, args: vec![Expr::Num(3.0), Expr::Num(7.0)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_max() { + let expr = Expr::Call { func: Func::Max, args: vec![Expr::Num(3.0), Expr::Num(7.0)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_floor() { + let expr = Expr::Call { func: Func::Floor, args: vec![Expr::Num(3.7)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_ceil() { + let expr = Expr::Call { func: Func::Ceil, args: vec![Expr::Num(3.2)] }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 4.0); +} + +#[test] +fn test_eval_arity_error() { + let expr = Expr::Call { func: Func::Log2, args: vec![Expr::Num(1.0), Expr::Num(2.0)] }; + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::Arity { .. }))); +} + +#[test] +fn test_eval_complex() { + // 3 * n ^ 2 + 1.44 ^ m + let expr = Expr::binop( + BinOp::Add, + Expr::binop(BinOp::Mul, Expr::Num(3.0), Expr::binop(BinOp::Pow, Expr::Var("n".into()), Expr::Num(2.0))), + Expr::binop(BinOp::Pow, Expr::Num(1.44), Expr::Var("m".into())), + ); + let size = ProblemSize::new(vec![("n", 4), ("m", 3)]); + let result = expr.evaluate(&size).unwrap(); + let expected = 3.0 * 16.0 + 1.44_f64.powi(3); + assert!((result - expected).abs() < 1e-10); +} +``` + +**Step 2: Run tests to verify they fail** + +Run: `cargo test expr --lib` +Expected: compilation error (module doesn't exist) + +**Step 3: Implement AST types and evaluator** + +Create `src/expr.rs` with: + +```rust +//! Symbolic expression system for reduction overhead. +//! +//! Provides a DSL for expressing how problem sizes transform during reductions. +//! Supports arithmetic, exponentiation, and built-in math functions. + +use crate::types::ProblemSize; +use std::fmt; + +/// A symbolic expression over named variables. +#[derive(Clone, Debug)] +pub enum Expr { + /// Numeric literal. + Num(f64), + /// Named variable (e.g., `num_vertices`). + Var(Box), + /// Binary operation. + BinOp { op: BinOp, lhs: Box, rhs: Box }, + /// Unary negation. + Neg(Box), + /// Built-in function call. + Call { func: Func, args: Vec }, +} + +/// Binary operators. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BinOp { Add, Sub, Mul, Div, Pow } + +/// Built-in functions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Func { Log2, Log10, Ln, Exp, Sqrt, Min, Max, Floor, Ceil, Abs } + +/// Evaluation error. +#[derive(Debug)] +pub enum EvalError { + UnknownVar(Box), + DivideByZero, + Arity { func: Func, expected: usize, got: usize }, + Domain { func: Func, detail: Box }, +} + +impl Expr { + /// Convenience constructor for binary operations. + pub fn binop(op: BinOp, lhs: Expr, rhs: Expr) -> Self { + Expr::BinOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs) } + } + + /// Evaluate the expression given variable bindings from `ProblemSize`. + pub fn evaluate(&self, size: &ProblemSize) -> Result { + match self { + Expr::Num(v) => Ok(*v), + Expr::Var(name) => size + .get(name) + .map(|v| v as f64) + .ok_or_else(|| EvalError::UnknownVar(name.clone())), + Expr::Neg(inner) => Ok(-inner.evaluate(size)?), + Expr::BinOp { op, lhs, rhs } => { + let l = lhs.evaluate(size)?; + let r = rhs.evaluate(size)?; + match op { + BinOp::Add => Ok(l + r), + BinOp::Sub => Ok(l - r), + BinOp::Mul => Ok(l * r), + BinOp::Div => { + if r == 0.0 { return Err(EvalError::DivideByZero); } + Ok(l / r) + } + BinOp::Pow => { + if l < 0.0 && r.fract() != 0.0 { + return Err(EvalError::Domain { + func: Func::Sqrt, // closest built-in + detail: "negative base with non-integer exponent".into(), + }); + } + let result = l.powf(r); + if result.is_nan() || result.is_infinite() { + return Err(EvalError::Domain { + func: Func::Exp, + detail: format!("{l} ^ {r} produced non-finite result").into(), + }); + } + Ok(result) + } + } + } + Expr::Call { func, args } => eval_func(*func, args, size), + } + } +} + +fn eval_func(func: Func, args: &[Expr], size: &ProblemSize) -> Result { + // Check arity + let (min_args, max_args) = match func { + Func::Min | Func::Max => (2, 2), + _ => (1, 1), + }; + if args.len() < min_args || args.len() > max_args { + return Err(EvalError::Arity { func, expected: min_args, got: args.len() }); + } + + let a = args[0].evaluate(size)?; + match func { + Func::Log2 => { + if a <= 0.0 { return Err(EvalError::Domain { func, detail: "log2 of non-positive".into() }); } + Ok(a.log2()) + } + Func::Log10 => { + if a <= 0.0 { return Err(EvalError::Domain { func, detail: "log10 of non-positive".into() }); } + Ok(a.log10()) + } + Func::Ln => { + if a <= 0.0 { return Err(EvalError::Domain { func, detail: "ln of non-positive".into() }); } + Ok(a.ln()) + } + Func::Exp => { + let result = a.exp(); + if result.is_infinite() { + return Err(EvalError::Domain { func, detail: "exp overflow".into() }); + } + Ok(result) + } + Func::Sqrt => { + if a < 0.0 { return Err(EvalError::Domain { func, detail: "sqrt of negative".into() }); } + Ok(a.sqrt()) + } + Func::Abs => Ok(a.abs()), + Func::Floor => Ok(a.floor()), + Func::Ceil => Ok(a.ceil()), + Func::Min => { let b = args[1].evaluate(size)?; Ok(a.min(b)) } + Func::Max => { let b = args[1].evaluate(size)?; Ok(a.max(b)) } + } +} +``` + +**Step 4: Wire up the module and test file** + +In `src/expr.rs`, add at the bottom: +```rust +#[cfg(test)] +#[path = "unit_tests/expr.rs"] +mod tests; +``` + +In `src/lib.rs`, add: `pub mod expr;` (keep `polynomial` for now — we'll remove it later). + +**Step 5: Run tests to verify they pass** + +Run: `cargo test expr --lib` +Expected: all tests pass + +**Step 6: Commit** + +```bash +git add src/expr.rs src/unit_tests/expr.rs src/lib.rs +git commit -m "feat: add symbolic expression AST and evaluator" +``` + +--- + +### Task 2: Parser (Tokenizer + Pratt) + +**Files:** +- Modify: `src/expr.rs` +- Modify: `src/unit_tests/expr.rs` + +**Step 1: Write failing parser tests** + +Add to `src/unit_tests/expr.rs`: + +```rust +#[test] +fn test_parse_num() { + let expr = Expr::parse("42").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 42.0); +} + +#[test] +fn test_parse_float() { + let expr = Expr::parse("1.44").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 1.44); +} + +#[test] +fn test_parse_var() { + let expr = Expr::parse("num_vertices").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 10.0); +} + +#[test] +fn test_parse_add() { + let expr = Expr::parse("3 + 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 7.0); +} + +#[test] +fn test_parse_precedence_mul_add() { + // 2 + 3 * 4 = 14 (not 20) + let expr = Expr::parse("2 + 3 * 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 14.0); +} + +#[test] +fn test_parse_precedence_pow() { + // 2 ^ 3 ^ 2 = 2 ^ 9 = 512 (right-associative) + let expr = Expr::parse("2 ^ 3 ^ 2").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 512.0); +} + +#[test] +fn test_parse_unary_neg() { + let expr = Expr::parse("-5").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), -5.0); +} + +#[test] +fn test_parse_neg_pow() { + // -2 ^ 2 = -(2^2) = -4 + let expr = Expr::parse("-2 ^ 2").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), -4.0); +} + +#[test] +fn test_parse_parens() { + let expr = Expr::parse("(2 + 3) * 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 20.0); +} + +#[test] +fn test_parse_function_log2() { + let expr = Expr::parse("log2(8)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 3.0); +} + +#[test] +fn test_parse_function_max() { + let expr = Expr::parse("max(3, 7)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 7.0); +} + +#[test] +fn test_parse_function_case_insensitive() { + let expr = Expr::parse("Log2(8)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 3.0); +} + +#[test] +fn test_parse_complex_expression() { + // 3 * num_vertices ^ 2 + 1.44 ^ num_edges + let expr = Expr::parse("3 * num_vertices ^ 2 + 1.44 ^ num_edges").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 4), ("num_edges", 3)]); + let expected = 3.0 * 16.0 + 1.44_f64.powi(3); + assert!((expr.evaluate(&size).unwrap() - expected).abs() < 1e-10); +} + +#[test] +fn test_parse_nested_functions() { + let expr = Expr::parse("floor(log2(16))").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 4.0); +} + +#[test] +fn test_parse_unknown_function() { + let result = Expr::parse("foo(3)"); + assert!(matches!(result, Err(ParseError::UnknownFunction { .. }))); +} + +#[test] +fn test_parse_unexpected_eof() { + let result = Expr::parse("3 +"); + assert!(result.is_err()); +} + +#[test] +fn test_parse_empty() { + let result = Expr::parse(""); + assert!(result.is_err()); +} + +#[test] +fn test_parse_leading_dot() { + let expr = Expr::parse(".5").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 0.5); +} + +#[test] +fn test_parse_subtraction() { + let expr = Expr::parse("10 - 3 - 2").unwrap(); + // left-associative: (10 - 3) - 2 = 5 + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 5.0); +} + +#[test] +fn test_parse_division() { + let expr = Expr::parse("num_vertices / 2").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 5.0); +} +``` + +**Step 2: Run tests to verify they fail** + +Run: `cargo test expr --lib` +Expected: `Expr::parse` doesn't exist + +**Step 3: Implement tokenizer and parser** + +Add to `src/expr.rs`: + +- `Span` struct +- `ParseError` enum +- `Token` enum (private) +- `Lexer` struct (private) — iterates chars, produces tokens with spans +- `Parser` struct (private) — Pratt parser consuming tokens +- `Expr::parse(input: &str) -> Result` public entry point + +The Pratt parser uses binding powers: +- `+`, `-`: left bp = 1, right bp = 2 +- `*`, `/`: left bp = 3, right bp = 4 +- prefix `-`: right bp = 5 +- `^`: left bp = 7, right bp = 6 (right-assoc: left > right) + +Function names resolved via a match on the lowercased ident. + +**Step 4: Run tests to verify they pass** + +Run: `cargo test expr --lib` +Expected: all pass + +**Step 5: Commit** + +```bash +git add src/expr.rs src/unit_tests/expr.rs +git commit -m "feat: add expression parser with Pratt precedence climbing" +``` + +--- + +### Task 3: Display and Serde + +**Files:** +- Modify: `src/expr.rs` +- Modify: `src/unit_tests/expr.rs` + +**Step 1: Write failing tests for Display and round-tripping** + +Add to `src/unit_tests/expr.rs`: + +```rust +#[test] +fn test_display_num_integer() { + let expr = Expr::Num(3.0); + assert_eq!(expr.to_string(), "3"); +} + +#[test] +fn test_display_num_float() { + let expr = Expr::Num(1.44); + assert_eq!(expr.to_string(), "1.44"); +} + +#[test] +fn test_display_var() { + let expr = Expr::Var("num_vertices".into()); + assert_eq!(expr.to_string(), "num_vertices"); +} + +#[test] +fn test_display_add() { + let expr = Expr::parse("a + b").unwrap(); + assert_eq!(expr.to_string(), "a + b"); +} + +#[test] +fn test_display_precedence() { + let expr = Expr::parse("a + b * c").unwrap(); + assert_eq!(expr.to_string(), "a + b * c"); +} + +#[test] +fn test_display_parens_needed() { + let expr = Expr::parse("(a + b) * c").unwrap(); + assert_eq!(expr.to_string(), "(a + b) * c"); +} + +#[test] +fn test_display_pow() { + let expr = Expr::parse("1.44 ^ n").unwrap(); + assert_eq!(expr.to_string(), "1.44 ^ n"); +} + +#[test] +fn test_display_neg() { + let expr = Expr::parse("-x").unwrap(); + assert_eq!(expr.to_string(), "-x"); +} + +#[test] +fn test_display_neg_compound() { + let expr = Expr::parse("-(a + b)").unwrap(); + assert_eq!(expr.to_string(), "-(a + b)"); +} + +#[test] +fn test_display_func() { + let expr = Expr::parse("log2(n)").unwrap(); + assert_eq!(expr.to_string(), "log2(n)"); +} + +#[test] +fn test_display_func_two_args() { + let expr = Expr::parse("max(a, b)").unwrap(); + assert_eq!(expr.to_string(), "max(a, b)"); +} + +#[test] +fn test_roundtrip_complex() { + let cases = vec![ + "3 * n ^ 2 + 1.44 ^ m", + "log2(n) * m", + "max(n, m) + 1", + "floor(n / 2)", + "-(a + b) * c", + "a ^ b ^ c", + "a - b - c", + ]; + let size = ProblemSize::new(vec![("n", 4), ("m", 3), ("a", 2), ("b", 3), ("c", 5)]); + for case in cases { + let expr1 = Expr::parse(case).unwrap(); + let displayed = expr1.to_string(); + let expr2 = Expr::parse(&displayed).unwrap(); + let v1 = expr1.evaluate(&size).unwrap(); + let v2 = expr2.evaluate(&size).unwrap(); + assert!((v1 - v2).abs() < 1e-10, "Round-trip failed for {case}: displayed as {displayed}"); + } +} + +#[test] +fn test_serde_roundtrip() { + let expr = Expr::parse("3 * n ^ 2 + 1").unwrap(); + let json = serde_json::to_string(&expr).unwrap(); + let back: Expr = serde_json::from_str(&json).unwrap(); + let size = ProblemSize::new(vec![("n", 5)]); + assert_eq!(expr.evaluate(&size).unwrap(), back.evaluate(&size).unwrap()); +} +``` + +**Step 2: Run tests to verify they fail** + +Run: `cargo test expr --lib` +Expected: `Display` not implemented, serde not implemented + +**Step 3: Implement Display** + +Add `impl fmt::Display for Expr` to `src/expr.rs`: +- Use a helper that takes parent precedence context +- Parenthesize when child precedence < parent precedence +- For `^`, parenthesize LHS if same precedence (right-assoc) +- Integer floats display without decimal point +- Function names: lowercase canonical form + +Also add `impl fmt::Display for EvalError` and `impl fmt::Display for ParseError`. + +**Step 4: Implement Serde** + +Add custom `Serialize`/`Deserialize` for `Expr`: +- `Serialize`: delegates to `Display` (serializes as string) +- `Deserialize`: calls `Expr::parse` (deserializes from string) + +**Step 5: Run tests to verify they pass** + +Run: `cargo test expr --lib` +Expected: all pass + +**Step 6: Commit** + +```bash +git add src/expr.rs src/unit_tests/expr.rs +git commit -m "feat: add Display and Serde for Expr (string round-trip)" +``` + +--- + +### Task 4: Migrate ReductionOverhead and Registry + +**Files:** +- Modify: `src/rules/registry.rs` +- Modify: `src/unit_tests/rules/registry.rs` + +**Step 1: Write updated tests** + +Update `src/unit_tests/rules/registry.rs` — change from `poly!` to string specs: + +```rust +use crate::rules::registry::ReductionOverhead; +use crate::types::ProblemSize; + +#[test] +fn test_reduction_overhead_evaluate() { + let overhead = ReductionOverhead::new(vec![("n", "3 * m"), ("m", "m ^ 2")]); + let input = ProblemSize::new(vec![("m", 4)]); + let output = overhead.evaluate_output_size(&input).unwrap(); + assert_eq!(output.get("n"), Some(12)); + assert_eq!(output.get("m"), Some(16)); +} +``` + +Update the `ReductionEntry` test similarly — use `ReductionOverhead::new(vec![("n", "2 * n")])`. + +**Step 2: Run tests to verify they fail** + +Run: `cargo test registry --lib` +Expected: type mismatch (still expects `Polynomial`) + +**Step 3: Update `src/rules/registry.rs`** + +Replace: +```rust +use crate::polynomial::Polynomial; +``` +with: +```rust +use crate::expr::Expr; +``` + +Change `ReductionOverhead`: +```rust +pub struct ReductionOverhead { + pub output_size: Vec<(&'static str, Expr)>, +} + +impl ReductionOverhead { + pub fn new(specs: Vec<(&'static str, &'static str)>) -> Self { + Self { + output_size: specs + .into_iter() + .map(|(field, expr_str)| { + let expr = Expr::parse(expr_str).unwrap_or_else(|e| { + panic!("invalid overhead expression for '{field}': {e}") + }); + (field, expr) + }) + .collect(), + } + } + + pub fn evaluate_output_size(&self, input: &ProblemSize) -> Result { + let mut fields = Vec::new(); + for (name, expr) in &self.output_size { + let val = expr.evaluate(input)?; + let rounded = val.round(); + if !rounded.is_finite() || rounded < 0.0 || rounded > usize::MAX as f64 { + return Err(crate::expr::EvalError::Domain { + func: crate::expr::Func::Floor, + detail: format!("overhead for '{name}' produced out-of-range value: {val}").into(), + }); + } + fields.push((*name, rounded as usize)); + } + Ok(ProblemSize::new(fields)) + } +} +``` + +Keep `Default` impl producing empty `output_size`. + +**Step 4: Run tests to verify they pass** + +Run: `cargo test registry --lib` +Expected: pass + +**Step 5: Commit** + +```bash +git add src/rules/registry.rs src/unit_tests/rules/registry.rs +git commit -m "refactor: ReductionOverhead uses Expr parsed from strings" +``` + +--- + +### Task 5: Migrate Export System + +**Files:** +- Modify: `src/export.rs` +- Modify: `src/unit_tests/export.rs` + +**Step 1: Write updated tests** + +Update `src/unit_tests/export.rs` — `overhead_to_json` now returns `Vec` where each entry has `field` and `expression` (a string): + +```rust +use crate::export::overhead_to_json; +use crate::rules::registry::ReductionOverhead; + +#[test] +fn test_overhead_to_json_empty() { + let overhead = ReductionOverhead::default(); + let entries = overhead_to_json(&overhead); + assert!(entries.is_empty()); +} + +#[test] +fn test_overhead_to_json_single_field() { + let overhead = ReductionOverhead::new(vec![("num_vertices", "n + m")]); + let entries = overhead_to_json(&overhead); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].field, "num_vertices"); + assert_eq!(entries[0].expression, "n + m"); +} + +#[test] +fn test_overhead_to_json_multiple_fields() { + let overhead = ReductionOverhead::new(vec![ + ("num_vertices", "n ^ 2"), + ("num_edges", "1.44 ^ n"), + ]); + let entries = overhead_to_json(&overhead); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].expression, "n ^ 2"); + assert_eq!(entries[1].expression, "1.44 ^ n"); +} +``` + +**Step 2: Run tests to verify they fail** + +Run: `cargo test export --lib` +Expected: `MonomialJson` still expected + +**Step 3: Update `src/export.rs`** + +Remove `MonomialJson`. Simplify `OverheadEntry`: +```rust +#[derive(Serialize, Clone, Debug)] +pub struct OverheadEntry { + pub field: String, + pub expression: String, +} +``` + +Simplify `overhead_to_json`: +```rust +pub fn overhead_to_json(overhead: &ReductionOverhead) -> Vec { + overhead + .output_size + .iter() + .map(|(field, expr)| OverheadEntry { + field: field.to_string(), + expression: expr.to_string(), + }) + .collect() +} +``` + +Also update `ReductionData.overhead` field type from `Vec` — this should already work since `OverheadEntry` is still `Serialize`. + +**Step 4: Run tests to verify they pass** + +Run: `cargo test export --lib` +Expected: pass + +**Step 5: Commit** + +```bash +git add src/export.rs src/unit_tests/export.rs +git commit -m "refactor: simplify export overhead to expression strings" +``` + +--- + +### Task 6: Migrate Cost Functions + +**Files:** +- Modify: `src/rules/cost.rs` +- Modify: `src/unit_tests/rules/cost.rs` + +**Step 1: Update tests** + +In `src/unit_tests/rules/cost.rs`, change the helper: +```rust +fn test_overhead() -> ReductionOverhead { + ReductionOverhead::new(vec![ + ("n", "2 * n"), + ("m", "m"), + ]) +} +``` + +Remove the `use crate::polynomial::Polynomial;` import. + +Existing test assertions should still hold since the values are the same. + +**Step 2: Run tests to verify they fail** + +Run: `cargo test cost --lib` +Expected: fail on `Polynomial` import + +**Step 3: Update `src/rules/cost.rs`** + +`evaluate_output_size` now returns `Result`. For cost functions, we should unwrap since overhead expressions are known-good at this point. Change each call from: +```rust +overhead.evaluate_output_size(size).get(self.0).unwrap_or(0) as f64 +``` +to: +```rust +overhead.evaluate_output_size(size) + .expect("overhead evaluation failed") + .get(self.0).unwrap_or(0) as f64 +``` + +Apply same pattern to `MinimizeWeighted`, `MinimizeMax`, `MinimizeLexicographic`. + +**Step 4: Run tests to verify they pass** + +Run: `cargo test cost --lib` +Expected: pass + +**Step 5: Commit** + +```bash +git add src/rules/cost.rs src/unit_tests/rules/cost.rs +git commit -m "refactor: migrate cost functions to Expr-based overhead" +``` + +--- + +### Task 7: Migrate Graph Module + +**Files:** +- Modify: `src/rules/graph.rs` + +**Step 1: Update evaluate_output_size call** + +At line ~420, change: +```rust +let new_size = edge.overhead.evaluate_output_size(¤t_size); +``` +to: +```rust +let new_size = edge.overhead.evaluate_output_size(¤t_size) + .expect("overhead evaluation failed during path finding"); +``` + +At line ~712, the `to_json` method uses `poly.to_string()` — since `Expr` also implements `Display`, this already works. Just verify the field name change: the `OverheadFieldJson.formula` field now comes from `Expr::to_string()` which should produce equivalent output. + +**Step 2: Run: `cargo test graph --lib`** + +Expected: pass (no behavioral change) + +**Step 3: Commit** + +```bash +git add src/rules/graph.rs +git commit -m "refactor: migrate graph module to Expr-based overhead" +``` + +--- + +### Task 8: Migrate All 30 Reduction Files (Simple Ones) + +**Files:** +- Modify: 28 reduction files that use `poly!()` macro + +These all follow the same mechanical transformation. The `poly!` expressions map to string literals: + +| Old `poly!` syntax | New string literal | +|---|---| +| `poly!(num_vertices)` | `"num_vertices"` | +| `poly!(num_vertices ^ 2)` | `"num_vertices ^ 2"` | +| `poly!(3 * num_vars)` | `"3 * num_vars"` | +| `poly!(num_vertices * num_edges)` | `"num_vertices * num_edges"` | +| `poly!(3 * num_vars) + poly!(num_clauses)` | `"3 * num_vars + num_clauses"` | +| `poly!(num_clauses) + poly!(num_literals)` | `"num_clauses + num_literals"` | +| `poly!(num_clauses).scale(-5.0)` | `"-5 * num_clauses"` | +| `poly!(2 * num_vars) + poly!(5 * num_literals) + poly!(num_clauses).scale(-5.0) + poly!(3)` | `"2 * num_vars + 5 * num_literals - 5 * num_clauses + 3"` | + +**Step 1: Migrate files alphabetically** + +For each file, replace the `ReductionOverhead::new(vec![...])` body. Remove any `use crate::polynomial::Polynomial;` or `use crate::poly;` imports. The `#[reduction(overhead = { ... })]` wrapper stays the same. + +Example — `src/rules/minimumvertexcover_maximumindependentset.rs`: + +Old: +```rust +ReductionOverhead::new(vec![ + ("num_vertices", poly!(num_vertices)), + ("num_edges", poly!(num_edges)), +]) +``` + +New: +```rust +ReductionOverhead::new(vec![ + ("num_vertices", "num_vertices"), + ("num_edges", "num_edges"), +]) +``` + +**Step 2: Run `cargo test --lib` after each batch of ~5 files** + +Expected: pass + +**Step 3: Commit after all simple files** + +```bash +git add src/rules/*.rs +git commit -m "refactor: migrate 28 reduction files from poly! to string expressions" +``` + +--- + +### Task 9: Migrate Complex Reduction Files + +**Files:** +- Modify: `src/rules/travelingsalesman_ilp.rs` +- Modify: `src/rules/factoring_ilp.rs` + +These two files construct `Polynomial`/`Monomial` structs directly. + +**Step 1: Migrate `travelingsalesman_ilp.rs`** + +Old (manual Monomial/Polynomial construction): +```rust +("num_vars", Polynomial::var_pow("num_vertices", 2) + Polynomial { + terms: vec![Monomial { + coefficient: 2.0, + variables: vec![("num_vertices", 1), ("num_edges", 1)], + }] +}), +("num_constraints", Polynomial::var_pow("num_vertices", 3) + Polynomial { + terms: vec![ + Monomial { coefficient: -1.0, variables: vec![("num_vertices", 2)] }, + Monomial { coefficient: 2.0, variables: vec![("num_vertices", 1)] }, + Monomial { coefficient: 4.0, variables: vec![("num_vertices", 1), ("num_edges", 1)] }, + ] +}), +``` + +New: +```rust +("num_vars", "num_vertices ^ 2 + 2 * num_vertices * num_edges"), +("num_constraints", "num_vertices ^ 3 - num_vertices ^ 2 + 2 * num_vertices + 4 * num_vertices * num_edges"), +``` + +Remove `use crate::polynomial::{Monomial, Polynomial};`. + +**Step 2: Migrate `factoring_ilp.rs`** + +Old: +```rust +("num_vars", Polynomial { terms: vec![ + Monomial::var("num_bits_first").scale(2.0), + Monomial::var("num_bits_second").scale(2.0), + Monomial { coefficient: 1.0, variables: vec![("num_bits_first", 1), ("num_bits_second", 1)] }, +] }), +("num_constraints", Polynomial { terms: vec![ + Monomial { coefficient: 3.0, variables: vec![("num_bits_first", 1), ("num_bits_second", 1)] }, + Monomial::var("num_bits_first"), + Monomial::var("num_bits_second"), + Monomial::constant(1.0), +] }), +``` + +New: +```rust +("num_vars", "2 * num_bits_first + 2 * num_bits_second + num_bits_first * num_bits_second"), +("num_constraints", "3 * num_bits_first * num_bits_second + num_bits_first + num_bits_second + 1"), +``` + +Remove `use crate::polynomial::{Monomial, Polynomial};`. + +**Step 3: Run `cargo test --lib`** + +Expected: pass + +**Step 4: Commit** + +```bash +git add src/rules/travelingsalesman_ilp.rs src/rules/factoring_ilp.rs +git commit -m "refactor: migrate complex reduction overheads to string expressions" +``` + +--- + +### Task 10: Delete Polynomial Module and poly! Macro + +**Files:** +- Delete: `src/polynomial.rs` +- Delete: `src/unit_tests/polynomial.rs` +- Modify: `src/lib.rs` — remove `pub mod polynomial;` +- Modify: `src/rules/mod.rs` — remove any `poly!` re-export if present + +**Step 1: Remove `pub mod polynomial;` from `src/lib.rs`** + +**Step 2: Delete `src/polynomial.rs` and `src/unit_tests/polynomial.rs`** + +**Step 3: Search for any remaining references** + +Run: `grep -r "polynomial\|poly!" src/ --include="*.rs"` +Expected: no matches (only in test data or docs) + +**Step 4: Run full test suite** + +Run: `make test` +Expected: all pass + +**Step 5: Commit** + +```bash +git rm src/polynomial.rs src/unit_tests/polynomial.rs +git add src/lib.rs +git commit -m "refactor: remove Polynomial/Monomial/poly! (replaced by Expr)" +``` + +--- + +### Task 11: Update Examples + +**Files:** +- Modify: all `examples/reduction_*.rs` files (~30 files) + +These files call `overhead_to_json(&overhead)`. The function signature is unchanged, but the output format changed (from `MonomialJson` array to `expression` string). The `ReductionData` struct's `overhead` field type is `Vec` which is still `Serialize`. + +**Step 1: Verify examples compile** + +Run: `cargo build --examples` +Expected: pass (no source changes needed — examples call `overhead_to_json` which still works) + +If any example directly constructs `MonomialJson` or `Polynomial`, update it. + +**Step 2: Regenerate example JSON outputs** + +Run: `make examples` +Expected: JSON files regenerated with new `expression` string format + +**Step 3: Verify tests** + +Run: `make test` +Expected: all pass + +**Step 4: Commit** + +```bash +git add examples/ docs/paper/examples/ +git commit -m "chore: regenerate example JSON with expression string format" +``` + +--- + +### Task 12: Run Full Verification + +**Step 1: Format check** + +Run: `make fmt-check` +Expected: pass + +**Step 2: Clippy** + +Run: `make clippy` +Expected: no warnings + +**Step 3: Full test suite** + +Run: `make test` +Expected: all pass + +**Step 4: Build docs (includes reduction graph export)** + +Run: `make doc` +Expected: builds successfully, reduction_graph.json regenerated with expression strings + +**Step 5: Fix any issues found** + +**Step 6: Final commit if needed** + +```bash +git add -A +git commit -m "chore: fix formatting and clippy warnings" +``` diff --git a/problemreductions-cli/src/commands/reduce.rs b/problemreductions-cli/src/commands/reduce.rs index 211560d8..53032303 100644 --- a/problemreductions-cli/src/commands/reduce.rs +++ b/problemreductions-cli/src/commands/reduce.rs @@ -34,7 +34,10 @@ fn load_path_file(path_file: &Path) -> Result { anyhow::bail!("Path file must contain at least one reduction step"); } - Ok(ReductionPath { steps }) + Ok(ReductionPath { + steps, + overheads: vec![], + }) } fn parse_path_node(node: &serde_json::Value) -> Result { diff --git a/problemreductions-cli/src/commands/solve.rs b/problemreductions-cli/src/commands/solve.rs index 1a06352c..63b8ae0b 100644 --- a/problemreductions-cli/src/commands/solve.rs +++ b/problemreductions-cli/src/commands/solve.rs @@ -45,13 +45,9 @@ pub fn solve(input: &Path, solver_name: &str, timeout: u64, out: &OutputConfig) let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { let result = match parsed { - SolveInput::Problem(pj) => solve_problem( - &pj.problem_type, - &pj.variant, - pj.data, - &solver_name, - &out, - ), + SolveInput::Problem(pj) => { + solve_problem(&pj.problem_type, &pj.variant, pj.data, &solver_name, &out) + } SolveInput::Bundle(b) => solve_bundle(b, &solver_name, &out), }; tx.send(result).ok(); @@ -167,6 +163,7 @@ fn solve_bundle(bundle: ReductionBundle, solver_name: &str, out: &OutputConfig) variant: s.variant.clone(), }) .collect(), + overheads: vec![], }; let chain = graph diff --git a/src/export.rs b/src/export.rs index 32639855..3a8237a9 100644 --- a/src/export.rs +++ b/src/export.rs @@ -5,7 +5,7 @@ //! - `.json` — reduction structure (source, target, overhead) //! - `.result.json` — runtime solutions //! -//! The schema mirrors the internal types: `ReductionOverhead` for polynomials, +//! The schema mirrors the internal types: `ReductionOverhead` for overhead expressions, //! `Problem::variant()` for problem variants, and `Problem::NAME` for problem names. use crate::rules::registry::ReductionOverhead; @@ -26,18 +26,11 @@ pub struct ProblemSide { pub instance: serde_json::Value, } -/// A monomial in JSON: coefficient × Π(variable^exponent). -#[derive(Serialize, Clone, Debug)] -pub struct MonomialJson { - pub coefficient: f64, - pub variables: Vec<(String, u8)>, -} - -/// One output field mapped to a polynomial. +/// One output field mapped to an expression string. #[derive(Serialize, Clone, Debug)] pub struct OverheadEntry { pub field: String, - pub polynomial: Vec, + pub expression: String, } /// Top-level reduction structure (written to `.json`). @@ -66,20 +59,9 @@ pub fn overhead_to_json(overhead: &ReductionOverhead) -> Vec { overhead .output_size .iter() - .map(|(field, poly)| OverheadEntry { + .map(|(field, expr)| OverheadEntry { field: field.to_string(), - polynomial: poly - .terms - .iter() - .map(|m| MonomialJson { - coefficient: m.coefficient, - variables: m - .variables - .iter() - .map(|(name, exp)| (name.to_string(), *exp)) - .collect(), - }) - .collect(), + expression: expr.to_string(), }) .collect() } diff --git a/src/expr.rs b/src/expr.rs new file mode 100644 index 00000000..84e3833f --- /dev/null +++ b/src/expr.rs @@ -0,0 +1,761 @@ +//! Symbolic expression system for reduction overhead. +//! +//! Provides a DSL for expressing how problem sizes transform during reductions. +//! Supports arithmetic, exponentiation, and built-in math functions. + +use crate::types::ProblemSize; +use std::collections::HashSet; +use std::fmt; + +/// A symbolic expression over named variables. +#[derive(Clone, Debug)] +pub enum Expr { + /// Numeric literal. + Num(f64), + /// Named variable (e.g., `num_vertices`). + Var(Box), + /// Binary operation. + BinOp { + op: BinOp, + lhs: Box, + rhs: Box, + }, + /// Unary negation. + Neg(Box), + /// Built-in function call. + Call { func: Func, args: Vec }, +} + +/// Binary operators. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Pow, +} + +/// Built-in functions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Func { + Log2, + Log10, + Ln, + Exp, + Sqrt, + Min, + Max, + Floor, + Ceil, + Abs, +} + +/// Evaluation error. +#[derive(Debug)] +pub enum EvalError { + UnknownVar(Box), + DivideByZero, + Arity { + func: Func, + expected: usize, + got: usize, + }, + Domain { + func: Func, + detail: Box, + }, +} + +impl fmt::Display for EvalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + EvalError::UnknownVar(name) => write!(f, "unknown variable: {name}"), + EvalError::DivideByZero => write!(f, "division by zero"), + EvalError::Arity { + func, + expected, + got, + } => write!(f, "{func:?} expects {expected} args, got {got}"), + EvalError::Domain { func, detail } => write!(f, "{func:?}: {detail}"), + } + } +} + +impl std::error::Error for EvalError {} + +// ── AST construction ── + +impl Expr { + /// Convenience constructor for binary operations. + pub fn binop(op: BinOp, lhs: Expr, rhs: Expr) -> Self { + Expr::BinOp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + } + } +} + +// ── Analysis & Transformation ── + +impl Expr { + /// Collect all variable names referenced in this expression. + pub fn variable_names(&self) -> HashSet<&str> { + let mut names = HashSet::new(); + self.collect_variable_names(&mut names); + names + } + + fn collect_variable_names<'a>(&'a self, names: &mut HashSet<&'a str>) { + match self { + Expr::Num(_) => {} + Expr::Var(name) => { + names.insert(name); + } + Expr::Neg(inner) => inner.collect_variable_names(names), + Expr::BinOp { lhs, rhs, .. } => { + lhs.collect_variable_names(names); + rhs.collect_variable_names(names); + } + Expr::Call { args, .. } => { + for arg in args { + arg.collect_variable_names(names); + } + } + } + } + + /// Substitute variables in this expression using a mapping. + /// Variables found in the mapping are replaced by clones of the mapped expression; + /// variables not in the mapping are left unchanged. + pub fn substitute(&self, mapping: &std::collections::HashMap<&str, &Expr>) -> Expr { + match self { + Expr::Num(v) => Expr::Num(*v), + Expr::Var(name) => { + if let Some(replacement) = mapping.get(name.as_ref()) { + (*replacement).clone() + } else { + Expr::Var(name.clone()) + } + } + Expr::Neg(inner) => Expr::Neg(Box::new(inner.substitute(mapping))), + Expr::BinOp { op, lhs, rhs } => Expr::BinOp { + op: *op, + lhs: Box::new(lhs.substitute(mapping)), + rhs: Box::new(rhs.substitute(mapping)), + }, + Expr::Call { func, args } => Expr::Call { + func: *func, + args: args.iter().map(|a| a.substitute(mapping)).collect(), + }, + } + } +} + +// ── Evaluator ── + +impl Expr { + /// Evaluate the expression given variable bindings from `ProblemSize`. + pub fn evaluate(&self, size: &ProblemSize) -> Result { + match self { + Expr::Num(v) => Ok(*v), + Expr::Var(name) => Ok(size.get(name).unwrap_or(0) as f64), + Expr::Neg(inner) => Ok(-inner.evaluate(size)?), + Expr::BinOp { op, lhs, rhs } => { + let l = lhs.evaluate(size)?; + let r = rhs.evaluate(size)?; + match op { + BinOp::Add => Ok(l + r), + BinOp::Sub => Ok(l - r), + BinOp::Mul => Ok(l * r), + BinOp::Div => { + if r == 0.0 { + return Err(EvalError::DivideByZero); + } + Ok(l / r) + } + BinOp::Pow => { + if l < 0.0 && r.fract() != 0.0 { + return Err(EvalError::Domain { + func: Func::Sqrt, + detail: "negative base with non-integer exponent".into(), + }); + } + let result = l.powf(r); + if result.is_nan() || result.is_infinite() { + return Err(EvalError::Domain { + func: Func::Exp, + detail: format!("{l} ^ {r} produced non-finite result").into(), + }); + } + Ok(result) + } + } + } + Expr::Call { func, args } => eval_func(*func, args, size), + } + } +} + +fn eval_func(func: Func, args: &[Expr], size: &ProblemSize) -> Result { + let (min_args, max_args) = match func { + Func::Min | Func::Max => (2, 2), + _ => (1, 1), + }; + if args.len() < min_args || args.len() > max_args { + return Err(EvalError::Arity { + func, + expected: min_args, + got: args.len(), + }); + } + + let a = args[0].evaluate(size)?; + match func { + Func::Log2 => { + if a <= 0.0 { + return Err(EvalError::Domain { + func, + detail: "log2 of non-positive".into(), + }); + } + Ok(a.log2()) + } + Func::Log10 => { + if a <= 0.0 { + return Err(EvalError::Domain { + func, + detail: "log10 of non-positive".into(), + }); + } + Ok(a.log10()) + } + Func::Ln => { + if a <= 0.0 { + return Err(EvalError::Domain { + func, + detail: "ln of non-positive".into(), + }); + } + Ok(a.ln()) + } + Func::Exp => { + let result = a.exp(); + if result.is_infinite() { + return Err(EvalError::Domain { + func, + detail: "exp overflow".into(), + }); + } + Ok(result) + } + Func::Sqrt => { + if a < 0.0 { + return Err(EvalError::Domain { + func, + detail: "sqrt of negative".into(), + }); + } + Ok(a.sqrt()) + } + Func::Abs => Ok(a.abs()), + Func::Floor => Ok(a.floor()), + Func::Ceil => Ok(a.ceil()), + Func::Min => { + let b = args[1].evaluate(size)?; + Ok(a.min(b)) + } + Func::Max => { + let b = args[1].evaluate(size)?; + Ok(a.max(b)) + } + } +} + +// ── Parser ── + +/// A source span for error reporting. +#[derive(Clone, Copy, Debug)] +pub struct Span { + pub start: usize, + pub end: usize, +} + +/// Parse error. +#[derive(Debug)] +pub enum ParseError { + UnexpectedChar { ch: char, pos: usize }, + UnexpectedEof, + UnexpectedToken { span: Span, detail: String }, + UnknownFunction { name: String, span: Span }, + TrailingInput { span: Span }, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseError::UnexpectedChar { ch, pos } => { + write!(f, "unexpected character '{ch}' at position {pos}") + } + ParseError::UnexpectedEof => write!(f, "unexpected end of input"), + ParseError::UnexpectedToken { span, detail } => { + write!(f, "{detail} at {}..{}", span.start, span.end) + } + ParseError::UnknownFunction { name, span } => { + write!( + f, + "unknown function '{name}' at {}..{}", + span.start, span.end + ) + } + ParseError::TrailingInput { span } => { + write!(f, "trailing input at {}..{}", span.start, span.end) + } + } + } +} + +impl std::error::Error for ParseError {} + +#[derive(Clone, Debug)] +enum Token { + Num(f64), + Ident(String), + Plus, + Minus, + Star, + Slash, + Caret, + LParen, + RParen, + Comma, + Eof, +} + +#[derive(Clone, Debug)] +struct SpannedToken { + token: Token, + span: Span, +} + +struct Lexer<'a> { + input: &'a [u8], + pos: usize, +} + +impl<'a> Lexer<'a> { + fn new(input: &'a str) -> Self { + Self { + input: input.as_bytes(), + pos: 0, + } + } + + fn skip_whitespace(&mut self) { + while self.pos < self.input.len() && self.input[self.pos].is_ascii_whitespace() { + self.pos += 1; + } + } + + fn next_token(&mut self) -> Result { + self.skip_whitespace(); + let start = self.pos; + + if self.pos >= self.input.len() { + return Ok(SpannedToken { + token: Token::Eof, + span: Span { start, end: start }, + }); + } + + let ch = self.input[self.pos] as char; + match ch { + '+' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Plus, + span: Span { + start, + end: self.pos, + }, + }) + } + '-' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Minus, + span: Span { + start, + end: self.pos, + }, + }) + } + '*' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Star, + span: Span { + start, + end: self.pos, + }, + }) + } + '/' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Slash, + span: Span { + start, + end: self.pos, + }, + }) + } + '^' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Caret, + span: Span { + start, + end: self.pos, + }, + }) + } + '(' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::LParen, + span: Span { + start, + end: self.pos, + }, + }) + } + ')' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::RParen, + span: Span { + start, + end: self.pos, + }, + }) + } + ',' => { + self.pos += 1; + Ok(SpannedToken { + token: Token::Comma, + span: Span { + start, + end: self.pos, + }, + }) + } + c if c.is_ascii_digit() || c == '.' => self.lex_number(start), + c if c.is_ascii_alphabetic() || c == '_' => self.lex_ident(start), + _ => Err(ParseError::UnexpectedChar { ch, pos: start }), + } + } + + fn lex_number(&mut self, start: usize) -> Result { + while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() { + self.pos += 1; + } + if self.pos < self.input.len() && self.input[self.pos] == b'.' { + self.pos += 1; + while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() { + self.pos += 1; + } + } + let s = std::str::from_utf8(&self.input[start..self.pos]).unwrap(); + let val: f64 = s.parse().map_err(|_| ParseError::UnexpectedChar { + ch: s.chars().next().unwrap_or('?'), + pos: start, + })?; + Ok(SpannedToken { + token: Token::Num(val), + span: Span { + start, + end: self.pos, + }, + }) + } + + fn lex_ident(&mut self, start: usize) -> Result { + while self.pos < self.input.len() + && (self.input[self.pos].is_ascii_alphanumeric() || self.input[self.pos] == b'_') + { + self.pos += 1; + } + let s = std::str::from_utf8(&self.input[start..self.pos]).unwrap(); + Ok(SpannedToken { + token: Token::Ident(s.to_string()), + span: Span { + start, + end: self.pos, + }, + }) + } +} + +struct Parser { + tokens: Vec, + pos: usize, +} + +impl Parser { + fn new(input: &str) -> Result { + let mut lexer = Lexer::new(input); + let mut tokens = Vec::new(); + loop { + let tok = lexer.next_token()?; + let is_eof = matches!(tok.token, Token::Eof); + tokens.push(tok); + if is_eof { + break; + } + } + Ok(Self { tokens, pos: 0 }) + } + + fn peek(&self) -> &SpannedToken { + &self.tokens[self.pos] + } + + fn advance(&mut self) -> &SpannedToken { + let tok = &self.tokens[self.pos]; + if self.pos < self.tokens.len() - 1 { + self.pos += 1; + } + tok + } + + fn expect_rparen(&mut self) -> Result<(), ParseError> { + match &self.peek().token { + Token::RParen => { + self.advance(); + Ok(()) + } + _ => Err(ParseError::UnexpectedToken { + span: self.peek().span, + detail: "expected ')'".to_string(), + }), + } + } + + fn parse_expr(&mut self, min_bp: u8) -> Result { + let mut lhs = self.parse_prefix()?; + + loop { + let (op, l_bp, r_bp) = match &self.peek().token { + Token::Plus => (BinOp::Add, 1, 2), + Token::Minus => (BinOp::Sub, 1, 2), + Token::Star => (BinOp::Mul, 3, 4), + Token::Slash => (BinOp::Div, 3, 4), + Token::Caret => (BinOp::Pow, 7, 6), // right-assoc + _ => break, + }; + + if l_bp < min_bp { + break; + } + + self.advance(); + let rhs = self.parse_expr(r_bp)?; + lhs = Expr::binop(op, lhs, rhs); + } + + Ok(lhs) + } + + fn parse_prefix(&mut self) -> Result { + let tok = self.peek().clone(); + match &tok.token { + Token::Num(v) => { + let v = *v; + self.advance(); + Ok(Expr::Num(v)) + } + Token::Minus => { + self.advance(); + let inner = self.parse_expr(5)?; // unary minus bp + Ok(Expr::Neg(Box::new(inner))) + } + Token::LParen => { + self.advance(); + let inner = self.parse_expr(0)?; + self.expect_rparen()?; + Ok(inner) + } + Token::Ident(name) => { + let name = name.clone(); + let span = tok.span; + self.advance(); + + // Check if followed by '(' — function call + if matches!(self.peek().token, Token::LParen) { + self.advance(); // consume '(' + let func = resolve_func(&name, span)?; + let mut args = Vec::new(); + if !matches!(self.peek().token, Token::RParen) { + args.push(self.parse_expr(0)?); + while matches!(self.peek().token, Token::Comma) { + self.advance(); + args.push(self.parse_expr(0)?); + } + } + self.expect_rparen()?; + Ok(Expr::Call { func, args }) + } else { + Ok(Expr::Var(name.into())) + } + } + Token::Eof => Err(ParseError::UnexpectedEof), + _ => Err(ParseError::UnexpectedToken { + span: tok.span, + detail: "expected expression".to_string(), + }), + } + } +} + +fn resolve_func(name: &str, span: Span) -> Result { + match name.to_ascii_lowercase().as_str() { + "log2" => Ok(Func::Log2), + "log10" => Ok(Func::Log10), + "ln" => Ok(Func::Ln), + "exp" => Ok(Func::Exp), + "sqrt" => Ok(Func::Sqrt), + "min" => Ok(Func::Min), + "max" => Ok(Func::Max), + "floor" => Ok(Func::Floor), + "ceil" => Ok(Func::Ceil), + "abs" => Ok(Func::Abs), + _ => Err(ParseError::UnknownFunction { + name: name.to_string(), + span, + }), + } +} + +impl Expr { + /// Parse an expression from a string. + pub fn parse(input: &str) -> Result { + if input.trim().is_empty() { + return Err(ParseError::UnexpectedEof); + } + let mut parser = Parser::new(input)?; + let expr = parser.parse_expr(0)?; + if !matches!(parser.peek().token, Token::Eof) { + return Err(ParseError::TrailingInput { + span: parser.peek().span, + }); + } + Ok(expr) + } +} + +// ── Display ── + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + display_expr(self, f, 0) + } +} + +/// Display with parent context to decide parenthesization. +/// `parent_bp` is the binding power of the parent context. +fn display_expr(expr: &Expr, f: &mut fmt::Formatter<'_>, parent_bp: u8) -> fmt::Result { + match expr { + Expr::Num(v) => { + let rounded = v.round() as i64; + if (*v - rounded as f64).abs() < 1e-10 && v.is_finite() { + write!(f, "{rounded}") + } else { + write!(f, "{v}") + } + } + Expr::Var(name) => write!(f, "{name}"), + Expr::Neg(inner) => { + write!(f, "-")?; + // Wrap compound inner expressions + let needs_parens = matches!(inner.as_ref(), Expr::BinOp { .. }); + if needs_parens { + write!(f, "(")?; + display_expr(inner, f, 0)?; + write!(f, ")") + } else { + display_expr(inner, f, 5) + } + } + Expr::BinOp { op, lhs, rhs } => { + let (l_bp, r_bp) = match op { + BinOp::Add | BinOp::Sub => (1, 2), + BinOp::Mul | BinOp::Div => (3, 4), + BinOp::Pow => (7, 6), + }; + + let needs_parens = l_bp < parent_bp; + if needs_parens { + write!(f, "(")?; + } + + display_expr(lhs, f, l_bp)?; + let op_str = match op { + BinOp::Add => " + ", + BinOp::Sub => " - ", + BinOp::Mul => " * ", + BinOp::Div => " / ", + BinOp::Pow => " ^ ", + }; + write!(f, "{op_str}")?; + display_expr(rhs, f, r_bp)?; + + if needs_parens { + write!(f, ")")?; + } + Ok(()) + } + Expr::Call { func, args } => { + let name = match func { + Func::Log2 => "log2", + Func::Log10 => "log10", + Func::Ln => "ln", + Func::Exp => "exp", + Func::Sqrt => "sqrt", + Func::Min => "min", + Func::Max => "max", + Func::Floor => "floor", + Func::Ceil => "ceil", + Func::Abs => "abs", + }; + write!(f, "{name}(")?; + for (i, arg) in args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + display_expr(arg, f, 0)?; + } + write!(f, ")") + } + } +} + +// ── Serde ── + +impl serde::Serialize for Expr { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> serde::Deserialize<'de> for Expr { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Expr::parse(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +#[path = "unit_tests/expr.rs"] +mod tests; diff --git a/src/lib.rs b/src/lib.rs index 30dd2a65..775b7e30 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,9 +20,10 @@ pub mod config; pub mod error; pub mod export; +pub mod expr; pub mod io; pub mod models; -pub(crate) mod polynomial; + pub mod registry; pub mod rules; pub mod solvers; diff --git a/src/polynomial.rs b/src/polynomial.rs deleted file mode 100644 index f17e3087..00000000 --- a/src/polynomial.rs +++ /dev/null @@ -1,340 +0,0 @@ -//! Polynomial representation for reduction overhead. - -use crate::types::ProblemSize; -use std::collections::{HashMap, HashSet}; -use std::fmt; -use std::ops::Add; - -/// A monomial: coefficient × Π(variable^exponent) -#[derive(Clone, Debug, PartialEq, serde::Serialize)] -pub struct Monomial { - pub coefficient: f64, - pub variables: Vec<(&'static str, u8)>, -} - -impl Monomial { - pub fn constant(c: f64) -> Self { - Self { - coefficient: c, - variables: vec![], - } - } - - pub fn var(name: &'static str) -> Self { - Self { - coefficient: 1.0, - variables: vec![(name, 1)], - } - } - - pub fn var_pow(name: &'static str, exp: u8) -> Self { - Self { - coefficient: 1.0, - variables: vec![(name, exp)], - } - } - - pub fn scale(mut self, c: f64) -> Self { - self.coefficient *= c; - self - } - - pub fn evaluate(&self, size: &ProblemSize) -> f64 { - let var_product: f64 = self - .variables - .iter() - .map(|(name, exp)| { - let val = size.get(name).unwrap_or(0) as f64; - val.powi(*exp as i32) - }) - .product(); - self.coefficient * var_product - } - - /// Multiply two monomials. - pub fn mul(&self, other: &Monomial) -> Monomial { - let mut variables = self.variables.clone(); - variables.extend_from_slice(&other.variables); - Monomial { - coefficient: self.coefficient * other.coefficient, - variables, - } - } - - /// Normalize: sort variables by name, merge duplicate entries. - pub fn normalize(&mut self) { - self.variables.sort_by_key(|(name, _)| *name); - let mut merged: Vec<(&'static str, u8)> = Vec::new(); - for &(name, exp) in &self.variables { - if let Some(last) = merged.last_mut() { - if last.0 == name { - last.1 += exp; - continue; - } - } - merged.push((name, exp)); - } - // Remove zero-exponent variables - merged.retain(|&(_, exp)| exp > 0); - self.variables = merged; - } - - /// Variable signature for like-term comparison (after normalization). - fn var_signature(&self) -> &[(&'static str, u8)] { - &self.variables - } -} - -/// A polynomial: Σ monomials -#[derive(Clone, Debug, PartialEq, serde::Serialize)] -pub struct Polynomial { - pub terms: Vec, -} - -impl Polynomial { - pub fn zero() -> Self { - Self { terms: vec![] } - } - - pub fn constant(c: f64) -> Self { - Self { - terms: vec![Monomial::constant(c)], - } - } - - pub fn var(name: &'static str) -> Self { - Self { - terms: vec![Monomial::var(name)], - } - } - - pub fn var_pow(name: &'static str, exp: u8) -> Self { - Self { - terms: vec![Monomial::var_pow(name, exp)], - } - } - - /// Create a polynomial with a single monomial that is a product of two variables. - pub fn var_product(a: &'static str, b: &'static str) -> Self { - Self { - terms: vec![Monomial { - coefficient: 1.0, - variables: vec![(a, 1), (b, 1)], - }], - } - } - - pub fn scale(mut self, c: f64) -> Self { - for term in &mut self.terms { - term.coefficient *= c; - } - self - } - - pub fn evaluate(&self, size: &ProblemSize) -> f64 { - self.terms.iter().map(|m| m.evaluate(size)).sum() - } - - /// Collect all variable names referenced by this polynomial. - pub fn variable_names(&self) -> HashSet<&'static str> { - self.terms - .iter() - .flat_map(|m| m.variables.iter().map(|(name, _)| *name)) - .collect() - } - - /// Multiply two polynomials. - pub fn mul(&self, other: &Polynomial) -> Polynomial { - let mut terms = Vec::new(); - for a in &self.terms { - for b in &other.terms { - terms.push(a.mul(b)); - } - } - let mut result = Polynomial { terms }; - result.normalize(); - result - } - - /// Raise to a non-negative integer power. - pub fn pow(&self, n: u8) -> Polynomial { - match n { - 0 => Polynomial::constant(1.0), - 1 => self.clone(), - _ => { - let mut result = self.clone(); - for _ in 1..n { - result = result.mul(self); - } - result - } - } - } - - /// Substitute variables with polynomials. - /// - /// Each variable in the polynomial is replaced by the corresponding - /// polynomial from the mapping. Variables not in the mapping are left as-is. - pub fn substitute(&self, mapping: &HashMap<&str, &Polynomial>) -> Polynomial { - let mut result = Polynomial::zero(); - for mono in &self.terms { - // Start with the coefficient - let mut term_poly = Polynomial::constant(mono.coefficient); - // Multiply by each variable's substitution raised to its exponent - for &(name, exp) in &mono.variables { - let var_poly = if let Some(&replacement) = mapping.get(name) { - replacement.pow(exp) - } else { - Polynomial::var_pow(name, exp) - }; - term_poly = term_poly.mul(&var_poly); - } - result = result + term_poly; - } - result.normalize(); - result - } - - /// Normalize: normalize all monomials, then combine like terms. - pub fn normalize(&mut self) { - for term in &mut self.terms { - term.normalize(); - } - // Combine like terms - let mut combined: Vec = Vec::new(); - for term in &self.terms { - if let Some(existing) = combined - .iter_mut() - .find(|m| m.var_signature() == term.var_signature()) - { - existing.coefficient += term.coefficient; - } else { - combined.push(term.clone()); - } - } - // Remove zero-coefficient terms - combined.retain(|m| m.coefficient.abs() > 1e-15); - self.terms = combined; - } - - /// Return a normalized copy. - pub fn normalized(&self) -> Polynomial { - let mut p = self.clone(); - p.normalize(); - p - } -} - -impl fmt::Display for Monomial { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let coeff_i = self.coefficient.round() as i64; - let is_int = (self.coefficient - coeff_i as f64).abs() < 1e-10; - if self.variables.is_empty() { - if is_int { - write!(f, "{coeff_i}") - } else { - write!(f, "{}", self.coefficient) - } - } else { - let has_coeff = if is_int { - match coeff_i { - 1 => false, - -1 => { - write!(f, "-")?; - false - } - _ => { - write!(f, "{coeff_i}")?; - true - } - } - } else { - write!(f, "{}", self.coefficient)?; - true - }; - for (i, (name, exp)) in self.variables.iter().enumerate() { - if has_coeff || i > 0 { - write!(f, " * ")?; - } - write!(f, "{name}")?; - if *exp > 1 { - write!(f, "^{exp}")?; - } - } - Ok(()) - } - } -} - -impl fmt::Display for Polynomial { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.terms.is_empty() { - write!(f, "0") - } else { - for (i, term) in self.terms.iter().enumerate() { - if i > 0 { - if term.coefficient < 0.0 { - write!(f, " - ")?; - let negated = Monomial { - coefficient: -term.coefficient, - variables: term.variables.clone(), - }; - write!(f, "{negated}")?; - } else { - write!(f, " + ")?; - write!(f, "{term}")?; - } - } else { - write!(f, "{term}")?; - } - } - Ok(()) - } - } -} - -impl Add for Polynomial { - type Output = Self; - - fn add(mut self, other: Self) -> Self { - self.terms.extend(other.terms); - self - } -} - -/// Convenience macro for building polynomials. -#[macro_export] -macro_rules! poly { - // Single variable: poly!(n) - ($name:ident) => { - $crate::polynomial::Polynomial::var(stringify!($name)) - }; - // Variable with exponent: poly!(n^2) - ($name:ident ^ $exp:literal) => { - $crate::polynomial::Polynomial::var_pow(stringify!($name), $exp) - }; - // Constant: poly!(5) - ($c:literal) => { - $crate::polynomial::Polynomial::constant($c as f64) - }; - // Scaled variable: poly!(3 * n) - ($c:literal * $name:ident) => { - $crate::polynomial::Polynomial::var(stringify!($name)).scale($c as f64) - }; - // Scaled variable with exponent: poly!(9 * n^2) - ($c:literal * $name:ident ^ $exp:literal) => { - $crate::polynomial::Polynomial::var_pow(stringify!($name), $exp).scale($c as f64) - }; - // Product of two variables: poly!(a * b) - ($a:ident * $b:ident) => { - $crate::polynomial::Polynomial::var_product(stringify!($a), stringify!($b)) - }; - // Scaled product of two variables: poly!(3 * a * b) - ($c:literal * $a:ident * $b:ident) => { - $crate::polynomial::Polynomial::var_product(stringify!($a), stringify!($b)).scale($c as f64) - }; -} - -#[cfg(test)] -#[path = "unit_tests/polynomial.rs"] -mod tests; diff --git a/src/rules/circuit_ilp.rs b/src/rules/circuit_ilp.rs index 04e89f25..54f82606 100644 --- a/src/rules/circuit_ilp.rs +++ b/src/rules/circuit_ilp.rs @@ -16,7 +16,7 @@ use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; use crate::models::specialized::{BooleanExpr, BooleanOp, CircuitSAT}; -use crate::poly; + use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -175,8 +175,8 @@ impl ILPBuilder { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_variables) + poly!(num_assignments)), - ("num_constraints", poly!(num_variables) + poly!(num_assignments)), + ("num_vars", "num_variables + num_assignments"), + ("num_constraints", "num_variables + num_assignments"), ]) } )] diff --git a/src/rules/circuit_spinglass.rs b/src/rules/circuit_spinglass.rs index 30a019bf..1fbf7778 100644 --- a/src/rules/circuit_spinglass.rs +++ b/src/rules/circuit_spinglass.rs @@ -8,7 +8,6 @@ use crate::models::optimization::SpinGlass; use crate::models::specialized::{Assignment, BooleanExpr, BooleanOp, CircuitSAT}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -416,8 +415,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_spins", poly!(num_assignments)), - ("num_interactions", poly!(num_assignments)), + ("num_spins", "num_assignments"), + ("num_interactions", "num_assignments"), ]) } )] diff --git a/src/rules/coloring_ilp.rs b/src/rules/coloring_ilp.rs index f9e982da..89a8c68e 100644 --- a/src/rules/coloring_ilp.rs +++ b/src/rules/coloring_ilp.rs @@ -9,7 +9,6 @@ use crate::models::graph::KColoring; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -125,8 +124,8 @@ fn reduce_kcoloring_to_ilp( #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vertices ^ 2)), - ("num_constraints", poly!(num_vertices) + poly!(num_vertices * num_edges)), + ("num_vars", "num_vertices ^ 2"), + ("num_constraints", "num_vertices + num_vertices * num_edges"), ]) } )] diff --git a/src/rules/coloring_qubo.rs b/src/rules/coloring_qubo.rs index 4a2f954d..d436eac3 100644 --- a/src/rules/coloring_qubo.rs +++ b/src/rules/coloring_qubo.rs @@ -10,7 +10,6 @@ use crate::models::graph::KColoring; use crate::models::optimization::QUBO; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -107,7 +106,7 @@ fn reduce_kcoloring_to_qubo( // Register only the KN variant in the reduction graph #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_vertices ^ 2))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_vertices ^ 2")]) } )] impl ReduceTo> for KColoring { type Result = ReductionKColoringToQUBO; diff --git a/src/rules/cost.rs b/src/rules/cost.rs index 0cbf1bf1..33463312 100644 --- a/src/rules/cost.rs +++ b/src/rules/cost.rs @@ -14,7 +14,59 @@ pub struct Minimize(pub &'static str); impl PathCostFn for Minimize { fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 { - overhead.evaluate_output_size(size).get(self.0).unwrap_or(0) as f64 + overhead + .evaluate_output_size(size) + .expect("overhead evaluation failed") + .get(self.0) + .unwrap_or(0) as f64 + } +} + +/// Minimize weighted sum of output fields. +pub struct MinimizeWeighted(pub Vec<(&'static str, f64)>); + +impl PathCostFn for MinimizeWeighted { + fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 { + let output = overhead + .evaluate_output_size(size) + .expect("overhead evaluation failed"); + self.0 + .iter() + .map(|(field, weight)| weight * output.get(field).unwrap_or(0) as f64) + .sum() + } +} + +/// Minimize the maximum of specified fields. +pub struct MinimizeMax(pub Vec<&'static str>); + +impl PathCostFn for MinimizeMax { + fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 { + let output = overhead + .evaluate_output_size(size) + .expect("overhead evaluation failed"); + self.0 + .iter() + .map(|field| output.get(field).unwrap_or(0) as f64) + .fold(0.0, f64::max) + } +} + +/// Lexicographic: minimize first field, break ties with subsequent. +pub struct MinimizeLexicographic(pub Vec<&'static str>); + +impl PathCostFn for MinimizeLexicographic { + fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 { + let output = overhead + .evaluate_output_size(size) + .expect("overhead evaluation failed"); + let mut cost = 0.0; + let mut scale = 1.0; + for field in &self.0 { + cost += scale * output.get(field).unwrap_or(0) as f64; + scale *= 1e-10; + } + cost } } diff --git a/src/rules/factoring_circuit.rs b/src/rules/factoring_circuit.rs index e43a8853..cddba9bd 100644 --- a/src/rules/factoring_circuit.rs +++ b/src/rules/factoring_circuit.rs @@ -8,7 +8,6 @@ //! carry propagation, building up partial products row by row. use crate::models::specialized::{Assignment, BooleanExpr, Circuit, CircuitSAT, Factoring}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -178,8 +177,8 @@ fn build_multiplier_cell( #[reduction(overhead = { ReductionOverhead::new(vec![ - ("num_variables", poly!(num_bits_first * num_bits_second)), - ("num_assignments", poly!(num_bits_first * num_bits_second)), + ("num_variables", "num_bits_first * num_bits_second"), + ("num_assignments", "num_bits_first * num_bits_second"), ]) })] impl ReduceTo for Factoring { diff --git a/src/rules/factoring_ilp.rs b/src/rules/factoring_ilp.rs index 491c0627..8ed4dcaf 100644 --- a/src/rules/factoring_ilp.rs +++ b/src/rules/factoring_ilp.rs @@ -19,7 +19,6 @@ use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; use crate::models::specialized::Factoring; -use crate::polynomial::{Monomial, Polynomial}; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -95,31 +94,8 @@ impl ReductionResult for ReductionFactoringToILP { #[reduction(overhead = { ReductionOverhead::new(vec![ - // num_vars = m + n + m*n + num_carries where num_carries = max(m+n, target_bits) - // For feasible instances, target_bits <= m+n, so this is 2(m+n) + m*n - ("num_vars", Polynomial { - terms: vec![ - Monomial::var("num_bits_first").scale(2.0), - Monomial::var("num_bits_second").scale(2.0), - Monomial { - coefficient: 1.0, - variables: vec![("num_bits_first", 1), ("num_bits_second", 1)], - }, - ] - }), - // num_constraints = 3*m*n + num_bit_positions + 1 - // For feasible instances (target_bits <= m+n), this is 3*m*n + (m+n) + 1 - ("num_constraints", Polynomial { - terms: vec![ - Monomial { - coefficient: 3.0, - variables: vec![("num_bits_first", 1), ("num_bits_second", 1)], - }, - Monomial::var("num_bits_first"), - Monomial::var("num_bits_second"), - Monomial::constant(1.0), - ] - }), + ("num_vars", "2 * num_bits_first + 2 * num_bits_second + num_bits_first * num_bits_second"), + ("num_constraints", "3 * num_bits_first * num_bits_second + num_bits_first + num_bits_second + 1"), ]) })] impl ReduceTo for Factoring { diff --git a/src/rules/graph.rs b/src/rules/graph.rs index a1c5c746..13705adf 100644 --- a/src/rules/graph.rs +++ b/src/rules/graph.rs @@ -97,7 +97,7 @@ pub(crate) struct EdgeJson { pub(crate) source: usize, /// Index into the `nodes` array for the target problem variant. pub(crate) target: usize, - /// Reduction overhead: output size as polynomials of input size. + /// Reduction overhead: output size as expressions of input size. pub(crate) overhead: Vec, /// Relative rustdoc path for the reduction module. pub(crate) doc_path: String, @@ -108,6 +108,8 @@ pub(crate) struct EdgeJson { pub struct ReductionPath { /// Variant-level steps in the path. pub steps: Vec, + /// Overhead for each edge in the path (length = steps.len() - 1). + pub overheads: Vec, } impl ReductionPath { @@ -145,6 +147,15 @@ impl ReductionPath { } names } + + /// Evaluate the end-to-end overhead by chaining each step's overhead. + pub fn evaluate(&self, input: &ProblemSize) -> Result { + let mut current = input.clone(); + for overhead in &self.overheads { + current = overhead.evaluate_output_size(¤t)?; + } + Ok(current) + } } impl std::fmt::Display for ReductionPath { @@ -502,7 +513,10 @@ impl ReductionGraph { let edge_cost = cost_fn.edge_cost(overhead, ¤t_size); let new_cost = cost.0 + edge_cost; - let new_size = overhead.evaluate_output_size(¤t_size); + let new_size = match overhead.evaluate_output_size(¤t_size) { + Ok(s) => s, + Err(_) => continue, + }; if new_cost < *costs.get(&next).unwrap_or(&f64::INFINITY) { costs.insert(next, new_cost); @@ -528,7 +542,14 @@ impl ReductionGraph { } }) .collect(); - ReductionPath { steps } + let overheads = node_path + .windows(2) + .map(|w| { + let edge_idx = self.graph.find_edge(w[0], w[1]).unwrap(); + self.graph[edge_idx].overhead.clone() + }) + .collect(); + ReductionPath { steps, overheads } } /// Find all simple paths between two specific problem variants. @@ -856,7 +877,11 @@ impl ReductionGraph { ) -> NeighborTree { let children = node_children .get(&idx) - .map(|cs| cs.iter().map(|&c| build(c, node_children, nodes, graph)).collect()) + .map(|cs| { + cs.iter() + .map(|&c| build(c, node_children, nodes, graph)) + .collect() + }) .unwrap_or_default(); let node = &nodes[graph[idx]]; NeighborTree { diff --git a/src/rules/ilp_qubo.rs b/src/rules/ilp_qubo.rs index dbbb5c00..2e86b0d2 100644 --- a/src/rules/ilp_qubo.rs +++ b/src/rules/ilp_qubo.rs @@ -10,7 +10,6 @@ //! Slack variables: ceil(log2(slack_range)) bits per inequality constraint. use crate::models::optimization::{Comparison, ObjectiveSense, ILP, QUBO}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -37,7 +36,7 @@ impl ReductionResult for ReductionILPToQUBO { } #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_vars))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_vars")]) } )] impl ReduceTo> for ILP { type Result = ReductionILPToQUBO; diff --git a/src/rules/ksatisfiability_qubo.rs b/src/rules/ksatisfiability_qubo.rs index 2cfb4f98..61f86d73 100644 --- a/src/rules/ksatisfiability_qubo.rs +++ b/src/rules/ksatisfiability_qubo.rs @@ -14,7 +14,6 @@ use crate::models::optimization::QUBO; use crate::models::satisfiability::KSatisfiability; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -293,7 +292,7 @@ fn build_qubo_matrix( } #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_vars))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_vars")]) } )] impl ReduceTo> for KSatisfiability { type Result = ReductionKSatToQUBO; @@ -311,7 +310,7 @@ impl ReduceTo> for KSatisfiability { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vars) + poly!(num_clauses)), + ("num_vars", "num_vars + num_clauses"), ]) } )] impl ReduceTo> for KSatisfiability { diff --git a/src/rules/maximumclique_ilp.rs b/src/rules/maximumclique_ilp.rs index 08f5026e..e30862ea 100644 --- a/src/rules/maximumclique_ilp.rs +++ b/src/rules/maximumclique_ilp.rs @@ -8,7 +8,6 @@ use crate::models::graph::MaximumClique; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -45,8 +44,8 @@ impl ReductionResult for ReductionCliqueToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vertices)), - ("num_constraints", poly!(num_vertices ^ 2)), + ("num_vars", "num_vertices"), + ("num_constraints", "num_vertices ^ 2"), ]) } )] diff --git a/src/rules/maximumindependentset_gridgraph.rs b/src/rules/maximumindependentset_gridgraph.rs index 8d0ee475..d2ff47ab 100644 --- a/src/rules/maximumindependentset_gridgraph.rs +++ b/src/rules/maximumindependentset_gridgraph.rs @@ -4,7 +4,6 @@ //! Maps an arbitrary graph's MIS problem to an equivalent weighted MIS on a grid graph. use crate::models::graph::MaximumIndependentSet; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -34,8 +33,8 @@ impl ReductionResult for ReductionISSimpleToGrid { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vertices * num_vertices)), - ("num_edges", poly!(num_vertices * num_vertices)), + ("num_vertices", "num_vertices * num_vertices"), + ("num_edges", "num_vertices * num_vertices"), ]) } )] @@ -81,8 +80,8 @@ impl ReductionResult for ReductionISUnitDiskToGrid { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vertices * num_vertices)), - ("num_edges", poly!(num_vertices * num_vertices)), + ("num_vertices", "num_vertices * num_vertices"), + ("num_edges", "num_vertices * num_vertices"), ]) } )] diff --git a/src/rules/maximumindependentset_ilp.rs b/src/rules/maximumindependentset_ilp.rs index 220cd7e7..136b59b6 100644 --- a/src/rules/maximumindependentset_ilp.rs +++ b/src/rules/maximumindependentset_ilp.rs @@ -7,7 +7,6 @@ use crate::models::graph::MaximumIndependentSet; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -44,8 +43,8 @@ impl ReductionResult for ReductionISToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vertices)), - ("num_constraints", poly!(num_edges)), + ("num_vars", "num_vertices"), + ("num_constraints", "num_edges"), ]) } )] diff --git a/src/rules/maximumindependentset_maximumsetpacking.rs b/src/rules/maximumindependentset_maximumsetpacking.rs index 0cdbbb02..0dc97e52 100644 --- a/src/rules/maximumindependentset_maximumsetpacking.rs +++ b/src/rules/maximumindependentset_maximumsetpacking.rs @@ -5,7 +5,6 @@ use crate::models::graph::MaximumIndependentSet; use crate::models::set::MaximumSetPacking; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -39,8 +38,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_sets", poly!(num_vertices)), - ("universe_size", poly!(num_vertices)), + ("num_sets", "num_vertices"), + ("universe_size", "num_vertices"), ]) } )] @@ -90,8 +89,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_sets)), - ("num_edges", poly!(num_sets)), + ("num_vertices", "num_sets"), + ("num_edges", "num_sets"), ]) } )] diff --git a/src/rules/maximumindependentset_qubo.rs b/src/rules/maximumindependentset_qubo.rs index ea1e5d08..0c159c83 100644 --- a/src/rules/maximumindependentset_qubo.rs +++ b/src/rules/maximumindependentset_qubo.rs @@ -7,7 +7,6 @@ use crate::models::graph::MaximumIndependentSet; use crate::models::optimization::QUBO; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -32,7 +31,7 @@ impl ReductionResult for ReductionISToQUBO { } #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_vertices))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_vertices")]) } )] impl ReduceTo> for MaximumIndependentSet { type Result = ReductionISToQUBO; diff --git a/src/rules/maximumindependentset_triangular.rs b/src/rules/maximumindependentset_triangular.rs index 09d6a85e..091a994c 100644 --- a/src/rules/maximumindependentset_triangular.rs +++ b/src/rules/maximumindependentset_triangular.rs @@ -5,7 +5,6 @@ //! triangular lattice grid graph. use crate::models::graph::MaximumIndependentSet; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -36,8 +35,8 @@ impl ReductionResult for ReductionISSimpleToTriangular { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vertices * num_vertices)), - ("num_edges", poly!(num_vertices * num_vertices)), + ("num_vertices", "num_vertices * num_vertices"), + ("num_edges", "num_vertices * num_vertices"), ]) } )] diff --git a/src/rules/maximummatching_ilp.rs b/src/rules/maximummatching_ilp.rs index dc016861..d07069cb 100644 --- a/src/rules/maximummatching_ilp.rs +++ b/src/rules/maximummatching_ilp.rs @@ -8,7 +8,6 @@ use crate::models::graph::MaximumMatching; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -45,8 +44,8 @@ impl ReductionResult for ReductionMatchingToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_edges)), - ("num_constraints", poly!(num_vertices)), + ("num_vars", "num_edges"), + ("num_constraints", "num_vertices"), ]) } )] diff --git a/src/rules/maximummatching_maximumsetpacking.rs b/src/rules/maximummatching_maximumsetpacking.rs index c477fc66..fa15a9f5 100644 --- a/src/rules/maximummatching_maximumsetpacking.rs +++ b/src/rules/maximummatching_maximumsetpacking.rs @@ -5,7 +5,6 @@ use crate::models::graph::MaximumMatching; use crate::models::set::MaximumSetPacking; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -40,8 +39,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_sets", poly!(num_edges)), - ("universe_size", poly!(num_vertices)), + ("num_sets", "num_edges"), + ("universe_size", "num_vertices"), ]) } )] diff --git a/src/rules/maximumsetpacking_ilp.rs b/src/rules/maximumsetpacking_ilp.rs index b5f22d74..b3dfbedb 100644 --- a/src/rules/maximumsetpacking_ilp.rs +++ b/src/rules/maximumsetpacking_ilp.rs @@ -7,7 +7,6 @@ use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; use crate::models::set::MaximumSetPacking; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -43,8 +42,8 @@ impl ReductionResult for ReductionSPToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_sets)), - ("num_constraints", poly!(num_sets ^ 2)), + ("num_vars", "num_sets"), + ("num_constraints", "num_sets ^ 2"), ]) } )] diff --git a/src/rules/maximumsetpacking_qubo.rs b/src/rules/maximumsetpacking_qubo.rs index 2e5a48c0..1da95d93 100644 --- a/src/rules/maximumsetpacking_qubo.rs +++ b/src/rules/maximumsetpacking_qubo.rs @@ -8,7 +8,6 @@ use crate::models::optimization::QUBO; use crate::models::set::MaximumSetPacking; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -33,7 +32,7 @@ impl ReductionResult for ReductionSPToQUBO { } #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_sets))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_sets")]) } )] impl ReduceTo> for MaximumSetPacking { type Result = ReductionSPToQUBO; diff --git a/src/rules/minimumdominatingset_ilp.rs b/src/rules/minimumdominatingset_ilp.rs index b6c526b8..9c5a1b37 100644 --- a/src/rules/minimumdominatingset_ilp.rs +++ b/src/rules/minimumdominatingset_ilp.rs @@ -8,7 +8,6 @@ use crate::models::graph::MinimumDominatingSet; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -46,8 +45,8 @@ impl ReductionResult for ReductionDSToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vertices)), - ("num_constraints", poly!(num_vertices)), + ("num_vars", "num_vertices"), + ("num_constraints", "num_vertices"), ]) } )] diff --git a/src/rules/minimumsetcovering_ilp.rs b/src/rules/minimumsetcovering_ilp.rs index 1a3889b7..1f9c3b01 100644 --- a/src/rules/minimumsetcovering_ilp.rs +++ b/src/rules/minimumsetcovering_ilp.rs @@ -7,7 +7,6 @@ use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; use crate::models::set::MinimumSetCovering; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -43,8 +42,8 @@ impl ReductionResult for ReductionSCToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_sets)), - ("num_constraints", poly!(universe_size)), + ("num_vars", "num_sets"), + ("num_constraints", "universe_size"), ]) } )] diff --git a/src/rules/minimumvertexcover_ilp.rs b/src/rules/minimumvertexcover_ilp.rs index 18780fd5..f4459971 100644 --- a/src/rules/minimumvertexcover_ilp.rs +++ b/src/rules/minimumvertexcover_ilp.rs @@ -7,7 +7,6 @@ use crate::models::graph::MinimumVertexCover; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -44,8 +43,8 @@ impl ReductionResult for ReductionVCToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_vertices)), - ("num_constraints", poly!(num_edges)), + ("num_vars", "num_vertices"), + ("num_constraints", "num_edges"), ]) } )] diff --git a/src/rules/minimumvertexcover_maximumindependentset.rs b/src/rules/minimumvertexcover_maximumindependentset.rs index 0715c474..a3d62380 100644 --- a/src/rules/minimumvertexcover_maximumindependentset.rs +++ b/src/rules/minimumvertexcover_maximumindependentset.rs @@ -3,7 +3,6 @@ //! These problems are complements: a set S is an independent set iff V\S is a vertex cover. use crate::models::graph::{MaximumIndependentSet, MinimumVertexCover}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -37,8 +36,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vertices)), - ("num_edges", poly!(num_edges)), + ("num_vertices", "num_vertices"), + ("num_edges", "num_edges"), ]) } )] @@ -80,8 +79,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vertices)), - ("num_edges", poly!(num_edges)), + ("num_vertices", "num_vertices"), + ("num_edges", "num_edges"), ]) } )] diff --git a/src/rules/minimumvertexcover_minimumsetcovering.rs b/src/rules/minimumvertexcover_minimumsetcovering.rs index e2f130f1..46ceb105 100644 --- a/src/rules/minimumvertexcover_minimumsetcovering.rs +++ b/src/rules/minimumvertexcover_minimumsetcovering.rs @@ -5,7 +5,6 @@ use crate::models::graph::MinimumVertexCover; use crate::models::set::MinimumSetCovering; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -39,8 +38,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_sets", poly!(num_vertices)), - ("universe_size", poly!(num_edges)), + ("num_sets", "num_vertices"), + ("universe_size", "num_edges"), ]) } )] diff --git a/src/rules/minimumvertexcover_qubo.rs b/src/rules/minimumvertexcover_qubo.rs index b0422e57..45d1b498 100644 --- a/src/rules/minimumvertexcover_qubo.rs +++ b/src/rules/minimumvertexcover_qubo.rs @@ -8,7 +8,6 @@ use crate::models::graph::MinimumVertexCover; use crate::models::optimization::QUBO; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -34,7 +33,7 @@ impl ReductionResult for ReductionVCToQUBO { } #[reduction( - overhead = { ReductionOverhead::new(vec![("num_vars", poly!(num_vertices))]) } + overhead = { ReductionOverhead::new(vec![("num_vars", "num_vertices")]) } )] impl ReduceTo> for MinimumVertexCover { type Result = ReductionVCToQUBO; diff --git a/src/rules/qubo_ilp.rs b/src/rules/qubo_ilp.rs index d43ccc93..c785d4fb 100644 --- a/src/rules/qubo_ilp.rs +++ b/src/rules/qubo_ilp.rs @@ -15,7 +15,7 @@ //! minimize Σ_i Q_ii · x_i + Σ_{i, + /// Output size as symbolic expressions of input size variables. + /// Each entry is (output_field_name, expression). + pub output_size: Vec<(&'static str, Expr)>, } impl ReductionOverhead { - pub fn new(output_size: Vec<(&'static str, Polynomial)>) -> Self { - Self { output_size } + pub fn new(specs: Vec<(&'static str, &'static str)>) -> Self { + Self { + output_size: specs + .into_iter() + .map(|(field, expr_str)| { + let expr = Expr::parse(expr_str).unwrap_or_else(|e| { + panic!("invalid overhead expression for '{field}': {e}") + }); + (field, expr) + }) + .collect(), + } } /// Identity overhead: each output field equals the same-named input field. /// Used by variant cast reductions where problem size doesn't change. pub fn identity(fields: &[&'static str]) -> Self { Self { - output_size: fields.iter().map(|&f| (f, Polynomial::var(f))).collect(), + output_size: fields + .iter() + .map(|&f| { + let expr = Expr::parse(f) + .unwrap_or_else(|e| panic!("invalid identity field name '{f}': {e}")); + (f, expr) + }) + .collect(), } } /// Evaluate output size given input size. /// - /// Uses `round()` for the f64 to usize conversion because polynomial coefficients + /// Uses `round()` for the f64 to usize conversion because expression coefficients /// are typically integers (1, 2, 3, 7, 21, etc.) and any fractional results come /// from floating-point arithmetic imprecision, not intentional fractions. - /// For problem sizes, rounding to nearest integer is the most intuitive behavior. - pub fn evaluate_output_size(&self, input: &ProblemSize) -> ProblemSize { - let fields: Vec<_> = self - .output_size - .iter() - .map(|(name, poly)| (*name, poly.evaluate(input).round() as usize)) - .collect(); - ProblemSize::new(fields) + pub fn evaluate_output_size(&self, input: &ProblemSize) -> Result { + let mut fields = Vec::new(); + for (name, expr) in &self.output_size { + let val = expr.evaluate(input)?; + let rounded = val.round(); + if !rounded.is_finite() || rounded < 0.0 || rounded > usize::MAX as f64 { + return Err(EvalError::Domain { + func: Func::Floor, + detail: format!("overhead for '{name}' produced out-of-range value: {val}") + .into(), + }); + } + fields.push((*name, rounded as usize)); + } + Ok(ProblemSize::new(fields)) } - /// Collect all input variable names referenced by the overhead polynomials. - pub fn input_variable_names(&self) -> HashSet<&'static str> { + /// Collect all input variable names referenced by the overhead expressions. + pub fn input_variable_names(&self) -> HashSet<&str> { self.output_size .iter() - .flat_map(|(_, poly)| poly.variable_names()) + .flat_map(|(_, expr)| expr.variable_names()) .collect() } /// Compose two overheads: substitute self's output into `next`'s input. /// - /// Returns a new overhead whose polynomials map from self's input variables + /// Returns a new overhead whose expressions map from self's input variables /// directly to `next`'s output variables. pub fn compose(&self, next: &ReductionOverhead) -> ReductionOverhead { use std::collections::HashMap; - // Build substitution map: output field name → output polynomial - let mapping: HashMap<&str, &Polynomial> = self + // Build substitution map: output field name → output expression + let mapping: HashMap<&str, &Expr> = self .output_size .iter() - .map(|(name, poly)| (*name, poly)) + .map(|(name, expr)| (*name, expr)) .collect(); let composed = next .output_size .iter() - .map(|(name, poly)| (*name, poly.substitute(&mapping))) + .map(|(name, expr)| (*name, expr.substitute(&mapping))) .collect(); ReductionOverhead { @@ -75,12 +99,12 @@ impl ReductionOverhead { } } - /// Get the polynomial for a named output field. - pub fn get(&self, name: &str) -> Option<&Polynomial> { + /// Get the expression for a named output field. + pub fn get(&self, name: &str) -> Option<&Expr> { self.output_size .iter() .find(|(n, _)| *n == name) - .map(|(_, p)| p) + .map(|(_, e)| e) } } diff --git a/src/rules/sat_circuitsat.rs b/src/rules/sat_circuitsat.rs index b82084f8..1cd199e8 100644 --- a/src/rules/sat_circuitsat.rs +++ b/src/rules/sat_circuitsat.rs @@ -5,7 +5,7 @@ use crate::models::satisfiability::Satisfiability; use crate::models::specialized::{Assignment, BooleanExpr, Circuit, CircuitSAT}; -use crate::poly; + use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -38,8 +38,8 @@ impl ReductionResult for ReductionSATToCircuit { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_variables", poly!(num_vars) + poly!(num_clauses) + poly!(1)), - ("num_assignments", poly!(num_clauses) + poly!(2)), + ("num_variables", "num_vars + num_clauses + 1"), + ("num_assignments", "num_clauses + 2"), ]) } )] diff --git a/src/rules/sat_coloring.rs b/src/rules/sat_coloring.rs index 8337c5b9..e7000672 100644 --- a/src/rules/sat_coloring.rs +++ b/src/rules/sat_coloring.rs @@ -10,7 +10,6 @@ use crate::models::graph::KColoring; use crate::models::satisfiability::Satisfiability; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::sat_maximumindependentset::BoolVar; @@ -300,9 +299,9 @@ impl ReductionSATToColoring { overhead = { ReductionOverhead::new(vec![ // 2*num_vars + 3 (base) + 5*(num_literals - num_clauses) (OR gadgets) - ("num_vertices", poly!(2 * num_vars) + poly!(5 * num_literals) + poly!(num_clauses).scale(-5.0) + poly!(3)), + ("num_vertices", "2 * num_vars + 5 * num_literals - 5 * num_clauses + 3"), // 3 (triangle) + 3*num_vars + 11*(num_literals - num_clauses) (OR gadgets) + 2*num_clauses (set_true) - ("num_edges", poly!(3 * num_vars) + poly!(11 * num_literals) + poly!(num_clauses).scale(-9.0) + poly!(3)), + ("num_edges", "3 * num_vars + 11 * num_literals - 9 * num_clauses + 3"), ]) } )] diff --git a/src/rules/sat_ksat.rs b/src/rules/sat_ksat.rs index 5321fde6..a75045f5 100644 --- a/src/rules/sat_ksat.rs +++ b/src/rules/sat_ksat.rs @@ -7,7 +7,6 @@ //! K-SAT -> SAT: Trivial embedding (K-SAT is a special case of SAT) use crate::models::satisfiability::{CNFClause, KSatisfiability, Satisfiability}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -113,11 +112,11 @@ fn add_clause_to_ksat( macro_rules! impl_sat_to_ksat { ($ktype:ty, $k:expr) => { #[reduction(overhead = { - ReductionOverhead::new(vec![ - ("num_clauses", poly!(num_clauses) + poly!(num_literals)), - ("num_vars", poly!(num_vars) + poly!(num_literals)), - ]) - })] + ReductionOverhead::new(vec![ + ("num_clauses", "num_clauses + num_literals"), + ("num_vars", "num_vars + num_literals"), + ]) + })] impl ReduceTo> for Satisfiability { type Result = ReductionSATToKSAT<$ktype>; @@ -188,9 +187,9 @@ macro_rules! impl_ksat_to_sat { ($ktype:ty) => { #[reduction(overhead = { ReductionOverhead::new(vec![ - ("num_clauses", poly!(num_clauses)), - ("num_vars", poly!(num_vars)), - ("num_literals", poly!(num_literals)), + ("num_clauses", "num_clauses"), + ("num_vars", "num_vars"), + ("num_literals", "num_literals"), ]) })] impl ReduceTo for KSatisfiability<$ktype> { diff --git a/src/rules/sat_maximumindependentset.rs b/src/rules/sat_maximumindependentset.rs index f678bf79..8ec7c83d 100644 --- a/src/rules/sat_maximumindependentset.rs +++ b/src/rules/sat_maximumindependentset.rs @@ -10,7 +10,6 @@ use crate::models::graph::MaximumIndependentSet; use crate::models::satisfiability::Satisfiability; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -112,8 +111,8 @@ impl ReductionSATToIS { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_literals)), - ("num_edges", poly!(num_literals ^ 2)), + ("num_vertices", "num_literals"), + ("num_edges", "num_literals ^ 2"), ]) } )] diff --git a/src/rules/sat_minimumdominatingset.rs b/src/rules/sat_minimumdominatingset.rs index 07646902..aa693c63 100644 --- a/src/rules/sat_minimumdominatingset.rs +++ b/src/rules/sat_minimumdominatingset.rs @@ -16,7 +16,6 @@ use crate::models::graph::MinimumDominatingSet; use crate::models::satisfiability::Satisfiability; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::sat_maximumindependentset::BoolVar; @@ -116,8 +115,8 @@ impl ReductionSATToDS { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(3 * num_vars) + poly!(num_clauses)), - ("num_edges", poly!(3 * num_vars) + poly!(num_literals)), + ("num_vertices", "3 * num_vars + num_clauses"), + ("num_edges", "3 * num_vars + num_literals"), ]) } )] diff --git a/src/rules/spinglass_maxcut.rs b/src/rules/spinglass_maxcut.rs index ef6ae91b..748fd595 100644 --- a/src/rules/spinglass_maxcut.rs +++ b/src/rules/spinglass_maxcut.rs @@ -5,7 +5,6 @@ use crate::models::graph::MaxCut; use crate::models::optimization::SpinGlass; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -46,8 +45,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_spins", poly!(num_vertices)), - ("num_interactions", poly!(num_edges)), + ("num_spins", "num_vertices"), + ("num_interactions", "num_edges"), ]) } )] @@ -137,8 +136,8 @@ where #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_spins)), - ("num_edges", poly!(num_interactions)), + ("num_vertices", "num_spins"), + ("num_edges", "num_interactions"), ]) } )] diff --git a/src/rules/spinglass_qubo.rs b/src/rules/spinglass_qubo.rs index 2d39981f..5c408401 100644 --- a/src/rules/spinglass_qubo.rs +++ b/src/rules/spinglass_qubo.rs @@ -6,7 +6,6 @@ //! Transformation: s = 2x - 1 (so x=0 -> s=-1, x=1 -> s=+1) use crate::models::optimization::{SpinGlass, QUBO}; -use crate::poly; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -35,7 +34,7 @@ impl ReductionResult for ReductionQUBOToSG { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_spins", poly!(num_vars)), + ("num_spins", "num_vars"), ]) } )] @@ -112,7 +111,7 @@ impl ReductionResult for ReductionSGToQUBO { #[reduction( overhead = { ReductionOverhead::new(vec![ - ("num_vars", poly!(num_spins)), + ("num_vars", "num_spins"), ]) } )] diff --git a/src/rules/travelingsalesman_ilp.rs b/src/rules/travelingsalesman_ilp.rs index af2b37e2..bb12dc96 100644 --- a/src/rules/travelingsalesman_ilp.rs +++ b/src/rules/travelingsalesman_ilp.rs @@ -7,7 +7,6 @@ use crate::models::graph::TravelingSalesman; use crate::models::optimization::{LinearConstraint, ObjectiveSense, VarBounds, ILP}; -use crate::polynomial::{Monomial, Polynomial}; use crate::reduction; use crate::rules::registry::ReductionOverhead; use crate::rules::traits::{ReduceTo, ReductionResult}; @@ -75,30 +74,8 @@ impl ReductionResult for ReductionTSPToILP { #[reduction( overhead = { ReductionOverhead::new(vec![ - // num_vars = n^2 + 2*m*n - ("num_vars", Polynomial::var_pow("num_vertices", 2) + Polynomial { - terms: vec![Monomial { - coefficient: 2.0, - variables: vec![("num_vertices", 1), ("num_edges", 1)], - }] - }), - // num_constraints = 2n + n(n(n-1) - 2m) + 6mn = n^3 - n^2 + 2n + 4mn - ("num_constraints", Polynomial::var_pow("num_vertices", 3) + Polynomial { - terms: vec![ - Monomial { - coefficient: -1.0, - variables: vec![("num_vertices", 2)], - }, - Monomial { - coefficient: 2.0, - variables: vec![("num_vertices", 1)], - }, - Monomial { - coefficient: 4.0, - variables: vec![("num_vertices", 1), ("num_edges", 1)], - }, - ] - }), + ("num_vars", "num_vertices ^ 2 + 2 * num_vertices * num_edges"), + ("num_constraints", "num_vertices ^ 3 - num_vertices ^ 2 + 2 * num_vertices + 4 * num_vertices * num_edges"), ]) } )] diff --git a/src/unit_tests/export.rs b/src/unit_tests/export.rs index 25c43c63..23cb4aab 100644 --- a/src/unit_tests/export.rs +++ b/src/unit_tests/export.rs @@ -1,5 +1,4 @@ use super::*; -use crate::polynomial::Polynomial; use crate::rules::registry::ReductionOverhead; #[test] @@ -11,61 +10,33 @@ fn test_overhead_to_json_empty() { #[test] fn test_overhead_to_json_single_field() { - let overhead = ReductionOverhead::new(vec![( - "num_vertices", - Polynomial::var("n") + Polynomial::var("m"), - )]); + let overhead = ReductionOverhead::new(vec![("num_vertices", "n + m")]); let entries = overhead_to_json(&overhead); assert_eq!(entries.len(), 1); assert_eq!(entries[0].field, "num_vertices"); - assert_eq!(entries[0].polynomial.len(), 2); - - // Check first monomial: 1*n - assert_eq!(entries[0].polynomial[0].coefficient, 1.0); - assert_eq!( - entries[0].polynomial[0].variables, - vec![("n".to_string(), 1)] - ); - - // Check second monomial: 1*m - assert_eq!(entries[0].polynomial[1].coefficient, 1.0); - assert_eq!( - entries[0].polynomial[1].variables, - vec![("m".to_string(), 1)] - ); + assert_eq!(entries[0].expression, "n + m"); } #[test] -fn test_overhead_to_json_constant_monomial() { - let overhead = ReductionOverhead::new(vec![("num_vars", Polynomial::constant(42.0))]); +fn test_overhead_to_json_constant() { + let overhead = ReductionOverhead::new(vec![("num_vars", "42")]); let entries = overhead_to_json(&overhead); assert_eq!(entries.len(), 1); assert_eq!(entries[0].field, "num_vars"); - assert_eq!(entries[0].polynomial.len(), 1); - assert_eq!(entries[0].polynomial[0].coefficient, 42.0); - assert!(entries[0].polynomial[0].variables.is_empty()); + assert_eq!(entries[0].expression, "42"); } #[test] fn test_overhead_to_json_scaled_power() { - let overhead = - ReductionOverhead::new(vec![("num_edges", Polynomial::var_pow("n", 2).scale(3.0))]); + let overhead = ReductionOverhead::new(vec![("num_edges", "3 * n ^ 2")]); let entries = overhead_to_json(&overhead); assert_eq!(entries.len(), 1); - assert_eq!(entries[0].polynomial.len(), 1); - assert_eq!(entries[0].polynomial[0].coefficient, 3.0); - assert_eq!( - entries[0].polynomial[0].variables, - vec![("n".to_string(), 2)] - ); + assert_eq!(entries[0].expression, "3 * n ^ 2"); } #[test] fn test_overhead_to_json_multiple_fields() { - let overhead = ReductionOverhead::new(vec![ - ("num_vertices", Polynomial::var("n")), - ("num_edges", Polynomial::var_pow("n", 2)), - ]); + let overhead = ReductionOverhead::new(vec![("num_vertices", "n"), ("num_edges", "n ^ 2")]); let entries = overhead_to_json(&overhead); assert_eq!(entries.len(), 2); assert_eq!(entries[0].field, "num_vertices"); @@ -141,7 +112,6 @@ fn test_write_example_creates_files() { write_example("_test_export", &data, &results); - // Verify files exist and contain valid JSON let reduction_path = "docs/paper/examples/_test_export.json"; let results_path = "docs/paper/examples/_test_export.result.json"; @@ -157,7 +127,6 @@ fn test_write_example_creates_files() { serde_json::json!([1, 0, 1]) ); - // Clean up test files let _ = fs::remove_file(reduction_path); let _ = fs::remove_file(results_path); } @@ -190,15 +159,12 @@ fn test_reduction_data_serialization() { }, overhead: vec![OverheadEntry { field: "num_vertices".to_string(), - polynomial: vec![MonomialJson { - coefficient: 1.0, - variables: vec![("n".to_string(), 1)], - }], + expression: "n".to_string(), }], }; let json = serde_json::to_value(&data).unwrap(); assert_eq!(json["overhead"][0]["field"], "num_vertices"); - assert_eq!(json["overhead"][0]["polynomial"][0]["coefficient"], 1.0); + assert_eq!(json["overhead"][0]["expression"], "n"); } #[test] diff --git a/src/unit_tests/expr.rs b/src/unit_tests/expr.rs new file mode 100644 index 00000000..dbe67271 --- /dev/null +++ b/src/unit_tests/expr.rs @@ -0,0 +1,417 @@ +use super::*; +use crate::types::ProblemSize; + +// === Task 1: AST and Evaluator tests === + +#[test] +fn test_eval_num() { + let expr = Expr::Num(42.0); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 42.0); +} + +#[test] +fn test_eval_var() { + let expr = Expr::Var("n".into()); + let size = ProblemSize::new(vec![("n", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 10.0); +} + +#[test] +fn test_eval_unknown_var_defaults_to_zero() { + let expr = Expr::Var("missing".into()); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 0.0); +} + +#[test] +fn test_eval_add() { + let expr = Expr::binop(BinOp::Add, Expr::Num(3.0), Expr::Num(4.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_sub() { + let expr = Expr::binop(BinOp::Sub, Expr::Num(10.0), Expr::Num(3.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_mul() { + let expr = Expr::binop(BinOp::Mul, Expr::Num(3.0), Expr::Var("n".into())); + let size = ProblemSize::new(vec![("n", 5)]); + assert_eq!(expr.evaluate(&size).unwrap(), 15.0); +} + +#[test] +fn test_eval_div() { + let expr = Expr::binop(BinOp::Div, Expr::Num(10.0), Expr::Num(4.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 2.5); +} + +#[test] +fn test_eval_div_by_zero() { + let expr = Expr::binop(BinOp::Div, Expr::Num(1.0), Expr::Num(0.0)); + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::DivideByZero))); +} + +#[test] +fn test_eval_pow() { + let expr = Expr::binop(BinOp::Pow, Expr::Num(2.0), Expr::Num(10.0)); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 1024.0); +} + +#[test] +fn test_eval_pow_fractional_base_negative() { + let expr = Expr::binop(BinOp::Pow, Expr::Num(-2.0), Expr::Num(0.5)); + let size = ProblemSize::new(vec![]); + assert!(matches!( + expr.evaluate(&size), + Err(EvalError::Domain { .. }) + )); +} + +#[test] +fn test_eval_neg() { + let expr = Expr::Neg(Box::new(Expr::Num(5.0))); + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), -5.0); +} + +#[test] +fn test_eval_log2() { + let expr = Expr::Call { + func: Func::Log2, + args: vec![Expr::Num(8.0)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_log2_negative() { + let expr = Expr::Call { + func: Func::Log2, + args: vec![Expr::Num(-1.0)], + }; + let size = ProblemSize::new(vec![]); + assert!(matches!( + expr.evaluate(&size), + Err(EvalError::Domain { .. }) + )); +} + +#[test] +fn test_eval_sqrt() { + let expr = Expr::Call { + func: Func::Sqrt, + args: vec![Expr::Num(25.0)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 5.0); +} + +#[test] +fn test_eval_min() { + let expr = Expr::Call { + func: Func::Min, + args: vec![Expr::Num(3.0), Expr::Num(7.0)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_max() { + let expr = Expr::Call { + func: Func::Max, + args: vec![Expr::Num(3.0), Expr::Num(7.0)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 7.0); +} + +#[test] +fn test_eval_floor() { + let expr = Expr::Call { + func: Func::Floor, + args: vec![Expr::Num(3.7)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 3.0); +} + +#[test] +fn test_eval_ceil() { + let expr = Expr::Call { + func: Func::Ceil, + args: vec![Expr::Num(3.2)], + }; + let size = ProblemSize::new(vec![]); + assert_eq!(expr.evaluate(&size).unwrap(), 4.0); +} + +#[test] +fn test_eval_arity_error() { + let expr = Expr::Call { + func: Func::Log2, + args: vec![Expr::Num(1.0), Expr::Num(2.0)], + }; + let size = ProblemSize::new(vec![]); + assert!(matches!(expr.evaluate(&size), Err(EvalError::Arity { .. }))); +} + +#[test] +fn test_eval_complex() { + // 3 * n ^ 2 + 1.44 ^ m + let expr = Expr::binop( + BinOp::Add, + Expr::binop( + BinOp::Mul, + Expr::Num(3.0), + Expr::binop(BinOp::Pow, Expr::Var("n".into()), Expr::Num(2.0)), + ), + Expr::binop(BinOp::Pow, Expr::Num(1.44), Expr::Var("m".into())), + ); + let size = ProblemSize::new(vec![("n", 4), ("m", 3)]); + let result = expr.evaluate(&size).unwrap(); + let expected = 3.0 * 16.0 + 1.44_f64.powi(3); + assert!((result - expected).abs() < 1e-10); +} + +// === Task 2: Parser tests === + +#[test] +fn test_parse_num() { + let expr = Expr::parse("42").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 42.0); +} + +#[test] +fn test_parse_float() { + let expr = Expr::parse("1.44").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 1.44); +} + +#[test] +fn test_parse_var() { + let expr = Expr::parse("num_vertices").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 10.0); +} + +#[test] +fn test_parse_add() { + let expr = Expr::parse("3 + 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 7.0); +} + +#[test] +fn test_parse_precedence_mul_add() { + // 2 + 3 * 4 = 14 (not 20) + let expr = Expr::parse("2 + 3 * 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 14.0); +} + +#[test] +fn test_parse_precedence_pow() { + // 2 ^ 3 ^ 2 = 2 ^ 9 = 512 (right-associative) + let expr = Expr::parse("2 ^ 3 ^ 2").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 512.0); +} + +#[test] +fn test_parse_unary_neg() { + let expr = Expr::parse("-5").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), -5.0); +} + +#[test] +fn test_parse_neg_pow() { + // -2 ^ 2 = -(2^2) = -4 + let expr = Expr::parse("-2 ^ 2").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), -4.0); +} + +#[test] +fn test_parse_parens() { + let expr = Expr::parse("(2 + 3) * 4").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 20.0); +} + +#[test] +fn test_parse_function_log2() { + let expr = Expr::parse("log2(8)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 3.0); +} + +#[test] +fn test_parse_function_max() { + let expr = Expr::parse("max(3, 7)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 7.0); +} + +#[test] +fn test_parse_function_case_insensitive() { + let expr = Expr::parse("Log2(8)").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 3.0); +} + +#[test] +fn test_parse_complex_expression() { + let expr = Expr::parse("3 * num_vertices ^ 2 + 1.44 ^ num_edges").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 4), ("num_edges", 3)]); + let expected = 3.0 * 16.0 + 1.44_f64.powi(3); + assert!((expr.evaluate(&size).unwrap() - expected).abs() < 1e-10); +} + +#[test] +fn test_parse_nested_functions() { + let expr = Expr::parse("floor(log2(16))").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 4.0); +} + +#[test] +fn test_parse_unknown_function() { + let result = Expr::parse("foo(3)"); + assert!(matches!(result, Err(ParseError::UnknownFunction { .. }))); +} + +#[test] +fn test_parse_unexpected_eof() { + let result = Expr::parse("3 +"); + assert!(result.is_err()); +} + +#[test] +fn test_parse_empty() { + let result = Expr::parse(""); + assert!(result.is_err()); +} + +#[test] +fn test_parse_leading_dot() { + let expr = Expr::parse(".5").unwrap(); + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 0.5); +} + +#[test] +fn test_parse_subtraction() { + let expr = Expr::parse("10 - 3 - 2").unwrap(); + // left-associative: (10 - 3) - 2 = 5 + assert_eq!(expr.evaluate(&ProblemSize::new(vec![])).unwrap(), 5.0); +} + +#[test] +fn test_parse_division() { + let expr = Expr::parse("num_vertices / 2").unwrap(); + let size = ProblemSize::new(vec![("num_vertices", 10)]); + assert_eq!(expr.evaluate(&size).unwrap(), 5.0); +} + +// === Task 3: Display and Serde tests === + +#[test] +fn test_display_num_integer() { + let expr = Expr::Num(3.0); + assert_eq!(expr.to_string(), "3"); +} + +#[test] +fn test_display_num_float() { + let expr = Expr::Num(1.44); + assert_eq!(expr.to_string(), "1.44"); +} + +#[test] +fn test_display_var() { + let expr = Expr::Var("num_vertices".into()); + assert_eq!(expr.to_string(), "num_vertices"); +} + +#[test] +fn test_display_add() { + let expr = Expr::parse("a + b").unwrap(); + assert_eq!(expr.to_string(), "a + b"); +} + +#[test] +fn test_display_precedence() { + let expr = Expr::parse("a + b * c").unwrap(); + assert_eq!(expr.to_string(), "a + b * c"); +} + +#[test] +fn test_display_parens_needed() { + let expr = Expr::parse("(a + b) * c").unwrap(); + assert_eq!(expr.to_string(), "(a + b) * c"); +} + +#[test] +fn test_display_pow() { + let expr = Expr::parse("1.44 ^ n").unwrap(); + assert_eq!(expr.to_string(), "1.44 ^ n"); +} + +#[test] +fn test_display_neg() { + let expr = Expr::parse("-x").unwrap(); + assert_eq!(expr.to_string(), "-x"); +} + +#[test] +fn test_display_neg_compound() { + let expr = Expr::parse("-(a + b)").unwrap(); + assert_eq!(expr.to_string(), "-(a + b)"); +} + +#[test] +fn test_display_func() { + let expr = Expr::parse("log2(n)").unwrap(); + assert_eq!(expr.to_string(), "log2(n)"); +} + +#[test] +fn test_display_func_two_args() { + let expr = Expr::parse("max(a, b)").unwrap(); + assert_eq!(expr.to_string(), "max(a, b)"); +} + +#[test] +fn test_roundtrip_complex() { + let cases = vec![ + "3 * n ^ 2 + 1.44 ^ m", + "log2(n) * m", + "max(n, m) + 1", + "floor(n / 2)", + "-(a + b) * c", + "a ^ b ^ c", + "a - b - c", + ]; + let size = ProblemSize::new(vec![("n", 4), ("m", 3), ("a", 2), ("b", 3), ("c", 5)]); + for case in cases { + let expr1 = Expr::parse(case).unwrap(); + let displayed = expr1.to_string(); + let expr2 = Expr::parse(&displayed).unwrap(); + let v1 = expr1.evaluate(&size).unwrap(); + let v2 = expr2.evaluate(&size).unwrap(); + assert!( + (v1 - v2).abs() < 1e-10, + "Round-trip failed for {case}: displayed as {displayed}" + ); + } +} + +#[test] +fn test_serde_roundtrip() { + let expr = Expr::parse("3 * n ^ 2 + 1").unwrap(); + let json = serde_json::to_string(&expr).unwrap(); + let back: Expr = serde_json::from_str(&json).unwrap(); + let size = ProblemSize::new(vec![("n", 5)]); + assert_eq!(expr.evaluate(&size).unwrap(), back.evaluate(&size).unwrap()); +} diff --git a/src/unit_tests/polynomial.rs b/src/unit_tests/polynomial.rs deleted file mode 100644 index 04dfd7f5..00000000 --- a/src/unit_tests/polynomial.rs +++ /dev/null @@ -1,180 +0,0 @@ -use super::*; - -#[test] -fn test_monomial_constant() { - let m = Monomial::constant(5.0); - let size = ProblemSize::new(vec![("n", 10)]); - assert_eq!(m.evaluate(&size), 5.0); -} - -#[test] -fn test_monomial_variable() { - let m = Monomial::var("n"); - let size = ProblemSize::new(vec![("n", 10)]); - assert_eq!(m.evaluate(&size), 10.0); -} - -#[test] -fn test_monomial_var_pow() { - let m = Monomial::var_pow("n", 2); - let size = ProblemSize::new(vec![("n", 5)]); - assert_eq!(m.evaluate(&size), 25.0); -} - -#[test] -fn test_polynomial_add() { - // 3n + 2m - let p = Polynomial::var("n").scale(3.0) + Polynomial::var("m").scale(2.0); - - let size = ProblemSize::new(vec![("n", 10), ("m", 5)]); - assert_eq!(p.evaluate(&size), 40.0); // 3*10 + 2*5 -} - -#[test] -fn test_polynomial_complex() { - // n^2 + 3m - let p = Polynomial::var_pow("n", 2) + Polynomial::var("m").scale(3.0); - - let size = ProblemSize::new(vec![("n", 4), ("m", 2)]); - assert_eq!(p.evaluate(&size), 22.0); // 16 + 6 -} - -#[test] -fn test_poly_macro() { - let size = ProblemSize::new(vec![("n", 5), ("m", 3)]); - - assert_eq!(poly!(n).evaluate(&size), 5.0); - assert_eq!(poly!(n ^ 2).evaluate(&size), 25.0); - assert_eq!(poly!(3 * n).evaluate(&size), 15.0); - assert_eq!(poly!(2 * m ^ 2).evaluate(&size), 18.0); -} - -#[test] -fn test_missing_variable() { - let p = Polynomial::var("missing"); - let size = ProblemSize::new(vec![("n", 10)]); - assert_eq!(p.evaluate(&size), 0.0); // missing var = 0 -} - -#[test] -fn test_polynomial_zero() { - let p = Polynomial::zero(); - let size = ProblemSize::new(vec![("n", 100)]); - assert_eq!(p.evaluate(&size), 0.0); -} - -#[test] -fn test_polynomial_constant() { - let p = Polynomial::constant(42.0); - let size = ProblemSize::new(vec![("n", 100)]); - assert_eq!(p.evaluate(&size), 42.0); -} - -#[test] -fn test_monomial_scale() { - let m = Monomial::var("n").scale(3.0); - let size = ProblemSize::new(vec![("n", 10)]); - assert_eq!(m.evaluate(&size), 30.0); -} - -#[test] -fn test_polynomial_scale() { - let p = Polynomial::var("n").scale(5.0); - let size = ProblemSize::new(vec![("n", 10)]); - assert_eq!(p.evaluate(&size), 50.0); -} - -#[test] -fn test_monomial_multi_variable() { - // n * m^2 - let m = Monomial { - coefficient: 1.0, - variables: vec![("n", 1), ("m", 2)], - }; - let size = ProblemSize::new(vec![("n", 2), ("m", 3)]); - assert_eq!(m.evaluate(&size), 18.0); // 2 * 9 -} - -#[test] -fn test_display_monomial_constant_int() { - assert_eq!(format!("{}", Monomial::constant(5.0)), "5"); -} - -#[test] -fn test_display_monomial_constant_float() { - assert_eq!(format!("{}", Monomial::constant(3.5)), "3.5"); -} - -#[test] -fn test_display_monomial_single_var() { - assert_eq!(format!("{}", Monomial::var("n")), "n"); -} - -#[test] -fn test_display_monomial_neg_one_coeff() { - assert_eq!(format!("{}", Monomial::var("n").scale(-1.0)), "-n"); -} - -#[test] -fn test_display_monomial_scaled_var() { - assert_eq!(format!("{}", Monomial::var("n").scale(3.0)), "3 * n"); -} - -#[test] -fn test_display_monomial_var_pow() { - assert_eq!(format!("{}", Monomial::var_pow("n", 2)), "n^2"); -} - -#[test] -fn test_display_monomial_multi_var() { - let m = Monomial { - coefficient: 2.0, - variables: vec![("n", 1), ("m", 2)], - }; - assert_eq!(format!("{m}"), "2 * n * m^2"); -} - -#[test] -fn test_display_monomial_float_coeff_var() { - let m = Monomial { - coefficient: 1.5, - variables: vec![("n", 1)], - }; - assert_eq!(format!("{m}"), "1.5 * n"); -} - -#[test] -fn test_display_polynomial_zero() { - assert_eq!(format!("{}", Polynomial::zero()), "0"); -} - -#[test] -fn test_display_polynomial_single_term() { - assert_eq!(format!("{}", Polynomial::var("n").scale(3.0)), "3 * n"); -} - -#[test] -fn test_display_polynomial_addition() { - let p = Polynomial::var("n").scale(3.0) + Polynomial::var("m").scale(2.0); - assert_eq!(format!("{p}"), "3 * n + 2 * m"); -} - -#[test] -fn test_display_polynomial_subtraction() { - let p = Polynomial::var("n").scale(3.0) + Polynomial::var("m").scale(-2.0); - assert_eq!(format!("{p}"), "3 * n - 2 * m"); -} - -#[test] -fn test_poly_macro_product() { - let size = ProblemSize::new(vec![("a", 3), ("b", 4)]); - assert_eq!(poly!(a * b).evaluate(&size), 12.0); - assert_eq!(format!("{}", poly!(a * b)), "a * b"); -} - -#[test] -fn test_poly_macro_scaled_product() { - let size = ProblemSize::new(vec![("a", 3), ("b", 4)]); - assert_eq!(poly!(5 * a * b).evaluate(&size), 60.0); - assert_eq!(format!("{}", poly!(5 * a * b)), "5 * a * b"); -} diff --git a/src/unit_tests/reduction_graph.rs b/src/unit_tests/reduction_graph.rs index ed29b5f4..7e28a3e9 100644 --- a/src/unit_tests/reduction_graph.rs +++ b/src/unit_tests/reduction_graph.rs @@ -1,7 +1,6 @@ //! Tests for ReductionGraph: discovery, path finding, and typed API. use crate::models::satisfiability::KSatisfiability; -use crate::poly; use crate::prelude::*; use crate::rules::{MinimizeSteps, ReductionGraph, TraversalDirection}; use crate::topology::{SimpleGraph, TriangularSubgraph}; @@ -327,37 +326,34 @@ fn test_3sat_to_mis_triangular_overhead() { assert_eq!(edges.len(), 3); // Edge 0: K3SAT → SAT (identity) + assert_eq!(edges[0].get("num_vars").unwrap().to_string(), "num_vars"); assert_eq!( - edges[0].get("num_vars").unwrap().normalized(), - poly!(num_vars) + edges[0].get("num_clauses").unwrap().to_string(), + "num_clauses" ); assert_eq!( - edges[0].get("num_clauses").unwrap().normalized(), - poly!(num_clauses) - ); - assert_eq!( - edges[0].get("num_literals").unwrap().normalized(), - poly!(num_literals) + edges[0].get("num_literals").unwrap().to_string(), + "num_literals" ); // Edge 1: SAT → MIS{SimpleGraph,i32} assert_eq!( - edges[1].get("num_vertices").unwrap().normalized(), - poly!(num_literals) + edges[1].get("num_vertices").unwrap().to_string(), + "num_literals" ); assert_eq!( - edges[1].get("num_edges").unwrap().normalized(), - poly!(num_literals ^ 2) + edges[1].get("num_edges").unwrap().to_string(), + "num_literals ^ 2" ); // Edge 2: MIS{SimpleGraph,i32} → MIS{TriangularSubgraph,i32} assert_eq!( - edges[2].get("num_vertices").unwrap().normalized(), - poly!(num_vertices ^ 2) + edges[2].get("num_vertices").unwrap().to_string(), + "num_vertices * num_vertices" ); assert_eq!( - edges[2].get("num_edges").unwrap().normalized(), - poly!(num_vertices ^ 2) + edges[2].get("num_edges").unwrap().to_string(), + "num_vertices * num_vertices" ); // Compose overheads symbolically along the path. @@ -370,12 +366,12 @@ fn test_3sat_to_mis_triangular_overhead() { // Composed: num_vertices = L², num_edges = L² let composed = graph.compose_path_overhead(&path); assert_eq!( - composed.get("num_vertices").unwrap().normalized(), - poly!(num_literals ^ 2) + composed.get("num_vertices").unwrap().to_string(), + "num_literals * num_literals" ); assert_eq!( - composed.get("num_edges").unwrap().normalized(), - poly!(num_literals ^ 2) + composed.get("num_edges").unwrap().to_string(), + "num_literals * num_literals" ); } @@ -387,8 +383,8 @@ fn test_validate_overhead_variables_valid() { use crate::rules::validate_overhead_variables; let overhead = ReductionOverhead::new(vec![ - ("num_vertices", poly!(num_vars)), - ("num_edges", poly!(num_vars ^ 2)), + ("num_vertices", "num_vars"), + ("num_edges", "num_vars ^ 2"), ]); // Should not panic: inputs {num_vars} ⊆ source, outputs {num_vertices, num_edges} ⊆ target validate_overhead_variables( @@ -406,7 +402,7 @@ fn test_validate_overhead_variables_missing_input() { use crate::rules::registry::ReductionOverhead; use crate::rules::validate_overhead_variables; - let overhead = ReductionOverhead::new(vec![("num_vertices", poly!(num_colors))]); + let overhead = ReductionOverhead::new(vec![("num_vertices", "num_colors")]); validate_overhead_variables( "Source", "Target", @@ -422,7 +418,7 @@ fn test_validate_overhead_variables_missing_output() { use crate::rules::registry::ReductionOverhead; use crate::rules::validate_overhead_variables; - let overhead = ReductionOverhead::new(vec![("num_gates", poly!(num_vars))]); + let overhead = ReductionOverhead::new(vec![("num_gates", "num_vars")]); validate_overhead_variables( "Source", "Target", @@ -437,7 +433,7 @@ fn test_validate_overhead_variables_skips_output_when_empty() { use crate::rules::registry::ReductionOverhead; use crate::rules::validate_overhead_variables; - let overhead = ReductionOverhead::new(vec![("anything", poly!(num_vars))]); + let overhead = ReductionOverhead::new(vec![("anything", "num_vars")]); // Should not panic: target_size_names is empty so output check is skipped validate_overhead_variables("Source", "Target", &overhead, &["num_vars"], &[]); } diff --git a/src/unit_tests/rules/cost.rs b/src/unit_tests/rules/cost.rs index 1be5f44b..a6e915b5 100644 --- a/src/unit_tests/rules/cost.rs +++ b/src/unit_tests/rules/cost.rs @@ -1,11 +1,7 @@ use super::*; -use crate::polynomial::Polynomial; fn test_overhead() -> ReductionOverhead { - ReductionOverhead::new(vec![ - ("n", Polynomial::var("n").scale(2.0)), - ("m", Polynomial::var("m")), - ]) + ReductionOverhead::new(vec![("n", "2 * n"), ("m", "m")]) } #[test] @@ -29,7 +25,9 @@ fn test_minimize_steps() { #[test] fn test_custom_cost() { let cost_fn = CustomCost(|overhead: &ReductionOverhead, size: &ProblemSize| { - let output = overhead.evaluate_output_size(size); + let output = overhead + .evaluate_output_size(size) + .expect("overhead evaluation failed"); (output.get("n").unwrap_or(0) + output.get("m").unwrap_or(0)) as f64 }); let size = ProblemSize::new(vec![("n", 10), ("m", 5)]); diff --git a/src/unit_tests/rules/graph.rs b/src/unit_tests/rules/graph.rs index 41c1bf98..9134fc47 100644 --- a/src/unit_tests/rules/graph.rs +++ b/src/unit_tests/rules/graph.rs @@ -400,7 +400,10 @@ fn test_all_categories_present() { #[test] fn test_empty_path_source_target() { - let path = ReductionPath { steps: vec![] }; + let path = ReductionPath { + steps: vec![], + overheads: vec![], + }; assert!(path.is_empty()); assert_eq!(path.len(), 0); assert!(path.source().is_none()); @@ -415,6 +418,7 @@ fn test_single_node_path() { name: "MaximumIndependentSet".to_string(), variant: BTreeMap::new(), }], + overheads: vec![], }; assert!(!path.is_empty()); assert_eq!(path.len(), 0); // No reductions, just one type diff --git a/src/unit_tests/rules/registry.rs b/src/unit_tests/rules/registry.rs index 6a33f912..a2685abf 100644 --- a/src/unit_tests/rules/registry.rs +++ b/src/unit_tests/rules/registry.rs @@ -1,5 +1,4 @@ use super::*; -use crate::poly; /// Dummy reduce_fn for unit tests that don't exercise runtime reduction. fn dummy_reduce_fn(_: &dyn std::any::Any) -> Box { @@ -8,10 +7,10 @@ fn dummy_reduce_fn(_: &dyn std::any::Any) -> Box