mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
修改输出格式 (#1)
* Revise the output format * Revise the output format * Revise the output format * Add files via upload
This commit is contained in:
parent
19b44b7b4c
commit
1ba5426235
50
candle_demo/Cargo.lock
generated
50
candle_demo/Cargo.lock
generated
|
@ -132,6 +132,17 @@ version = "0.22.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "bindgen_cuda"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"num_cpus",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
@ -204,6 +215,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-kernels",
|
||||
"cudarc",
|
||||
"gemm",
|
||||
"half",
|
||||
"memmap2",
|
||||
|
@ -239,6 +252,15 @@ dependencies = [
|
|||
"tokenizers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.6.0"
|
||||
|
@ -329,7 +351,7 @@ checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
|
|||
name = "codegeex4-candle"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bindgen_cuda",
|
||||
"candle-core",
|
||||
"candle-examples",
|
||||
"candle-nn",
|
||||
|
@ -437,6 +459,16 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.11.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56ee2a3fbbd981e1c7ea73cc2af136e754eb22d17436de37155227ee4dbe0cf4"
|
||||
dependencies = [
|
||||
"half",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.10"
|
||||
|
@ -904,6 +936,12 @@ version = "0.29.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.26"
|
||||
|
@ -1172,6 +1210,16 @@ version = "0.2.155"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
|
|
|
@ -59,16 +59,22 @@ impl TextGeneration {
|
|||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => panic!("cannot find the endoftext token"),
|
||||
};
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = 151329;
|
||||
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().expect("output flush error");
|
||||
let start_gen = std::time::Instant::now();
|
||||
println!("start_gen");
|
||||
|
||||
println!("\n start_gen");
|
||||
println!("samplelen {}",sample_len);
|
||||
let mut count = 0;
|
||||
let mut result = vec!();
|
||||
for index in 0..sample_len {
|
||||
count += 1;
|
||||
println!("sample count {}",count);
|
||||
|
@ -96,7 +102,8 @@ impl TextGeneration {
|
|||
}
|
||||
println!("raw generate token {}",next_token);
|
||||
let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
|
||||
print!("{token}");
|
||||
println!("[token:{token}]");
|
||||
result.push(token);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
|
@ -104,6 +111,10 @@ impl TextGeneration {
|
|||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
println!("Result:");
|
||||
for tokens in result {
|
||||
print!("{tokens}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -208,7 +219,12 @@ fn main() -> Result<(),()> {
|
|||
let start = std::time::Instant::now();
|
||||
let config = Config::codegeex4();
|
||||
let device = candle_examples::device(args.cpu).unwrap();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device).unwrap() };
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
||||
let model = Model::new(&config, vb).unwrap();
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
|
Loading…
Reference in New Issue
Block a user