This commit is contained in:
donjuanplatinum 2024-07-12 13:29:57 -04:00
parent 6b8b74b8f8
commit 0d716a791c
2 changed files with 2 additions and 2 deletions

View File

@ -68,7 +68,7 @@ impl RotaryEmbedding {
let inv_freq_len = inv_freq.len(); let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)? let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
.to_dtype(dtype)? .to_dtype(dtype).expect("unalbe to dytpe in Rotray Embedding new")
.reshape((cfg.seq_length, 1))?; .reshape((cfg.seq_length, 1))?;
let freqs = t.matmul(&inv_freq)?; let freqs = t.matmul(&inv_freq)?;
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;

View File

@ -174,7 +174,7 @@ fn main() -> Result<(),()> {
); );
println!( println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.), args.temperature.unwrap_or(0.95),
args.repeat_penalty, args.repeat_penalty,
args.repeat_last_n args.repeat_last_n
); );