FINE TUNING LÀ GÌ

1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn những ai thao tác làm việc với những Model trong deep learning rất nhiều đã nghe/quen thuộc với khái niệm Transfer learning với Fine tuning. Khái niệm tổng quát: Transfer learning là tận dụng tri thức học tập được từ là một vấn đề nhằm vận dụng vào 1 vấn đề gồm liên quan khác. Một ví dụ solo giản: rứa vì chưng train 1 model bắt đầu hoàn toàn đến bài xích toán thù phân các loại chó/mèo, tín đồ ta rất có thể tận dụng tối đa 1 model đã có được train ở ImageNet dataphối cùng với hằng triệu hình họa. Pre-trained Model này sẽ tiến hành train tiếp bên trên tập dataset chó/mèo, quy trình train này diễn ra nkhô hanh rộng, tác dụng thường xuất sắc hơn. Có không hề ít hình trạng Transfer learning, những chúng ta có thể tìm hiểu thêm trong bài bác này: Tổng thích hợp Transfer learning. Trong bài xích này, mình sẽ viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine tuning là gì

quý khách hàng đã xem: Fine tuning là gì

Hiểu đơn giản, fine-tuning là bạn đem 1 pre-trained mã sản phẩm, tận dụng 1 phần hoặc cục bộ những layer, thêm/sửa/xoá 1 vài layer/nhánh nhằm tạo thành 1 mã sản phẩm new. Thường các layer đầu của mã sản phẩm được freeze (đóng góp băng) lại - tức weight các layer này đã không xẩy ra thay đổi quý hiếm vào quá trình train. Lý vị vì chưng các layer này vẫn có chức năng trích xuất thông tin nút trìu tượng thấp , khả năng này được học từ quá trình training trước đó. Ta freeze lại nhằm tận dụng được kĩ năng này với góp Việc train ra mắt nhanh khô rộng (model chỉ phải update weight sinh sống những layer cao). Có tương đối nhiều những Object detect model được xây cất dựa trên những Classifier mã sản phẩm. VD Retina model (Object detect) được tạo ra cùng với backbone là Resnet.

*

1.2 Tại sao pytorch nắm vì chưng Keras ?

Chủ đề bài viết từ bây giờ, bản thân vẫn trả lời fine-tuning Resnet50 - 1 pre-trained Model được hỗ trợ sẵn vào torchvision của pytorch. Tại sao là pytorch mà chưa phải Keras ? Lý vì do câu hỏi fine-tuning model trong keras khôn cùng đơn giản và dễ dàng. Dưới đấy là 1 đoạn code minch hoạ cho Việc xây đắp 1 Unet dựa trên Resnet vào Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Quý Khách rất có thể thấy, fine-tuning mã sản phẩm trong Keras đích thực rất dễ dàng, dễ dàng có tác dụng, dễ hiểu. Việc add thêm những nhánh rất dễ bởi cú pháp đơn giản. Trong pytorch thì ngược chở lại, tạo ra 1 Model Unet giống như sẽ rất vất vả với tinh vi. Người new học tập vẫn chạm chán trở ngại vì chưng trên mạng rất hiếm các hướng dẫn cho việc này. Vậy bắt buộc bài xích này mình đã lý giải cụ thể bí quyết fine-tune trong pytorch để vận dụng vào bài bác tân oán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?


*

khi chú ý vào 1 tấm hình, mắt thông thường sẽ có Xu thế triệu tập quan sát vào 1 vài ba đơn vị bao gồm. Hình ảnh bên trên đó là 1 minh hoạ, màu tiến thưởng được thực hiện để bộc lộ cường độ đam mê. Saliency prediction là bài toán thù mô phỏng sự triệu tập của mắt người Lúc quan lại ngay cạnh 1 tấm hình. Cụ thể, bài bác tân oán yên cầu gây ra 1 Model, mã sản phẩm này dìm hình ảnh đầu vào, trả về 1 mask tế bào rộp cường độ nóng bỏng. do vậy, Model nhấn vào 1 đầu vào image với trả về 1 mask gồm kích cỡ tương đương.

Để rõ hơn về bài toán thù này, bạn cũng có thể hiểu bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataphối thông dụng nhất: SALICON DATASET

2.2 Unet

Note: quý khách hàng rất có thể bỏ qua phần này trường hợp đã biết về Unet

Đây là 1 trong bài xích tân oán Image-to-Image. Để xử lý bài tân oán này, bản thân sẽ xây dựng 1 model theo phong cách xây dựng Unet. Unet là một trong bản vẽ xây dựng được áp dụng các vào bài xích tân oán Image-to-image như: semantic segmentation, tự động color, super resolution ... Kiến trúc của Unet tất cả điểm tựa như cùng với phong cách thiết kế Encoder-Decoder đối xứng, nhận thêm các skip connection từ bỏ Encode sang trọng Decode khớp ứng. Về cơ bản, những layer càng cao càng trích xuất biết tin tại mức trìu tượng cao, điều này đồng nghĩa tương quan cùng với Việc các biết tin mức trìu tượng rẻ nhỏng đường đường nét, màu sắc, độ sắc nét... sẽ bị mất đuối đi vào quá trình lan truyền. Người ta thêm các skip-connection vào để giải quyết vấn đề này.

Với phần Encode, feature-maps được downscale bởi các Convolution. trái lại, ở phần decode, feature-bản đồ được upscale do các Upsampling layer, trong bài xích này mình sử dụng các Convolution Transpose.

*

2.3 Resnet

Để giải quyết bài toán, bản thân sẽ xây dựng Model Unet với backbone là Resnet50. quý khách yêu cầu mày mò về Resnet ví như chưa biết về bản vẽ xây dựng này. Hãy quan cạnh bên hình minch hoạ dưới đây. Resnet50 được chia thành những khối lớn . Unet được kiến tạo cùng với Encoder là Resnet50. Ta vẫn kéo ra output của từng kân hận, sản xuất các skip-connection liên kết từ Encoder quý phái Decoder. Decoder được tạo vì chưng các Convolution Transpose layer (đan xen trong những số đó là các lớp Convolution nhằm mục tiêu mục đích sút số chanel của feature maps -> bớt số lượng weight đến model).

Theo quan điểm cá nhân, pytorch rất đơn giản code, dễ hiểu hơn tương đối nhiều đối với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, việc fine-tuning model trong pytorch lại khó khăn rộng không ít so với Keras. Trong Keras, ta ko đề xuất vượt quan tâm tới phong cách thiết kế, luồng giải pháp xử lý của Model, chỉ việc lôi ra các output tại một số layer một mực làm cho skip-connection, ghép nối cùng tạo thành model new.

Xem thêm: Traần Dần Là Ai - Nhà Tiên Tri Trần Dần Là Ai


*

3. Code

Tất cả code của bản thân được gói gọn vào file notebook Salicon_main.ipynb. Bạn hoàn toàn có thể download về với run code theo links github: github/trungthanhnguyen0502 . Trong nội dung bài viết mình vẫn chỉ đưa ra những đoạn code chủ yếu.

Import những package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, tài liệu có lắp thêm trường đoản cú dimension không giống với Keras/TF/numpy. Đôi khi cùng với numpy tốt keras, hình họa có dimension theo sản phẩm công nghệ tự (batchkích cỡ,h,w,chanel)(batchform size, h, w, chanel)(batchsize,h,w,chanel). Thứ đọng từ bỏ vào Pytorch ngược lại là (batchform size,chanel,h,w)(batchform size, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng dựng 2 hàm toTensor và toNumpy để thay đổi qua lại thân hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv and Deconv

Mình sẽ xây dựng dựng 2 function trả về module Convolution cùng Convolution Transpose (Deconv)

def Deconv(n_input đầu vào, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input đầu vào, n_output, kernel_size=k_kích cỡ, stride=stride, padding=padding, bias=False) bloông chồng = return nn.Sequential(*block) def Conv(n_đầu vào, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input đầu vào, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) bloông xã = return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta đã copy những layer nên giữ từ resnet50 vào unet. Sau kia khởi tạo những Conv / Deconv layer cùng các layer cần thiết.

Forward function: phải bảo vệ luồng giải pháp xử lý của resnet50 được giữ nguyên như là code gốc (trừ Fully-connected layer). Sau kia ta ghép nối những layer lại theo bản vẽ xây dựng Unet sẽ biểu hiện trong phần 2.

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet to make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use lớn reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)Model = Unet(resnet50)mã sản phẩm.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(Model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataphối and Dataloader

Datamix trả nhận 1 các mục những image_path với mask_dir, trả về image và mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/").split(".") mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample mask = sample # to lớn Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transkhung = A.Compose(, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataphối = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_datamix, batch_size=4, shuffle=True, drop_last=True)# Test datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài toán dễ dàng cùng khiến cho dễ nắm bắt, bản thân đang train Theo phong cách dễ dàng và đơn giản duy nhất, ko validate trong qúa trình train mà chỉ lưu lại Model sau 1 số ít epoch tuyệt nhất định

train_params = optimizer = torch.optyên.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = "model"os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce="mean")for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 Test model

img_fns = glob("./Salicon_dataset/image/val/*.jpg")mask_dir = "./Salicon_dataset/mask/val"val_transkhung = A.Compose()mã sản phẩm.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_datamix, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết trái thu được: