mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
加入cuda支持
This commit is contained in:
parent
1ba5426235
commit
84dea1e0bb
|
@ -1,3 +1,11 @@
|
||||||
|
# CPU运行
|
||||||
```
|
```
|
||||||
cargo run --release -- --prompt your prompt
|
cargo run --release -- --prompt your prompt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Cuda运行
|
||||||
|
- 注意 需要cuda为>=12.4以上的版本
|
||||||
|
```
|
||||||
|
cargo build --release --features cuda
|
||||||
|
./target/release/codegeex4-candle --prompt your prompt
|
||||||
|
```
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
//use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use codegeex4_candle::codegeex4::*;
|
use codegeex4_candle::codegeex4::*;
|
||||||
|
|
||||||
|
@ -18,6 +17,7 @@ struct TextGeneration {
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
|
@ -32,6 +32,7 @@ impl TextGeneration {
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
Self {
|
Self {
|
||||||
|
@ -42,6 +43,7 @@ impl TextGeneration {
|
||||||
repeat_last_n,
|
repeat_last_n,
|
||||||
verbose_prompt,
|
verbose_prompt,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +51,6 @@ impl TextGeneration {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
||||||
println!("run starting the token 57");
|
|
||||||
if tokens.is_empty() {
|
if tokens.is_empty() {
|
||||||
panic!("Empty prompts are not supported in the chatglm model.")
|
panic!("Empty prompts are not supported in the chatglm model.")
|
||||||
}
|
}
|
||||||
|
@ -82,7 +83,7 @@ impl TextGeneration {
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
|
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
|
||||||
let logits = self.model.forward(&input).unwrap();
|
let logits = self.model.forward(&input).unwrap();
|
||||||
let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
|
let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
} else {
|
} else {
|
||||||
|
@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
|
||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
args.verbose_prompt,
|
args.verbose_prompt,
|
||||||
&device,
|
&device,
|
||||||
|
dtype,
|
||||||
);
|
);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user