From c755910f28ffaff399da348f02949b53a1d4b320 Mon Sep 17 00:00:00 2001 From: lionarius Date: Tue, 1 Oct 2024 00:04:55 +0300 Subject: [PATCH] . --- .gitignore | 1 + src/cli.rs | 55 +++++++++++++++++--- src/generator.rs | 44 ++++++++++++++++ src/main.rs | 125 ++++++++-------------------------------------- src/translator.rs | 50 +++++++++++++++++++ 5 files changed, 164 insertions(+), 111 deletions(-) create mode 100644 src/generator.rs create mode 100644 src/translator.rs diff --git a/.gitignore b/.gitignore index ea8c4bf..84c4bf5 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/test \ No newline at end of file diff --git a/src/cli.rs b/src/cli.rs index e693321..0c5ebaf 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,5 +1,7 @@ use std::path::PathBuf; +use clap::CommandFactory; + #[derive(Debug, clap::Parser)] #[command(author, version, about, long_about = None)] pub struct Args { @@ -10,19 +12,60 @@ pub struct Args { #[derive(Debug, clap::Subcommand)] pub enum Command { Generate { - #[clap(short, long)] output: PathBuf, - #[clap(short, long, value_parser = clap::value_parser!(u32).range(1..))] + #[clap(short, long, default_value = "10", value_parser = clap::value_parser!(u32).range(1..))] lines: u32, - #[clap(long = "min-op", value_parser = clap::value_parser!(u32).range(1..))] + #[clap(long = "min-op", default_value = "2", value_parser = clap::value_parser!(u32).range(1..))] min_operands: u32, - #[clap(long = "max-op", value_parser = clap::value_parser!(u32).range(1..))] + #[clap(long = "max-op", default_value = "7", value_parser = clap::value_parser!(u32).range(1..))] max_operands: u32, }, Translate { - #[clap(short, long)] input: PathBuf, - #[clap(short, long)] output: PathBuf, }, } + +pub fn validate_args(args: Args) -> Args { + match _validate_args(args) { + Ok(args) => args, + Err(err) => { + let mut command = Args::command(); + err.format(&mut command).exit(); + } + } +} + +fn _validate_args(args: Args) -> Result { + match &args.command { + Command::Generate { + min_operands, + max_operands, + .. + } => { + if min_operands > max_operands { + return Err(clap::Error::raw( + clap::error::ErrorKind::ValueValidation, + "min operands must not be greater than max operands", + )); + } + } + Command::Translate { input, output } => { + if !input.is_file() { + return Err(clap::Error::raw( + clap::error::ErrorKind::ValueValidation, + "input must be an existing file", + )); + } + + if input == output { + return Err(clap::Error::raw( + clap::error::ErrorKind::ValueValidation, + "input and output must not be the same", + )); + } + } + } + + Ok(args) +} diff --git a/src/generator.rs b/src/generator.rs new file mode 100644 index 0000000..3a0541f --- /dev/null +++ b/src/generator.rs @@ -0,0 +1,44 @@ +pub struct FileGenerator<'a, W: std::io::Write, R: rand::Rng> { + writer: &'a mut W, + rand: R, +} + +impl<'a, W: std::io::Write, R: rand::Rng> FileGenerator<'a, W, R> { + const OPERANDS: &'static str = "+-*/"; + const DIGITS: &'static str = "123456789"; + + pub fn new(writer: &'a mut W, rand: R) -> Self { + Self { writer, rand } + } + + pub fn generate( + &mut self, + lines: u32, + min_operands: u32, + max_operands: u32, + ) -> anyhow::Result<()> { + for _ in 0..lines { + let operands = self.rand.gen_range(min_operands..max_operands + 1); + for _ in 0..operands { + let digit_idx = self.rand.gen_range(0..Self::DIGITS.len()); + let operand_idx = self.rand.gen_range(0..Self::OPERANDS.len()); + + let digit = Self::DIGITS.get(digit_idx..digit_idx + 1).unwrap(); + let operand = Self::OPERANDS.get(operand_idx..operand_idx + 1).unwrap(); + + self.writer.write_all(digit.as_bytes())?; + self.writer.write_all(b" ")?; + self.writer.write_all(operand.as_bytes())?; + self.writer.write_all(b" ")?; + } + + let digit = self.rand.gen_range(0..Self::DIGITS.len()); + let digit = Self::DIGITS.get(digit..digit + 1).unwrap(); + self.writer.write_all(digit.as_bytes())?; + + self.writer.write_all(b"\n")?; + } + + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs index 687108f..8ecd79e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,54 +1,13 @@ -use std::{ - collections::HashMap, - io::{Read, Write}, - sync::OnceLock, -}; +use std::{fs, io}; -use clap::{CommandFactory, Parser}; -use rand::Rng; +use clap::Parser; mod cli; +mod generator; +mod translator; -static OPERANDS: &str = "+-*/"; -static DIGITS: &str = "123456789"; - -static TRANSLATION_TABLE: OnceLock> = OnceLock::new(); - -fn valudate_args(args: cli::Args) -> Result { - match &args.command { - cli::Command::Generate { - min_operands, - max_operands, - .. - } => { - if min_operands > max_operands { - return Err(clap::Error::raw( - clap::error::ErrorKind::ValueValidation, - "min operands must not be greater than max operands", - )); - } - } - cli::Command::Translate { input, output } => { - if input == output { - return Err(clap::Error::raw( - clap::error::ErrorKind::ValueValidation, - "input and output must not be the same", - )); - } - } - } - - Ok(args) -} - -fn main() { - let args = match valudate_args(cli::Args::parse()) { - Ok(args) => args, - Err(err) => { - let mut command = cli::Args::command(); - err.format(&mut command).exit(); - } - }; +fn main() -> anyhow::Result<()> { + let args = cli::validate_args(cli::Args::parse()); match args.command { cli::Command::Generate { @@ -57,68 +16,24 @@ fn main() { min_operands, max_operands, } => { - let mut rand = rand::thread_rng(); - let mut output = std::fs::File::create(output).unwrap(); - for _ in 0..lines { - let operands = rand.gen_range(min_operands..max_operands + 1); - for _ in 0..operands { - let digit = rand.gen_range(0..DIGITS.len()); - let operand = rand.gen_range(0..OPERANDS.len()); - - let digit = DIGITS.get(digit..digit + 1).unwrap(); - let operand = OPERANDS.get(operand..operand + 1).unwrap(); - - output.write(digit.as_bytes()).unwrap(); - output.write(b" ").unwrap(); - output.write(operand.as_bytes()).unwrap(); - output.write(b" ").unwrap(); - } - - let digit = rand.gen_range(0..DIGITS.len()); - let digit = DIGITS.get(digit..digit + 1).unwrap(); - output.write(digit.as_bytes()).unwrap(); - - output.write(b"\n").unwrap(); + if !fs::exists(output.parent().unwrap())? { + fs::create_dir_all(output.parent().unwrap())?; } + let mut output = io::BufWriter::new(fs::File::create(output)?); + + let rand = rand::thread_rng(); + + let mut generator = generator::FileGenerator::new(&mut output, rand); + generator.generate(lines, min_operands, max_operands)?; } cli::Command::Translate { input, output } => { - let table = TRANSLATION_TABLE.get_or_init(|| { - HashMap::from([ - ("+", "add"), - ("-", "subtract"), - ("*", "multiply by"), - ("/", "divide by"), - ("0", "zero"), - ("1", "one"), - ("2", "two"), - ("3", "three"), - ("4", "four"), - ("5", "five"), - ("6", "six"), - ("7", "seven"), - ("8", "eight"), - ("9", "nine"), - ]) - }); + let mut input = io::BufReader::new(fs::File::open(input)?); + let mut output = io::BufWriter::new(fs::File::create(output)?); - let mut input = std::fs::File::open(input).unwrap(); - let mut output = std::fs::File::create(output).unwrap(); - - let mut buf = [0u8; 1]; - - loop { - match input.read(&mut buf) { - Ok(0) => break, - Ok(_) => { - let digit = std::str::from_utf8(&buf).unwrap(); - match table.get(digit) { - Some(translation) => output.write(translation.as_bytes()).unwrap(), - None => output.write(digit.as_bytes()).unwrap(), - }; - } - Err(err) => panic!("{}", err), - } - } + let mut translator = translator::FileTranslator::new(&mut input, &mut output); + translator.translate()?; } } + + Ok(()) } diff --git a/src/translator.rs b/src/translator.rs new file mode 100644 index 0000000..6045564 --- /dev/null +++ b/src/translator.rs @@ -0,0 +1,50 @@ +use std::{collections::HashMap, sync::LazyLock}; + +pub struct FileTranslator<'a, R: std::io::Read, W: std::io::Write> { + reader: &'a mut R, + writer: &'a mut W, +} + +static TRANSLATION_TABLE: LazyLock> = LazyLock::new(|| { + HashMap::from([ + ("+", "add"), + ("-", "subtract"), + ("*", "multiply by"), + ("/", "divide by"), + ("0", "zero"), + ("1", "one"), + ("2", "two"), + ("3", "three"), + ("4", "four"), + ("5", "five"), + ("6", "six"), + ("7", "seven"), + ("8", "eight"), + ("9", "nine"), + ]) +}); + +impl<'a, R: std::io::Read, W: std::io::Write> FileTranslator<'a, R, W> { + pub fn new(reader: &'a mut R, writer: &'a mut W) -> Self { + Self { reader, writer } + } + + pub fn translate(&mut self) -> anyhow::Result<()> { + let mut buf = [0u8; 1]; + + loop { + match self.reader.read(&mut buf)? { + 0 => break, + _ => { + let digit = std::str::from_utf8(&buf)?; + match TRANSLATION_TABLE.get(digit) { + Some(translation) => self.writer.write_all(translation.as_bytes())?, + None => self.writer.write_all(digit.as_bytes())?, + }; + } + } + } + + Ok(()) + } +}