Tensor Math Operations: The Calculator Inside Your AI Brain đź§®
Imagine you have a magical calculator. But this isn’t a normal calculator that works with single numbers. This calculator can work with boxes full of numbers all at once! That’s what PyTorch tensor operations do.
Think of it like this: Instead of adding 2 + 3, you can add a whole basket of apples to another basket of oranges—all in one go!
The Kitchen Analogy 🍳
Throughout this guide, we’ll think of tensors like ingredient trays in a kitchen:
- Each tray holds ingredients (numbers)
- Math operations are like recipes that transform ingredients
- Broadcasting is like a chef who automatically adjusts portion sizes
1. Tensor Arithmetic Operations ➕➖✖️➗
What Is It?
Just like you learned +, -, ×, ÷ in school, tensors can do the same—but with many numbers at once!
Simple Example
import torch
# Two small trays of numbers
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# Add them
print(a + b) # [5, 7, 9]
# Subtract
print(a - b) # [-3, -3, -3]
# Multiply
print(a * b) # [4, 10, 18]
# Divide
print(a / b) # [0.25, 0.4, 0.5]
What’s Happening?
| Position | a | b | a + b |
|---|---|---|---|
| 1st | 1 | 4 | 5 |
| 2nd | 2 | 5 | 7 |
| 3rd | 3 | 6 | 9 |
Each number pairs up with its partner and does the math!
2. Broadcasting Rules 📢
What Is It?
Sometimes your trays have different sizes. Broadcasting is like a magical copy machine that makes the smaller tray match the bigger one.
The Rule (Super Simple!)
PyTorch looks at shapes from right to left. Dimensions match if:
- They’re equal, OR
- One of them is 1 (gets stretched)
Example
# Big tray: 3 rows, 3 columns
big = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
# Small tray: just 1 row
small = torch.tensor([10, 20, 30])
# Magic! Small gets copied 3 times
result = big + small
# [[11, 22, 33],
# [14, 25, 36],
# [17, 28, 39]]
graph TD A["small: [10, 20, 30]"] --> B["Broadcast!"] B --> C["[10, 20, 30]<br>[10, 20, 30]<br>[10, 20, 30]"] C --> D["Now matches big!"]
3. In-Place Operations ⚡
What Is It?
Normal operations create a new tray. In-place operations change the original tray directly. They have an underscore _ at the end!
Why Use Them?
- Saves memory (no extra tray needed)
- Faster for big data
Example
x = torch.tensor([1.0, 2.0, 3.0])
# Normal: creates new tensor
y = x + 5 # x is still [1, 2, 3]
# In-place: changes x directly
x.add_(5) # x becomes [6, 7, 8]
Common In-Place Operations
| Normal | In-Place |
|---|---|
add() |
add_() |
sub() |
sub_() |
mul() |
mul_() |
div_() |
div_() |
⚠️ Warning: In-place operations can break gradient computation in neural networks. Use carefully!
4. Matrix and Batch Operations 📊
What Is It?
Matrices are 2D grids of numbers. Matrix multiplication is like a special handshake where rows meet columns.
Matrix Multiplication
# 2x3 matrix (2 rows, 3 columns)
A = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
# 3x2 matrix (3 rows, 2 columns)
B = torch.tensor([
[7, 8],
[9, 10],
[11, 12]
])
# Matrix multiply: result is 2x2
C = torch.matmul(A, B)
# Or use @ shortcut
C = A @ B
Batch Operations
When you have many matrices to multiply at once:
# 4 batches of 2x3 matrices
batch_A = torch.randn(4, 2, 3)
# 4 batches of 3x5 matrices
batch_B = torch.randn(4, 3, 5)
# All 4 multiply at once!
result = torch.bmm(batch_A, batch_B)
# Shape: (4, 2, 5)
graph TD A["Batch of 4 matrices A"] --> C["bmm"] B["Batch of 4 matrices B"] --> C C --> D["4 result matrices"]
5. torch.einsum: The Swiss Army Knife đź”§
What Is It?
einsum stands for Einstein Summation. It’s like writing a recipe in shorthand. One line can do what 10 lines of code would normally need!
How It Works
You write a pattern that describes:
- Which dimensions to match
- Which to sum over
- What shape to output
Simple Examples
Matrix multiplication:
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# 'ik,kj->ij' means:
# A has dims i,k | B has dims k,j
# Sum over k, output has i,j
result = torch.einsum('ik,kj->ij', A, B)
Dot product:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
dot = torch.einsum('i,i->', a, b)
# 1*4 + 2*5 + 3*6 = 32
Transpose:
M = torch.randn(3, 4)
M_T = torch.einsum('ij->ji', M)
6. Tensor Reduction Operations 📉
What Is It?
Reduction means squishing many numbers into fewer numbers. Like summarizing a book into one sentence!
Common Reductions
x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
# Sum everything
torch.sum(x) # 21
# Sum each column
torch.sum(x, dim=0) # [5, 7, 9]
# Sum each row
torch.sum(x, dim=1) # [6, 15]
# Other reductions
torch.mean(x.float()) # Average: 3.5
torch.max(x) # Biggest: 6
torch.min(x) # Smallest: 1
torch.prod(x) # Product: 720
graph TD A["[[1,2,3],<br>[4,5,6]]"] --> B["sum#40;dim=0#41;"] A --> C["sum#40;dim=1#41;"] B --> D["[5, 7, 9]"] C --> E["[6, 15]"]
The dim Parameter
dim=0: Squish along rows (result has columns)dim=1: Squish along columns (result has rows)
7. Element-Wise Math Operations 🔢
What Is It?
These operations work on each number individually—like giving every ingredient the same treatment.
Common Operations
x = torch.tensor([1.0, 4.0, 9.0])
# Square root
torch.sqrt(x) # [1, 2, 3]
# Power
torch.pow(x, 2) # [1, 16, 81]
# Absolute value
torch.abs(torch.tensor([-1, 2, -3]))
# [1, 2, 3]
# Exponential (e^x)
torch.exp(torch.tensor([0, 1, 2]))
# [1, 2.718, 7.389]
# Natural log
torch.log(torch.tensor([1, 2.718, 7.389]))
# [0, 1, 2]
# Trigonometry
torch.sin(torch.tensor([0, 3.14159/2]))
# [0, 1]
Special Functions
# Sigmoid (squashes to 0-1)
torch.sigmoid(torch.tensor([0, 2, -2]))
# [0.5, 0.88, 0.12]
# ReLU (negative becomes 0)
torch.relu(torch.tensor([-1, 0, 1]))
# [0, 0, 1]
8. Rounding and Clamping 🎯
Rounding: Making Numbers Neat
x = torch.tensor([1.2, 2.5, 3.7, -1.4])
# Round to nearest integer
torch.round(x) # [1, 2, 4, -1]
# Always round down (floor)
torch.floor(x) # [1, 2, 3, -2]
# Always round up (ceiling)
torch.ceil(x) # [2, 3, 4, -1]
# Cut off decimal (truncate)
torch.trunc(x) # [1, 2, 3, -1]
Clamping: Setting Boundaries
Clamping is like putting guardrails. Numbers can’t go below min or above max.
x = torch.tensor([1, 5, 10, 15, 20])
# Keep between 5 and 15
torch.clamp(x, min=5, max=15)
# [5, 5, 10, 15, 15]
# Only set minimum
torch.clamp(x, min=8)
# [8, 8, 10, 15, 20]
# Only set maximum
torch.clamp(x, max=12)
# [1, 5, 10, 12, 12]
graph LR A["1"] --> B["clamp#40;5,15#41;"] B --> C["5"] D["20"] --> E["clamp#40;5,15#41;"] E --> F["15"] G["10"] --> H["clamp#40;5,15#41;"] H --> I["10 âś“"]
Quick Reference Table đź“‹
| Category | Operation | Example |
|---|---|---|
| Arithmetic | +, -, *, / |
a + b |
| In-place | add_(), mul_() |
x.add_(5) |
| Matrix | matmul, @, bmm |
A @ B |
| Einstein | einsum |
einsum('ij->ji', M) |
| Reduction | sum, mean, max |
sum(x, dim=0) |
| Element | sqrt, exp, log |
torch.sqrt(x) |
| Round | round, floor, ceil |
torch.round(x) |
| Clamp | clamp |
clamp(x, 0, 1) |
You Did It! 🎉
You’ve learned how PyTorch’s tensor math works! Remember:
- Arithmetic = Basic math on many numbers at once
- Broadcasting = Auto-stretching smaller tensors
- In-place = Change without making copies (use
_) - Matrix ops = Row-meets-column multiplication
- einsum = Powerful shorthand for complex operations
- Reductions = Summarize many numbers into few
- Element-wise = Apply functions to each number
- Rounding/Clamping = Control number precision and range
Now go build something amazing! 🚀