mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
加入cuda支持
This commit is contained in:
parent
1ba5426235
commit
84dea1e0bb
|
@ -1,3 +1,11 @@
|
|||
# CPU运行
|
||||
```
|
||||
cargo run --release -- --prompt your prompt
|
||||
```
|
||||
|
||||
# Cuda运行
|
||||
- 注意 需要cuda为>=12.4以上的版本
|
||||
```
|
||||
cargo build --release --features cuda
|
||||
./target/release/codegeex4-candle --prompt your prompt
|
||||
```
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
//use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use codegeex4_candle::codegeex4::*;
|
||||
|
||||
|
@ -18,6 +17,7 @@ struct TextGeneration {
|
|||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
|
@ -32,6 +32,7 @@ impl TextGeneration {
|
|||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
|
@ -42,6 +43,7 @@ impl TextGeneration {
|
|||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +51,6 @@ impl TextGeneration {
|
|||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
||||
println!("run starting the token 57");
|
||||
if tokens.is_empty() {
|
||||
panic!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
|
@ -82,7 +83,7 @@ impl TextGeneration {
|
|||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
|
||||
let logits = self.model.forward(&input).unwrap();
|
||||
let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
|
||||
let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
|
@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
|
|||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
dtype,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in New Issue
Block a user