아래 내용은 Udemy에서 Pytorch: Deep Learning and Artificial Intelligence를 보고 정리한 내용이다.
Transfer Learning
Basic Conceps
Recall: Features are hierarchical!
computer vision
ImageNet - Large-scale image dataset
학습 하기 전에것은 좀 힘들다.
Training on ImageNet
2-part CNN
Feature transformer("body")
ANN classifier("head")
"head"를 바꾸고 새로운 해드를 추가 한다.
logistic regression or ANN 등 가능하다.
Freeze the "body"
train only the head (much fuster)
transfer learning 의 장점:
Don't need a lot of data to build a state-of-the-art model
with transfer learniing, this work has been done for us- the earlier features were already trained on lots of data(ImageNet)
Small dataset + a lot less weights helps us train fast
Some pre-trained Models(VGG, ResNet, Inception, MobieNet)
VGG -Visual Geometry Group
VGG16, VGG19
cnn 과 다르지 않고 그냥 bigger.
finally 3 fully-connected
ResNet
a cnn with branches(one branch is the identity function, so the other learns the residual)
ResNet50, ResNet101. ResNet152
ResNetv2-50, ResNetv2-101. ResNetv2-152
ResNetXt
layer에 따라 다르다.
Inception
similiar to resnet
Multiple convolutions in parallel branches
filter sizes (1x1, 3x3, 5x5, etc)
어떤것 선택할 지는 데이터 등에 따라 달라서 해야 한다.
MobileNet
Lightweight: makes a tradeoff between speed and accuracy
meant for less powerful machines(mobile, embedded)
Large Datasets pytorch
다양한 datasets
jpg, png
train_dataset = datasets.ImageFolder(
'data/train',
transform= train_transform
)
Approches to Transfer learning
2가지 부분
2-part CNN
Imagine the computation in 2 parts;
part 1: z = f(x) # pre-trained CNN - slow
part 2: y_hat = softmax(Wx+b) # logistic regression - fast
for epoch in epochs:
shuffle(batches)
for x, y in batches:
z = vgg_body(x)
y_hat = softmax(z.dot(w) + b)
gw = grad(error(y, y_hat) , w)
gb = grad(error(y, y_hat) , b)
update w and b using gw and gb
vgg_body -> 이것은 학습이 안된다.
All data is the same
z = f(x) # use cnn to precompute all z's at once
turn(z,y) into tabular dataset
fit(z,y) to logistic regression , never have to look at X again!
how can we use data augmentation
두가지 접근
1. use data augmentation with ImageFolder Dataset + DataLoader
loop
2. precompute Z without data augmentation
(Z,y)
장점과 단점
with data augmentation | without data augmentation | |
Speed | slow(must pass through entire CNN) | fast(only need to pass through 1 dense layer) |
generalization/ accuracy |
possibly better for generalization | possibly worse for generalization |
Transfer learning pytorch code
pytorch transfer learning with data augmentation
binary classifier
normalizes mean and std are standardized for imagenet
train_dataset = datasets.ImageFolder(
'data/train',
transform= train_transform
)
test_dataset = datasets.ImageFolder(
'data/test',
transform= test_transform
)
pytorch vgg16 model pre-trained
model = models.vgg16(pretrained=True)
#freeze vgg weights
for param in model.parameters():
param.requires_grad = False
binary classification
model.classifier = nn.Linear(n_features, 2)
pytorch transfer learning without data augmentation
transform= transforms.Compose([
transforms.Resize(size = 256),
transforms.CenterCrop(size = 224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406] , [0.229, 0.224, 0.225])
])
build the model
vgg = models.vgg16(pretrained=True)
class VGGFeatures(nn.Module):
def __init__(self, vgg):
super(VGGFeatures, self).__init__()
self.vgg = vgg
def forward(self, x):
out = vgg.features(x)
out = vgg.avgpool(out)
out = out.view(out.size(0), -1) #flatten
return out
vggf = VGGFeatures(vgg)
out = vggf(torch.rand(1,3, 224, 224))
out.shape
'교육동영상 > 02. pytorch: Deep Learning' 카테고리의 다른 글
10. Deep Reinforcement Learning (0) | 2021.01.04 |
---|---|
10. GANs (0) | 2020.12.28 |
08. Recommender System (0) | 2020.12.21 |
07. NLP (0) | 2020.12.16 |
06-2. Recurrent Neural Networks, Time Series, and Sequence Data (0) | 2020.12.14 |