修改输出格式 (#1)

* Revise the output format

* Revise the output format

* Revise the output format

* Add files via upload
This commit is contained in:
JasonYANG17 2024-07-16 02:37:16 +08:00 committed by GitHub
parent 19b44b7b4c
commit 1ba5426235
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 5 deletions

50
candle_demo/Cargo.lock generated
View File

@ -132,6 +132,17 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen_cuda"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d"
dependencies = [
"glob",
"num_cpus",
"rayon",
]
[[package]]
name = "bit-set"
version = "0.5.3"
@ -204,6 +215,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
dependencies = [
"byteorder",
"candle-kernels",
"cudarc",
"gemm",
"half",
"memmap2",
@ -239,6 +252,15 @@ dependencies = [
"tokenizers",
]
[[package]]
name = "candle-kernels"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-nn"
version = "0.6.0"
@ -329,7 +351,7 @@ checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
name = "codegeex4-candle"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen_cuda",
"candle-core",
"candle-examples",
"candle-nn",
@ -437,6 +459,16 @@ dependencies = [
"memchr",
]
[[package]]
name = "cudarc"
version = "0.11.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56ee2a3fbbd981e1c7ea73cc2af136e754eb22d17436de37155227ee4dbe0cf4"
dependencies = [
"half",
"libloading",
]
[[package]]
name = "darling"
version = "0.20.10"
@ -904,6 +936,12 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "h2"
version = "0.3.26"
@ -1172,6 +1210,16 @@ version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "libloading"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
]
[[package]]
name = "libm"
version = "0.2.8"

View File

@ -59,16 +59,22 @@ impl TextGeneration {
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = 151329;
print!("{prompt}");
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
println!("start_gen");
println!("\n start_gen");
println!("samplelen {}",sample_len);
let mut count = 0;
let mut result = vec!();
for index in 0..sample_len {
count += 1;
println!("sample count {}",count);
@ -96,7 +102,8 @@ impl TextGeneration {
}
println!("raw generate token {}",next_token);
let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
print!("{token}");
println!("[token:{token}]");
result.push(token);
std::io::stdout().flush().unwrap();
}
let dt = start_gen.elapsed();
@ -104,6 +111,10 @@ impl TextGeneration {
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
Ok(())
}
}
@ -208,7 +219,12 @@ fn main() -> Result<(),()> {
let start = std::time::Instant::now();
let config = Config::codegeex4();
let device = candle_examples::device(args.cpu).unwrap();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device).unwrap() };
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
let model = Model::new(&config, vb).unwrap();
println!("loaded the model in {:?}", start.elapsed());