Skip to content

Rust Candle 框架与 Pytorch 的张量的索引、切片、连接、变异等价操作

kingzcheung
Published date:
Edit this post

书接上回, 深度学习中,张量还存在大量的索引、切片、连接、变异等操作。由于 Python 的特性,这些张量的操作相对来说是比较方便的。但是在 Candle 中可能会有和 Pytorch 中有差异的地方。

索引与切片

定义一个3x3的矩阵,下面有些操作不提示的话,则会基于这个矩阵:

	[
       [1,2,3],
       [4,5,6],
       [7,8,9], 
    ]

pytorch 中取第一行的数据、取第一列的数据分别为:

    print(x[0]) #tensor([1, 2, 3])
    print(x[:,:1])  #tensor([[1],
                    #        [4],
                    #        [7]])  

candle 中取第一行的数据,取第一列的数据分别为:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::from_vec(data, &[3, 3], &Device::Cpu)?;

    let y = x.i(0)?;
    println!("{y}"); //[1, 2, 3]

    let y =  x.i((.., ..1))?;
    println!("{y}",);   //[[1],
                        // [4],
                        // [7]]

select

select 等价于切片。

在 pytorch 中,tensor.select(0, index) 等价于 tensor[index],而 tensor.select(2, index) 等价于 tensor[:,:,index]

在 candle 中, 等价于 i() 方法。

改变张量的形状

pytorch 中使用 view或者 reshape 重构张量的形状:

    x = torch.tensor(data=[1,2,3,4,5,6,7,8,9])
    y = x.view(3,3)
    y = x.reshape(3,3)
    
    # 输出
    # tensor([[1, 2, 3],
	#         [4, 5, 6],
	#         [7, 8, 9]])
    print(y)

candle 中只有 reshape api 可以重构张量的形状:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::new(data, &Device::Cpu)?;
    let y = x.reshape((3,3))?;
    
    // 输出的结果为:
    //[[1, 2, 3],
    //[4, 5, 6],
    //[7, 8, 9]]
    println!("{y}");

连接张量

很多时候我们需要在给定维度中连接 tensors 中给定的张量序列。连接 操作有 2 种。

cat

假设两个张量 a 和 b 形状均为 (2, 3),分别按第 0 维和第 1 维拼接,pytorch 代码如下:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # 沿行(dim=0)拼接 → 形状 (4, 3)
    y = torch.cat([a,b], dim=0)
    # 结果为:
    # tensor([[ 1,  2,  3],
    #         [ 4,  5,  6],
    #         [ 7,  8,  9],
    #         [10, 11, 12]])
    print(y)
    
    # 沿行(dim=0)拼接 → 形状 (2,6)
    y = torch.cat([a,b], dim=1)
    # 结果为:
    # tensor([[ 1,  2,  3,  7,  8,  9],
    #         [ 4,  5,  6, 10, 11, 12]])
    print(y)

candle 大致一样:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // 沿行(dim=0)拼接 → 形状 (4, 3)
    let y = candle_core::Tensor::cat(&[&a,&b], 0)?;
    // 结果如下:
    // [[ 1,  2,  3],
    // [ 4,  5,  6],
    // [ 7,  8,  9],
    // [10, 11, 12]]
    println!("{y}");
    
    // 沿行(dim=1)拼接 → 形状 (2,6)
    let y = candle_core::Tensor::cat(&[&a,&b], 1)?;
    
    // 结果如下:
    // [[ 1,  2,  3,  7,  8,  9],
    // [ 4,  5,  6, 10, 11, 12]]
    println!("{y}");

stack

stack 是另一个拼接方法,同样假设两个张量 a 和 b 形状均为 (2, 3),分别按第 0 维和第 1 维拼接,pytorch 代码如下:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    y = torch.stack([a,b], dim=0)
    # 结果为:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    print(y)
    
    # 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    y = torch.stack([a,b], dim=1)
    # 结果为:
    # [
    #   [[ 1,  2,  3], [ 7,  8,  9]],
    #   [[ 4,  5,  6], [10, 11, 12]]
    # ]
    print(y)

candle 同理:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 0)?;
    // 结果为:
    // [
    //   [[ 1,  2,  3],[ 4,  5,  6]],
    //   [[ 7,  8,  9],[ 10, 11, 12]]
    // ]
    println!("{y}");
    
    // 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 1)?;

    // 结果为:
    // [
    //   [[ 1,  2,  3], [ 7,  8,  9]],
    //   [[ 4,  5,  6], [10, 11, 12]]
    // ]
    println!("{y}");

cat 与 stack 的区别?

在PyTorch中,torch.cat 和 torch.stack 都用于张量拼接,但关键区别在于是否创建新维度。 可以看到,stack 直接增加了一个维度。

不过,stack 可以使用 cat 代替,因为我们可以手动增加维度。

pytorch:

    # 下面这两个操作等价
    y1 = torch.stack([a,b], dim=0)
    y2 = torch.cat([a.unsqueeze(0),b.unsqueeze(0)], dim=0)
    # 结果为:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    assert y1.equal(y2)

candle:

// 下面这两个操作等价
    let y1 = candle_core::Tensor::stack(&[&a,&b], 1)?;
    let y2 = candle_core::Tensor::cat(&[&a.unsqueeze(0)?,&b.unsqueeze(0)?], 1)?;

    println!("{y1}");
    println!("{y2}");

从上面也可以看出, 两边的升维 API 都是 unsqueeze

有长维就有降维,降维操作是 squeeze:

pytorch

    a = torch.tensor([[1, 2, 3]])
    
    # (2,3) -> (3)
    b = a.squeeze()
    
    # tensor([1, 2, 3])
    print(b.shape)
   

candle

    let a = Tensor::new(vec![&[1i64,2,3]],  &Device::Cpu)?;
    // 降维,参数为维度,默认为0
    // Tensor[2, 3; i64]
    let y = a.squeeze(0)?;

    //Tensor[1, 2, 3; i64]
    println!("{:?}",y);

分割张量

把张量分类为指定的块,默认会返回指定大小的块,但是 如果最后一块不能被指定块大小整除,可能会小于指定的块,可以理解为 除法中的余数。

pytorch

	a = torch.tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
    # 从第 0 维分割
    b = a.chunk(2, dim=0)
    # (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    print(b)
    # 从第 1 维分割
    c = a.chunk(2, dim=1)
    # (tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]]))
    print(c)

candle

    let a = Tensor::new(vec![&[[0i64,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]],  &Device::Cpu)?;
    // 从第 0 维分割
    let b = a.chunk(2, 0)?;
    // (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    println!("a: {a}");
    // 从第 1 维分割
    let c = a.chunk(2, 1)?;
    // [tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]])]
    println!("c: {:?}",c);

转置张量

通过给定张量的 2 个维度,对张量进行转置,使用 transpose。 pytorch

	data = [
       [1,2,3],
       [4,5,6],
    ]
    a = torch.tensor(data)
    # 对第 0 维和第 1 维进行转置
    b = a.transpose(0,1)
    
    # [[1, 4],
    #  [2, 5],
    #  [3, 6]]
    print(b)

candle

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    // 对矩阵进行转置
    let y = a.transpose(0,1)?;

    // 打打印结果:
    //     [[1, 4],
    //      [2, 5],
    //      [3, 6]]
    println!("{y}");

以上操作相对来说是比较常见的一些张量操作,还有更多不太常见的操作不在这里展开。这些是模型推理的一些基础,我们了解到这些基础才能更好地了解模型上的一些操作意义。

Previous
Equivalent Operations of Indexing, Slicing, Concatenation, and Mutation for Tensors in Rust Candle Framework and Pytorch
Next
Basic Equivalent Operations of Tensors in Rust AI Inference Framework Candle and Pytorch