FINE TUNING LÀ GÌ

     
1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn phần nhiều ai thao tác làm việc với các model trong deep learning phần nhiều đã nghe/quen với tư tưởng Transfer learning với Fine tuning. Quan niệm tổng quát: Transfer learning là tận dụng học thức học được từ 1 vấn đề để áp dụng vào 1 vấn đề có liên quan khác. Một ví dụ 1-1 giản: thay vì train 1 model mới hoàn toàn cho việc phân một số loại chó/mèo, tín đồ ta hoàn toàn có thể tận dụng 1 model đã được train bên trên ImageNet dataset với hằng triệu ảnh. Pre-trained mã sản phẩm này sẽ được train tiếp bên trên tập dataset chó/mèo, quá trình train này diễn ra nhanh hơn, hiệu quả thường tốt hơn. Có không ít kiểu Transfer learning, các chúng ta có thể tham khảo trong bài bác này: Tổng vừa lòng Transfer learning. Trong bài bác này, mình sẽ viết về 1 dạng transfer learning phổ biến: Fine-tuning.Bạn sẽ xem: Fine tuning là gì

Hiểu solo giản, fine-tuning là bạn lấy 1 pre-trained model, tận dụng một trong những phần hoặc toàn cục các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo nên 1 mã sản phẩm mới. Thường những layer đầu của mã sản phẩm được freeze (đóng băng) lại - tức weight các layer này sẽ không bị chuyển đổi giá trị trong quy trình train. Lý do bởi các layer này đã có khả năng trích xuất thông tin mức trìu tượng rẻ , năng lực này được học tập từ quá trình training trước đó. Ta freeze lại để tận dụng được khả năng này với giúp vấn đề train ra mắt nhanh hơn (model chỉ cần update weight ở các layer cao). Có rất nhiều các Object detect mã sản phẩm được phát hành dựa trên các Classifier model. VD Retina mã sản phẩm (Object detect) được tạo ra với backbone là Resnet.

Bạn đang xem: Fine tuning là gì


*

1.2 vì sao pytorch thay vì chưng Keras ?

Chủ đề bài viết hôm nay, bản thân sẽ chỉ dẫn fine-tuning Resnet50 - 1 pre-trained model được hỗ trợ sẵn trong torchvision của pytorch. Nguyên nhân là pytorch mà không hẳn Keras ? lý do bởi vấn đề fine-tuning model trong keras rất đối chọi giản. Dưới đó là 1 đoạn code minh hoạ cho việc xây dựng 1 Unet dựa trên Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Bạn rất có thể thấy, fine-tuning mã sản phẩm trong Keras đích thực rất đơn giản, dễ dàng làm, dễ dàng hiểu. Việc địa chỉ cửa hàng thêm các nhánh rất dễ bởi cú pháp 1-1 giản. Vào pytorch thì ngược lại, kiến thiết 1 model Unet tương tự sẽ khá vất vả và phức tạp. Bạn mới học sẽ gặp gỡ khó khăn bởi vì trên mạng không nhiều các hướng dẫn cho vấn đề này. Vậy nên bài bác này mình sẽ hướng dẫn cụ thể cách fine-tune trong pytorch để áp dụng vào việc Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?


*

Khi chú ý vào 1 bức ảnh, mắt thường có xu hướng tập trung nhìn vào 1 vài cửa hàng chính. Ảnh trên đây là 1 minh hoạ, màu đá quý được sử dụng để biểu lộ mức độ thu hút. Saliency prediction là câu hỏi mô bỏng sự triệu tập của mắt bạn khi quan gần cạnh 1 bức ảnh. Gắng thể, bài bác toán yên cầu xây dựng 1 model, model này nhận ảnh đầu vào, trả về 1 mask mô bỏng mức độ thu hút. Như vậy, model nhận vào 1 input đầu vào image cùng trả về 1 mask có kích cỡ tương đương.

Để rõ hơn về câu hỏi này, chúng ta cũng có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataset phổ biến nhất: SALICON DATASET

2.2 Unet

Note: Bạn hoàn toàn có thể bỏ qua phần này nếu đã biết về Unet

Đây là 1 trong những bài toán Image-to-Image. Để giải quyết bài toán này, mình sẽ xây dựng dựng 1 mã sản phẩm theo phong cách thiết kế Unet. Unet là một kiến trúc được áp dụng nhiều trong bài toán Image-to-image như: semantic segmentation, tự động hóa color, super resolution ... Kiến trúc của Unet có điểm giống như với kiến trúc Encoder-Decoder đối xứng, được thêm các skip connection từ Encode sang Decode tương ứng. Về cơ bản, các layer càng cao càng trích xuất thông tin ở nấc trìu tượng cao, điều ấy đồng nghĩa với việc những thông tin nút trìu tượng rẻ như đường nét, color sắc, độ phân giải... Sẽ bị mất non đi trong quy trình lan truyền. Fan ta thêm các skip-connection vào để giải quyết và xử lý vấn đề này.

Với phần Encode, feature-map được downscale bằng những Convolution. Ngược lại, ở chỗ decode, feature-map được upscale bởi những Upsampling layer, trong bài xích này mình sử dụng các Convolution Transpose.

*

2.3 Resnet

Để xử lý bài toán, mình sẽ xây dựng model Unet với backbone là Resnet50. Các bạn nên tò mò về Resnet nếu chưa chắc chắn về kiến trúc này. Hãy quan gần kề hình minh hoạ bên dưới đây. Resnet50 được chia thành các khối béo . Unet được xây đắp với Encoder là Resnet50. Ta sẽ lôi ra output của từng khối, tạo những skip-connection liên kết từ Encoder sang trọng Decoder. Decoder được phát hành bởi các Convolution Transpose layer (xen kẽ trong các số ấy là những lớp Convolution nhằm mục tiêu mục đích sút số chanel của feature bản đồ -> giảm con số weight mang đến model).

Theo quan điểm cá nhân, pytorch rất dễ dàng code, dễ hiểu hơn rất nhiều so cùng với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, việc fine-tuning model vào pytorch lại khó khăn hơn không ít so với Keras. Vào Keras, ta không nên quá thân thiết tới kiến trúc, luồng xử trí của model, chỉ việc lấy ra những output tại 1 số layer cố định làm skip-connection, ghép nối và chế tạo ra mã sản phẩm mới.


*

3. Code

Tất cả code của chính bản thân mình được gói gọn trong tệp tin notebook Salicon_main.ipynb. Bạn có thể tải về và run code theo links github: github/trungthanhnguyen0502 . Trong nội dung bài viết mình đã chỉ đưa ra hồ hết đoạn code chính.

Xem thêm: Nụ Cười Rạng Rỡ, Xinh Tự Nhiên Không Cần &Apos;Diễn Sâu&Apos; Của 9X Hà Thành

Import những package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, dữ liệu có đồ vật tự dimension khác với Keras/TF/numpy. Thường thì với numpy giỏi keras, ảnh có dimension theo đồ vật tự (batchsize,h,w,chanel)(batchsize, h, w, chanel)(batchsize,h,w,chanel). Thứ tự vào Pytorch trái lại là (batchsize,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng dựng 2 hàm toTensor cùng toNumpy để biến đổi qua lại thân hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv and Deconv

Mình sẽ xây dựng dựng 2 function trả về module Convolution và Convolution Transpose (Deconv)

def Deconv(n_input, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block) def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta sẽ copy các layer nên giữ trường đoản cú resnet50 vào unet. Sau đó khởi tạo những Conv / Deconv layer và những layer buộc phải thiết.

Forward function: cần đảm bảo luồng cách xử trí của resnet50 được không thay đổi giống code cội (trừ Fully-connected layer). Tiếp nối ta ghép nối các layer lại theo kiến trúc Unet đã miêu tả trong phần 2.

Tạo model: cần load resnet50 và truyền vào Unet. Đừng quên Freeze các layer của resnet50 trong Unet.

Xem thêm: Hoa Hậu Dương Thùy Linh Sinh Năm Bao Nhiêu, Hoa Hậu Dương Thùy Linh

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet khổng lồ make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use khổng lồ reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)model.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataset trả dìm 1 list những image_path cùng mask_dir, trả về image cùng mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/").split(".") mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample mask = sample # khổng lồ Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transform = A.Compose(, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataset = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)# kiểm tra datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài bác toán đơn giản dễ dàng và làm cho dễ hiểu, mình đã train theo cách dễ dàng nhất, không validate trong qúa trình train cơ mà chỉ lưu model sau 1 số ít epoch độc nhất vô nhị định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = "model"os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce="mean")for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 kiểm tra model

img_fns = glob("./Salicon_dataset/image/val/*.jpg")mask_dir = "./Salicon_dataset/mask/val"val_transform = A.Compose()model.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết quả thu được: