diff --git a/candle_demo/Cargo.toml b/candle_demo/Cargo.toml index 406f363..0d80876 100755 --- a/candle_demo/Cargo.toml +++ b/candle_demo/Cargo.toml @@ -11,7 +11,7 @@ description = "Codegeex4" # candle-transformers = {path = "../candle/candle-transformers"} # candle-core = {path = "../candle/candle-core"} # candle-nn = {path = "../candle/candle-nn"} -anyhow = "1.0.86" +#anyhow = "1.0.86" hf-hub = "0.3.2" #tokenizer = "0.1.2" clap = { version = "4.5.6", features = ["derive"] } @@ -25,4 +25,4 @@ candle-transformers = "0.6.0" candle-examples = "0.6.0" candle-nn = "0.6.0" safetensors = "0.4.3" -#safetensors = {path ="../safetensors/safetensors"} \ No newline at end of file +#safetensors = {path ="../safetensors/safetensors"} diff --git a/candle_demo/src/main.rs b/candle_demo/src/main.rs index d269f6e..b20beee 100755 --- a/candle_demo/src/main.rs +++ b/candle_demo/src/main.rs @@ -1,4 +1,4 @@ -use anyhow::{Error as E, Result}; +//use anyhow::{Error as E, Result}; use clap::Parser; use codegeex4_candle::codegeex4::*; @@ -45,13 +45,13 @@ impl TextGeneration { } } - fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(),()> { use std::io::Write; println!("starting the inference loop"); - let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?; + let tokens = self.tokenizer.encode(prompt, true).expect("tokens error"); println!("run starting the token 57"); if tokens.is_empty() { - anyhow::bail!("Empty prompts are not supported in the chatglm model.") + panic!("Empty prompts are not supported in the chatglm model.") } if self.verbose_prompt { for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { @@ -61,12 +61,10 @@ impl TextGeneration { } let mut tokens = tokens.get_ids().to_vec(); let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_vocab(true).get("") { - Some(token) => *token, - None => anyhow::bail!("cannot find the endoftext token"), - }; + let eos_token = 151329; + print!("{prompt}"); - std::io::stdout().flush()?; + std::io::stdout().flush().expect("output flush error"); let start_gen = std::time::Instant::now(); println!("start_gen"); println!("samplelen {}",sample_len); @@ -76,9 +74,9 @@ impl TextGeneration { println!("sample count {}",count); let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input)?; - let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error"); + let logits = self.model.forward(&input).unwrap(); + let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap(); let logits = if self.repeat_penalty == 1. { logits } else { @@ -87,18 +85,19 @@ impl TextGeneration { &logits, self.repeat_penalty, &tokens[start_at..], - )? + ).unwrap() }; - let next_token = self.logits_processor.sample(&logits)?; + let next_token = self.logits_processor.sample(&logits).unwrap(); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token { break; } - let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + println!("raw generate token {}",next_token); + let token = self.tokenizer.decode(&[next_token], true).expect("Token error"); print!("{token}"); - std::io::stdout().flush()?; + std::io::stdout().flush().unwrap(); } let dt = start_gen.elapsed(); println!( @@ -163,7 +162,7 @@ struct Args { repeat_last_n: usize, } -fn main() -> Result<()> { +fn main() -> Result<(),()> { let args = Args::parse(); println!( @@ -182,7 +181,7 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); println!("cache path {}",args.cache_path); - let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())).build()?; + let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())).build().unwrap(); let model_id = match args.model_id { Some(model_id) => model_id.to_string(), @@ -196,21 +195,21 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => api - .model("donjuanplatinum1/tokenizer".to_string()) - .get("chatglm-tokenizer.json")?, + .model("THUDM/codegeex4-all-9b".to_string()) + .get("tokenizer.json").unwrap(), }; let filenames = match args.weight_file { Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], - None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap(), }; println!("retrieved the files in {:?}", start.elapsed()); - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); let start = std::time::Instant::now(); let config = Config::codegeex4(); - let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = Model::new(&config, vb)?; + let device = candle_examples::device(args.cpu).unwrap(); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device).unwrap() }; + let model = Model::new(&config, vb).unwrap(); println!("loaded the model in {:?}", start.elapsed());