加入cuda支持

This commit is contained in:
donjuanplatinum 2024-07-15 23:03:55 -04:00
parent 1ba5426235
commit 84dea1e0bb
2 changed files with 13 additions and 3 deletions

View File

@ -1,3 +1,11 @@
# CPU运行
``` ```
cargo run --release -- --prompt your prompt cargo run --release -- --prompt your prompt
``` ```
# Cuda运行
- 注意 需要cuda为>=12.4以上的版本
```
cargo build --release --features cuda
./target/release/codegeex4-candle --prompt your prompt
```

View File

@ -1,4 +1,3 @@
//use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use codegeex4_candle::codegeex4::*; use codegeex4_candle::codegeex4::*;
@ -18,6 +17,7 @@ struct TextGeneration {
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool, verbose_prompt: bool,
dtype: DType,
} }
impl TextGeneration { impl TextGeneration {
@ -32,6 +32,7 @@ impl TextGeneration {
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool, verbose_prompt: bool,
device: &Device, device: &Device,
dtype: DType,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p); let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self { Self {
@ -42,6 +43,7 @@ impl TextGeneration {
repeat_last_n, repeat_last_n,
verbose_prompt, verbose_prompt,
device: device.clone(), device: device.clone(),
dtype,
} }
} }
@ -49,7 +51,6 @@ impl TextGeneration {
use std::io::Write; use std::io::Write;
println!("starting the inference loop"); println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error"); let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
println!("run starting the token 57");
if tokens.is_empty() { if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.") panic!("Empty prompts are not supported in the chatglm model.")
} }
@ -82,7 +83,7 @@ impl TextGeneration {
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error"); let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
let logits = self.model.forward(&input).unwrap(); let logits = self.model.forward(&input).unwrap();
let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap(); let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
let logits = if self.repeat_penalty == 1. { let logits = if self.repeat_penalty == 1. {
logits logits
} else { } else {
@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
args.repeat_last_n, args.repeat_last_n,
args.verbose_prompt, args.verbose_prompt,
&device, &device,
dtype,
); );
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())