From 84dea1e0bb4334eaab08dac02404e74f4e78acaf Mon Sep 17 00:00:00 2001 From: donjuanplatinum Date: Mon, 15 Jul 2024 23:03:55 -0400 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5cuda=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- candle_demo/README.md | 8 ++++++++ candle_demo/src/main.rs | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/candle_demo/README.md b/candle_demo/README.md index 8c682e0..c359bb5 100644 --- a/candle_demo/README.md +++ b/candle_demo/README.md @@ -1,3 +1,11 @@ +# CPU运行 ``` cargo run --release -- --prompt your prompt ``` + +# Cuda运行 +- 注意 需要cuda为>=12.4以上的版本 +``` +cargo build --release --features cuda +./target/release/codegeex4-candle --prompt your prompt +``` diff --git a/candle_demo/src/main.rs b/candle_demo/src/main.rs index 69793c0..034cc52 100755 --- a/candle_demo/src/main.rs +++ b/candle_demo/src/main.rs @@ -1,4 +1,3 @@ -//use anyhow::{Error as E, Result}; use clap::Parser; use codegeex4_candle::codegeex4::*; @@ -18,6 +17,7 @@ struct TextGeneration { repeat_penalty: f32, repeat_last_n: usize, verbose_prompt: bool, + dtype: DType, } impl TextGeneration { @@ -32,6 +32,7 @@ impl TextGeneration { repeat_last_n: usize, verbose_prompt: bool, device: &Device, + dtype: DType, ) -> Self { let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { @@ -42,6 +43,7 @@ impl TextGeneration { repeat_last_n, verbose_prompt, device: device.clone(), + dtype, } } @@ -49,7 +51,6 @@ impl TextGeneration { use std::io::Write; println!("starting the inference loop"); let tokens = self.tokenizer.encode(prompt, true).expect("tokens error"); - println!("run starting the token 57"); if tokens.is_empty() { panic!("Empty prompts are not supported in the chatglm model.") } @@ -82,7 +83,7 @@ impl TextGeneration { let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error"); let logits = self.model.forward(&input).unwrap(); - let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap(); + let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap(); let logits = if self.repeat_penalty == 1. { logits } else { @@ -239,6 +240,7 @@ fn main() -> Result<(),()> { args.repeat_last_n, args.verbose_prompt, &device, + dtype, ); pipeline.run(&args.prompt, args.sample_len)?; Ok(())