加入cuda支持

This commit is contained in:
donjuanplatinum 2024-07-15 23:03:55 -04:00
parent 1ba5426235
commit 84dea1e0bb
2 changed files with 13 additions and 3 deletions

View File

@ -1,3 +1,11 @@
# CPU运行
```
cargo run --release -- --prompt your prompt
```
# Cuda运行
- 注意 需要cuda为>=12.4以上的版本
```
cargo build --release --features cuda
./target/release/codegeex4-candle --prompt your prompt
```

View File

@ -1,4 +1,3 @@
//use anyhow::{Error as E, Result};
use clap::Parser;
use codegeex4_candle::codegeex4::*;
@ -18,6 +17,7 @@ struct TextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}
impl TextGeneration {
@ -32,6 +32,7 @@ impl TextGeneration {
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
@ -42,6 +43,7 @@ impl TextGeneration {
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}
@ -49,7 +51,6 @@ impl TextGeneration {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
println!("run starting the token 57");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
@ -82,7 +83,7 @@ impl TextGeneration {
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
let logits = self.model.forward(&input).unwrap();
let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
let logits = if self.repeat_penalty == 1. {
logits
} else {
@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())