﻿{"id":3809,"date":"2025-11-03T13:20:11","date_gmt":"2025-11-03T05:20:11","guid":{"rendered":"https:\/\/www.leexinghai.com\/aic\/?p=3809"},"modified":"2025-11-03T15:06:13","modified_gmt":"2025-11-03T07:06:13","slug":"p3cliponpv","status":"publish","type":"post","link":"https:\/\/www.leexinghai.com\/aic\/p3cliponpv\/","title":{"rendered":"P5110381-CLIP\u6a21\u578b\u5728PlantVillage\u690d\u7269\u75c5\u5bb3\u8bc6\u522b\u4efb\u52a1\u4e2d\u7684\u5e94\u7528\u63a2\u7a76"},"content":{"rendered":"\n<h2 class=\"wp-block-heading has-text-align-center\">0.Github<\/h2>\n\n\n\n<p><a href=\"https:\/\/github.com\/CrystalChanB31\/clip_on_plantvillage\/\">CrystalChanB31\/clip_on_plantvillage: CLIP\u6a21\u578b\u5728PlantVillage\u690d\u7269\u75c5\u5bb3\u8bc6\u522b\u4efb\u52a1\u4e2d\u7684\u5e94\u7528\u63a2\u7a76<\/a><\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity is-style-dots\"\/>\n\n\n\n<h2 class=\"wp-block-heading has-text-align-center\">1.\u73af\u5883\u51c6\u5907<\/h2>\n\n\n\n<p class=\"has-medium-font-size\"><strong>1.1 \u6570\u636e\u96c6<\/strong><\/p>\n\n\n\n<p><a href=\"https:\/\/www.kaggle.com\/datasets\/abdallahalidev\/plantvillage-dataset\">PlantVillage Dataset<\/a><\/p>\n\n\n\n<p>\u663e\u5361\uff1aNvidia Geforce RTX5090 @ 32GB * 1<\/p>\n\n\n\n<p><strong>1.2 \u8f6f\u4ef6\u73af\u5883\u914d\u7f6e<\/strong><\/p>\n\n\n\n<p>Linux\uff1aUbuntu 24.04LTS\uff08WSL2\uff09<\/p>\n\n\n\n<p>Anaconda\uff1a\u6700\u65b0\u7248\u672c<\/p>\n\n\n\n<p>CUDA\uff1a13.0<\/p>\n\n\n\n<p>Python version info: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0]<br>PyTorch version info: 2.10.0.dev20251026+cu130<\/p>\n\n\n\n<p><strong>1.3 requirements.txt<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>torch&gt;=1.12.0\ntorchvision&gt;=0.13.0\nscikit-learn&gt;=1.0.0\ntqdm&gt;=4.0.0\npillow&gt;=8.0.0\nnumpy&gt;=1.19.0\n# OpenAI CLIP: install from the official GitHub repo\n# This installs the `clip` package used in the code (ViT-B\/32, etc.).\n# If you prefer a released wheel or your environment already contains CLIP, you can omit the line below.\ngit+https:\/\/github.com\/openai\/CLIP.git@main#egg=clip<\/code><\/pre>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity is-style-wide\"\/>\n\n\n\n<h2 class=\"wp-block-heading has-text-align-center\">2.\u6570\u636e\u5904\u7406<\/h2>\n\n\n\n<p><strong>2.1 \u5148\u8fdb\u884c\u6570\u636e\u96c6\u7684\u5212\u5206\uff08\u6d4b\u8bd5\u96c6\uff0c\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\uff09<\/strong><\/p>\n\n\n\n<p>\u6570\u636e\u5206\u7c7b\u65b9\u6cd5\uff1a<\/p>\n\n\n\n<p>\u4e0b\u8f7d\u7684\u6570\u636e\u96c6\u4e2d\u5206\u4e3a <code>color<\/code> , <code>grayscale<\/code> , <code>segmented<\/code> \u4e09\u4e2a\u6587\u4ef6\u5939\uff0c\u8fd9\u91cc\u4ee5 <code>color<\/code> \u6587\u4ef6\u5939\u4e3a\u4f8b\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u8bad\u7ec3\u96c6\u6bd4\u7387\uff1a70%<\/li>\n\n\n\n<li>\u9a8c\u8bc1\u96c6\u6bd4\u7387\uff1a20%<\/li>\n\n\n\n<li>\u6d4b\u8bd5\u96c6\u6bd4\u7387\uff1a10%<\/li>\n<\/ul>\n\n\n\n<p><strong>2.2 \u521b\u5efa\u6570\u636e\u5212\u5206\u65b9\u6cd5\u6587\u4ef6<code>split_data.py<\/code><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># Plantvillage\/split_data.py\nimport os, shutil, random, sys\nfrom pathlib import Path\n\n# ===== \u914d\u7f6e\u533a =====\nSRC_DIR = Path(\".\/dataset\/color\")   # \u4f60\u7684\u6e90\u6570\u636e\uff1acolor \u6587\u4ef6\u5939\u8def\u5f84\nDEST_DIR = Path(\".\/Plantvillage\")                # \u76ee\u6807\u6839\u76ee\u5f55\uff1a\u4f1a\u751f\u6210 train\/val\/test\nTRAIN_RATIO, VAL_RATIO, TEST_RATIO = 0.7, 0.2, 0.1\nSEED = 42\nCLEAR_DEST = False   # \u82e5\u4f60\u591a\u6b21\u5c1d\u8bd5\uff0c\u60f3\u5148\u6e05\u7a7a\u518d\u91cd\u65b0\u62f7\u8d1d\uff0c\u6539\u4e3a True\uff08\u5c0f\u5fc3\uff01\u4f1a\u5220\u9664\u76ee\u6807\u76ee\u5f55\uff09\n\n# ===== \u5de5\u5177\u51fd\u6570 =====\nIMG_EXTS = {\".jpg\", \".jpeg\", \".png\", \".bmp\", \".tif\", \".tiff\"}\n\ndef list_images(d: Path):\n    return &#91;p for p in d.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS]\n\ndef ensure_dirs(*dirs):\n    for d in dirs:\n        d.mkdir(parents=True, exist_ok=True)\n\ndef copy_many(paths, target_dir: Path):\n    ensure_dirs(target_dir)\n    for p in paths:\n        shutil.copy2(p, target_dir \/ p.name)\n\ndef split_indices(n, tr=TRAIN_RATIO, vr=VAL_RATIO, te=TEST_RATIO):\n    \"\"\"\u5bf9\u957f\u5ea6\u4e3a n \u7684\u6570\u7ec4\u7d22\u5f15\uff0c\u8fd4\u56de (train_idx, val_idx, test_idx)\"\"\"\n    idx = list(range(n))\n    random.shuffle(idx)\n\n    if n == 0:\n        return &#91;], &#91;], &#91;]\n    if n == 1:\n        return idx, &#91;], &#91;]            # 1\u5f20\uff1a\u5168\u653etrain\n    if n == 2:\n        return idx&#91;:1], idx&#91;1:], &#91;]   # 2\u5f20\uff1a1\/1\/0\n    if n == 3:\n        return idx&#91;:2], idx&#91;2:], &#91;]   # 3\u5f20\uff1a2\/1\/0\n    if n == 4:\n        return idx&#91;:3], idx&#91;3:], &#91;]   # 4\u5f20\uff1a3\/1\/0\n\n    # n &gt;= 5 \u7528\u6bd4\u4f8b\n    n_train = max(1, int(round(tr * n)))\n    n_val   = max(1, int(round(vr * n)))\n    # \u786e\u4fdd\u4e0d\u8d85\n    if n_train + n_val &gt;= n:\n        n_val = max(1, n - n_train - 1)\n    n_test  = n - n_train - n_val\n    if n_test &lt; 0:\n        n_test = 0\n        # \u518d\u6b21\u7ea0\u504f\n        n_val = min(n_val, n - n_train)\n\n    tr_idx = idx&#91;:n_train]\n    va_idx = idx&#91;n_train:n_train+n_val]\n    te_idx = idx&#91;n_train+n_val:]\n    return tr_idx, va_idx, te_idx\n\ndef main():\n    random.seed(SEED)\n\n    if not SRC_DIR.exists():\n        print(f\"&#91;ERR] \u6e90\u76ee\u5f55\u4e0d\u5b58\u5728\uff1a{SRC_DIR.resolve()}\")\n        sys.exit(1)\n\n    if CLEAR_DEST and DEST_DIR.exists():\n        shutil.rmtree(DEST_DIR)\n    ensure_dirs(DEST_DIR \/ \"train\", DEST_DIR \/ \"val\", DEST_DIR \/ \"test\")\n\n    class_dirs = &#91;p for p in SRC_DIR.iterdir() if p.is_dir()]\n    if not class_dirs:\n        print(f\"&#91;ERR] \u5728 {SRC_DIR} \u4e0b\u672a\u627e\u5230\u7c7b\u522b\u6587\u4ef6\u5939\u3002\u8bf7\u786e\u8ba4\u8def\u5f84\u662f\u5426\u6b63\u786e\uff08\u5e94\u4e3a color\/ \u4e0b\u7684\u5404\u7c7b\u522b\u76ee\u5f55\uff09\u3002\")\n        sys.exit(1)\n\n    total_train = total_val = total_test = 0\n    skipped = 0\n\n    for cls_dir in sorted(class_dirs):\n        imgs = list_images(cls_dir)\n        if len(imgs) == 0:\n            print(f\"&#91;WARN] \u7c7b\u522b {cls_dir.name} \u65e0\u56fe\u7247\uff0c\u8df3\u8fc7\u3002\")\n            skipped += 1\n            continue\n\n        tr_idx, va_idx, te_idx = split_indices(len(imgs))\n        tr_imgs = &#91;imgs&#91;i] for i in tr_idx]\n        va_imgs = &#91;imgs&#91;i] for i in va_idx]\n        te_imgs = &#91;imgs&#91;i] for i in te_idx]\n\n        copy_many(tr_imgs, DEST_DIR \/ \"train\" \/ cls_dir.name)\n        copy_many(va_imgs, DEST_DIR \/ \"val\"   \/ cls_dir.name)\n        copy_many(te_imgs, DEST_DIR \/ \"test\"  \/ cls_dir.name)\n\n        total_train += len(tr_imgs)\n        total_val   += len(va_imgs)\n        total_test  += len(te_imgs)\n\n        print(f\"&#91;OK] {cls_dir.name}: {len(imgs)} =&gt; train {len(tr_imgs)}, val {len(va_imgs)}, test {len(te_imgs)}\")\n\n    print(\"\\n====== \u6c47\u603b ======\")\n    print(f\"\u7c7b\u522b\u603b\u6570\uff1a{len(class_dirs)}\uff08\u8df3\u8fc7\u7a7a\u7c7b {skipped}\uff09\")\n    print(f\"Train: {total_train} | Val: {total_val} | Test: {total_test}\")\n    print(f\"\u8f93\u51fa\u76ee\u5f55\uff1a{DEST_DIR.resolve()}\")\n\nif __name__ == \"__main__\":\n    main()<\/code><\/pre>\n\n\n\n<p>\u73b0\u5728\u5f53\u524d\u5de5\u4f5c\u76ee\u5f55\u4e0b\u5e94\u5f53\u4f1a\u770b\u5230 .\/PlantVillage\u6587\u4ef6\u5939\uff0c\u6709\u4e09\u4e2a\u5b50\u6587\u4ef6\u5939\uff1atest,train\u548cval\uff0c\u4f7f\u7528\u547d\u4ee4<code>ls -l | grep '^-' | wc -<\/code> \u53ef\u4ee5\u68c0\u67e5\u6587\u4ef6\u5939\u5185\u6587\u4ef6\u6570\u91cf\u60c5\u51b5\uff0c\u786e\u4fdd\u6d4b\u8bd5\u96c6:\u9a8c\u8bc1\u96c6:\u8bad\u7ec3\u96c6\u4e3a1:2:7\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"1899\" height=\"276\" src=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image.png\" alt=\"\" class=\"wp-image-3810\" srcset=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image.png 1899w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image-300x44.png 300w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image-1024x149.png 1024w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image-768x112.png 768w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/image-1536x223.png 1536w\" sizes=\"auto, (max-width: 1899px) 100vw, 1899px\" \/><\/figure>\n\n\n\n<p><strong>2.3 \u5bf9\u5212\u5206\u540e\u7684\u6570\u636e\u96c6\u8fdb\u884c\u89c4\u8303\u5316\u5904\u7406<code>preprocess.py<\/code><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import os\nfrom pathlib import Path\nfrom PIL import Image\nfrom tqdm import tqdm\n\n# 1. \u5b9a\u4e49\u4f60\u7684\u539f\u59cb\u6570\u636e\u96c6\u8def\u5f84\nsource_dir = Path(\".\/Plantvillage\")\n\n# 2. \u5b9a\u4e49\u4f60\u60f3\u8981\u4fdd\u5b58\u65b0\u6570\u636e\u96c6\u7684\u8def\u5f84\ntarget_dir = Path(\".\/Plantvillage_224\")\n\n# 3. \u5b9a\u4e49\u6211\u4eec\u60f3\u8981\u7684\u7edf\u4e00\u5c3a\u5bf8\nnew_size = (224, 224)\n\n# \u786e\u4fdd PIL \u4f7f\u7528\u9ad8\u8d28\u91cf\u7684\u7f29\u653e\u7b97\u6cd5\nresample_filter = Image.Resampling.BILINEAR\n\ndef preprocess_images():\n    # \u904d\u5386 train, val, test \u6587\u4ef6\u5939\n    for split in &#91;\"train\", \"val\", \"test\"]:\n        split_path = source_dir \/ split\n        target_split_path = target_dir \/ split\n        \n        if not split_path.is_dir():\n            print(f\"Skipping {split_path}, not a directory.\")\n            continue\n\n        # \u83b7\u53d6\u6240\u6709\u7c7b\u522b\u6587\u4ef6\u5939 (e.g., \"Tomato___Bacterial_spot\")\n        class_dirs = &#91;d for d in split_path.iterdir() if d.is_dir()]\n        print(f\"Found {len(class_dirs)} classes in {split}...\")\n\n        # \u4f7f\u7528 tqdm \u663e\u793a\u603b\u8fdb\u5ea6\n        for class_dir in tqdm(class_dirs, desc=f\"Processing {split} set\"):\n            # \u5728\u65b0\u76ee\u5f55\u4e2d\u521b\u5efa\u5bf9\u5e94\u7684\u7c7b\u522b\u6587\u4ef6\u5939\n            target_class_path = target_split_path \/ class_dir.name\n            target_class_path.mkdir(parents=True, exist_ok=True)\n            \n            # \u904d\u5386\u8fd9\u4e2a\u7c7b\u522b\u4e2d\u7684\u6240\u6709\u56fe\u7247\n            # (\u5047\u8bbe\u662f .jpg, .JPG, .jpeg, .png)\n            image_files = list(class_dir.glob(\"*.jpg\")) + \\\n                          list(class_dir.glob(\"*.JPG\")) + \\\n                          list(class_dir.glob(\"*.jpeg\")) + \\\n                          list(class_dir.glob(\"*.png\"))\n\n            for image_path in image_files:\n                try:\n                    with Image.open(image_path) as img:\n                        # 1. \u8f6c\u6362\u4e3a \"RGB\" (\u9632\u6b62\u6709\u4e9b\u662f P \u6a21\u5f0f\u6216 RGBA)\n                        # 2. \u7f29\u653e\n                        # 3. \u4fdd\u5b58\n                        img_rgb = img.convert(\"RGB\")\n                        img_resized = img_rgb.resize(new_size, resample_filter)\n                        #base_name = image_path.stem\n                        # \u5b9a\u4e49\u65b0\u56fe\u7247\u7684\u4fdd\u5b58\u8def\u5f84\n                        new_image_path = target_class_path \/ image_path.name\n                        img_resized.save(new_image_path, \"JPEG\",quality=95)\n                        \n                except Exception as e:\n                    print(f\"Error processing {image_path}: {e}\")\n\n    print(\"--- Pre-processing Complete!(V2) ---\")\n    print(f\"All images resized and saved to {target_dir}\")\n\nif __name__ == \"__main__\":\n    preprocess_images()<\/code><\/pre>\n\n\n\n<p><strong>2.4 \u521b\u5efa\u6570\u636e\u52a0\u8f7d\u6587\u4ef6<code>data_loader.py<\/code><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch.utils.data import DataLoader\nfrom torchvision import datasets, transforms\nfrom pathlib import Path\nNW = 32\nCLIP_MEAN = &#91;0.48145466, 0.4578275, 0.40821073]\nCLIP_STD = &#91;0.26862954, 0.26130258, 0.27577711]\ndef load_data(data_dir, batch_size=384):\n    \"\"\"\n    \u52a0\u8f7d\u8bad\u7ec3\u3001\u9a8c\u8bc1\u548c\u6d4b\u8bd5\u6570\u636e\n    \"\"\"\n    data_dir = Path(data_dir)\n    # \u6570\u636e\u589e\u5f3a\u548c\u9884\u5904\u7406\n    transform = transforms.Compose(&#91;\n        #transforms.Resize((224, 224)),  # \u8c03\u6574\u5927\u5c0f\n        transforms.ToTensor(),  # \u8f6c\u6362\u4e3a Tensor\n        transforms.Normalize(mean=CLIP_MEAN,std=CLIP_STD)  # \u6807\u51c6\u5316\n    ])\n\n    # \u4f7f\u7528 ImageFolder \u52a0\u8f7d\u6570\u636e\u96c6\n    train_data = datasets.ImageFolder(root=data_dir \/ 'train', transform=transform)\n    val_data = datasets.ImageFolder(root=data_dir \/ 'val', transform=transform)\n    test_data = datasets.ImageFolder(root=data_dir \/ 'test', transform=transform)\n\n    # \u521b\u5efa DataLoader\n    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,num_workers=NW,pin_memory=True)\n    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,num_workers=NW,pin_memory=True)\n    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False,num_workers=NW,pin_memory=True)\n\n    return train_loader, val_loader, test_loader\n\n# \u68c0\u67e5\u52a0\u8f7d\u7684\u6570\u636e\u96c6\nif __name__ == \"__main__\":\n    data_dir = \".\/Plantvillage_224\"  # \u4f60\u7684\u6570\u636e\u96c6\u8def\u5f84\n    train_loader, val_loader, test_loader = load_data(data_dir)\n\n    # \u6253\u5370\u4e00\u4e9bbatch\u6570\u636e\u68c0\u67e5\u52a0\u8f7d\u662f\u5426\u6b63\u786e\n    data_iter = iter(train_loader)\n    images, labels = next(data_iter)\n    print(f\"Batch of images shape: {images.shape}\")\n    print(f\"Batch of labels shape: {labels.shape}\")\n<\/code><\/pre>\n\n\n\n<p><strong>2.5 \u521b\u5efa\u6a21\u578b<code>model.py<\/code><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nimport torch.nn as nn\n\nclass PlantDiseaseModel(nn.Module):\n    def __init__(self, in_channels_img=512, out_channels_img=256, num_classes=38):\n        \"\"\"\n        \u4e00\u4e2a\u6807\u51c6\u7684\u56fe\u50cf\u5206\u7c7b\u6a21\u578b\uff0c\u5b83\u63a5\u6536\u6765\u81ea CLIP \u7684 512 \u7ef4\u7279\u5f81\u3002\n        \"\"\"\n        super(PlantDiseaseModel, self).__init__()\n        \n        # 1. \u56fe\u50cf\u7279\u5f81\u5904\u7406\u5c42\n        # \u8f93\u5165 512 (\u6765\u81ea CLIP), \u8f93\u51fa 256\n        self.image_fc = nn.Linear(in_channels_img, out_channels_img)\n        \n        # 2. \u6700\u7ec8\u5206\u7c7b\u5c42\n        # \u8f93\u5165 256 (\u6765\u81ea image_fc), \u8f93\u51fa num_classes\n        self.fc = nn.Linear(out_channels_img, num_classes)\n        \n        # 3. &#91;\u5220\u9664] \u4e0d\u518d\u9700\u8981 text_fc\n        # self.text_fc = ...\n        \n        # 4. &#91;\u5220\u9664] \u4e0d\u518d\u9700\u8981\u5728\u8fd9\u91cc\u52a0\u8f7d CLIP\n        # self.model, self.transform = ...\n    \n    def forward(self, image_features):\n        \"\"\"\n        \u5b9a\u4e49\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u3002\n        \u8f93\u5165 'image_features' \u662f CLIP \u5df2\u7ecf\u63d0\u53d6\u597d\u7684 &#91;batch_size, 512] \u7279\u5f81\u3002\n        \"\"\"\n        # 1. \u901a\u8fc7\u56fe\u50cf\u5c42\n        # &#91;B, 512] -&gt; &#91;B, 256]\n        x = torch.relu(self.image_fc(image_features.view(image_features.size(0), -1)))\n        \n        # 2. \u901a\u8fc7\u6700\u7ec8\u5206\u7c7b\u5c42\n        # &#91;B, 256] -&gt; &#91;B, num_classes]\n        output = self.fc(x)\n        \n        return output<\/code><\/pre>\n\n\n\n<p><strong>2.6 \u521b\u5efa\u8bad\u7ec3\u6587\u4ef6<code>train.py<\/code><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>#train\nimport torch\nimport torch.optim as optim\nfrom sklearn.metrics import accuracy_score, confusion_matrix,classification_report\nfrom tqdm import tqdm\nfrom model import PlantDiseaseModel  # \u5bfc\u5165 *\u4fee\u6539\u540e* \u7684\u6a21\u578b\nfrom data_loader import load_data\nimport clip\n\n# \u9009\u62e9\u8bbe\u5907\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(device)\n# \u52a0\u8f7d\u6570\u636e\u96c6\ndata_dir = \".\/Plantvillage_224\"\ntrain_loader, val_loader, test_loader = load_data(data_dir)\n\n# --- &#91;\u4fee\u6539] ---\n# (PlantVillage \u662f 38 \u7c7b)\nnum_classes = 38\nmodel = PlantDiseaseModel(in_channels_img=512, out_channels_img=256, num_classes=num_classes).to(device)\n# --- &#91;\u4fee\u6539\u7ed3\u675f] ---\n\n# \u5f3a\u5236\u6a21\u578b\u4e3a float32\nmodel = model.float()\n\n# \u8bbe\u7f6e\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668\ncriterion = torch.nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=1e-4)\n\n# \u52a0\u8f7d CLIP \u6a21\u578b (\u8fd9\u90e8\u5206\u4fdd\u7559\uff0c\u7528\u4e8e\u5728 *\u8bad\u7ec3\u811a\u672c* \u4e2d\u63d0\u53d6\u7279\u5f81)\nclip_model, preprocess = clip.load(\"ViT-B\/32\", device=device)\n\n# \u8bad\u7ec3\u51fd\u6570\ndef train(model, train_loader, val_loader, num_epochs=10):\n    best_accuracy = 0.0  # \u8ddf\u8e2a\u6700\u4f73\u51c6\u786e\u7387\n    best_model_path = \"best_model.pth\" # \u5b9a\u4e49\u6a21\u578b\u4fdd\u5b58\u8def\u5f84\n\n    for epoch in range(num_epochs):\n        model.train()  # \u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f\n        running_loss = 0.0\n        \n        # \u4f7f\u7528 tqdm \u5305\u88c5 train_loader\n        for images, labels in tqdm(train_loader, desc=f\"Epoch {epoch+1}\/{num_epochs} Training\"):\n            images, labels = images.to(device), labels.to(device)\n            images = images.float()\n            labels = labels.long()\n\n            # 1. \u83b7\u53d6\u56fe\u50cf\u7279\u5f81 (\u6765\u81ea CLIP)\n            # (\u5728 no_grad() \u4e2d\u8fd0\u884c clip_model \u4ee5\u8282\u7701\u663e\u5b58\u548c\u65f6\u95f4)\n            with torch.no_grad():\n                image_features = clip_model.encode_image(images)\n            image_features = image_features.float()\n            \n            # 2. \u83b7\u53d6\u6a21\u578b\u8f93\u51fa (\u524d\u5411\u4f20\u64ad)\n            outputs = model(image_features)\n            \n            # 3. \u8ba1\u7b97\u635f\u5931 (\u5728\u8fd9\u91cc\u5b9a\u4e49 'loss')\n            loss = criterion(outputs, labels)\n            \n            # --- &#91;\u8fd9\u662f\u4f60\u9057\u6f0f\u7684\u90e8\u5206 END] ---\n\n            # \u53cd\u5411\u4f20\u64ad\u5e76\u66f4\u65b0\u6743\u91cd\n            optimizer.zero_grad()  # \u6e05\u96f6\u68af\u5ea6\n            loss.backward()  # \u8ba1\u7b97\u68af\u5ea6 (\u73b0\u5728 'loss' \u5df2\u88ab\u5b9a\u4e49)\n            optimizer.step()  # \u66f4\u65b0\u6743\u91cd\n            \n            running_loss += loss.item() # \u7d2f\u52a0\u635f\u5931\n            \n        print(f\"\\nEpoch &#91;{epoch+1}\/{num_epochs}], Loss: {running_loss\/len(train_loader)}\")\n\n        # \u6bcf\u4e2a epoch \u540e\u8fdb\u884c\u9a8c\u8bc1\n        val_accuracy = validate(model, val_loader)\n        \n        # \u68c0\u67e5\u8fd9\u662f\u5426\u662f\u8fc4\u4eca\u4e3a\u6b62\u6700\u597d\u7684\u6a21\u578b\n        if val_accuracy &gt; best_accuracy:\n            best_accuracy = val_accuracy\n            # \u4fdd\u5b58\u5f53\u524d\u6a21\u578b\u7684\u6743\u91cd\n            torch.save(model.state_dict(), best_model_path)\n            print(f\"*** \u65b0\u7684\u6700\u4f73\u6a21\u578b\u5df2\u4fdd\u5b58\uff0c\u51c6\u786e\u7387: {best_accuracy * 100:.2f}% ***\")\n\n# \u9a8c\u8bc1\u51fd\u6570\ndef validate(model, val_loader):\n    model.eval()\n    all_preds = &#91;]\n    all_labels = &#91;]\n    \n    # \u4f7f\u7528 tqdm \u5305\u88c5 val_loader\n    with torch.no_grad():\n        for images, labels in tqdm(val_loader, desc=\"Validating\"):\n            images, labels = images.to(device), labels.to(device)\n            images = images.float()\n            labels = labels.long()\n\n            # 1. \u83b7\u53d6\u56fe\u50cf\u7279\u5f81\n            image_features = clip_model.encode_image(images)\n            image_features = image_features.float()\n            \n            # 2. \u83b7\u53d6\u6a21\u578b\u8f93\u51fa\n            outputs = model(image_features)\n            _, preds = torch.max(outputs, 1)\n            \n            all_preds.extend(preds.cpu().numpy())\n            all_labels.extend(labels.cpu().numpy())\n            \n    # \u8ba1\u7b97\u51c6\u786e\u7387\n    accuracy = accuracy_score(all_labels, all_preds)\n    cm = confusion_matrix(all_labels, all_preds)\n    \n    print(f\"Validation Accuracy: {accuracy * 100:.2f}%\")\n    print(\"\u6df7\u6dc6\u77e9\u9635 (Validation):\")\n    print(cm)\n    \n    return accuracy  # &lt;-- &#91;\u4fee\u6539] \u8fd4\u56de\u8ba1\u7b97\u51fa\u7684\u51c6\u786e\u7387\n# \u6d4b\u8bd5\u51fd\u6570\ndef test(model, test_loader):\n    print(\"\\n--- \u542f\u52a8\u6d4b\u8bd5\u9636\u6bb5 ---\")\n    model.eval()  # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n    all_preds = &#91;]\n    all_labels = &#91;]\n    \n    # \u4ece test_loader \u4e2d\u83b7\u53d6\u7c7b\u522b\u540d\u79f0\uff0c\u7528\u4e8e\u62a5\u544a\n    try:\n        class_names = test_loader.dataset.classes\n    except:\n        class_names = &#91;str(i) for i in range(num_classes)] # \u5907\u7528\u65b9\u6848\n\n    with torch.no_grad():\n        # \u4f7f\u7528 tqdm \u663e\u793a\u8fdb\u5ea6\u6761\n        for images, labels in tqdm(test_loader, desc=\"Testing\"): \n            images, labels = images.to(device), labels.to(device)\n            images = images.float()\n            labels = labels.long()\n\n            # 1. \u83b7\u53d6\u56fe\u50cf\u7279\u5f81 (clip_model \u662f\u5168\u5c40\u53d8\u91cf)\n            image_features = clip_model.encode_image(images)\n            image_features = image_features.float()\n            \n            # 2. \u83b7\u53d6\u6a21\u578b\u8f93\u51fa\n            outputs = model(image_features)\n            \n            # 3. \u83b7\u53d6\u9884\u6d4b\n            _, preds = torch.max(outputs, 1)\n            \n            all_preds.extend(preds.cpu().numpy())\n            all_labels.extend(labels.cpu().numpy())\n            \n    # \u8ba1\u7b97\u6307\u6807\n    accuracy = accuracy_score(all_labels, all_preds)\n    cm = confusion_matrix(all_labels, all_preds)\n    \n    print(f\"\\n--- \u6d4b\u8bd5\u7ed3\u679c ---\")\n    print(f\"Test Accuracy: {accuracy * 100:.2f}%\")\n    \n    print(\"\\n\u6df7\u6dc6\u77e9\u9635 (Test):\")\n    print(cm)\n    \n    # \u6253\u5370\u5206\u7c7b\u62a5\u544a (\u5305\u542b\u7cbe\u786e\u7387, \u53ec\u56de\u7387, F1-score)\n    print(\"\\n\u5206\u7c7b\u62a5\u544a (Test):\")\n    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))\n# \u5f00\u59cb\u8bad\u7ec3\nif __name__ == \"__main__\":\n    best_model_path = \"best_model.pth\"\n\n    # 1. \u8bad\u7ec3\u6a21\u578b (\u73b0\u5728\u5b83\u4f1a\u81ea\u52a8\u4fdd\u5b58 'best_model.pth')\n    train(model, train_loader, val_loader, num_epochs=20)\n    \n    print(\"\\n--- \u8bad\u7ec3\u5b8c\u6210 ---\")\n    print(\"\u6b63\u5728\u52a0\u8f7d\u6700\u4f73\u6a21\u578b\u6743\u91cd\u7528\u4e8e\u6d4b\u8bd5...\")\n\n    # 2. \u52a0\u8f7d\u4fdd\u5b58\u7684 *\u6700\u4f73* \u6a21\u578b\u6743\u91cd\n    model.load_state_dict(torch.load(best_model_path))\n\n    # 3. \u4f7f\u7528\u52a0\u8f7d\u7684 *\u6700\u4f73* \u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\n    test(model, test_loader)<\/code><\/pre>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity is-style-wide\"\/>\n\n\n\n<h2 class=\"wp-block-heading has-text-align-center\">3.\u4f7f\u7528\u6559\u7a0b<\/h2>\n\n\n\n<p>0.\u6587\u4ef6\u76ee\u5f55\u7ed3\u6784\uff1a <\/p>\n\n\n\n<p>\uff08\u5de5\u4f5c\uff09\u6839\u76ee\u5f55<\/p>\n\n\n\n<p>-dataset<\/p>\n\n\n\n<p>--color<\/p>\n\n\n\n<p>-data_loader.py<\/p>\n\n\n\n<p>-split_data.py<\/p>\n\n\n\n<p>-model.py<\/p>\n\n\n\n<p>-train.py<\/p>\n\n\n\n<p><\/p>\n\n\n\n<p>1.\u5148\u8fd0\u884c<code>pip install -r requirements.txt<\/code> \u5b89\u88c5\u4f9d\u8d56<\/p>\n\n\n\n<p>2.\u8fd0\u884c<code>split_data.py<\/code>\u5212\u5206\u6570\u636e\u96c6<\/p>\n\n\n\n<p>3.\u8fd0\u884c<code>train.py<\/code>\u8bad\u7ec3<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity is-style-wide\"\/>\n\n\n\n<h2 class=\"wp-block-heading has-text-align-center\">4.\u8bad\u7ec3\u7ed3\u679c<\/h2>\n\n\n\n<p>\u5728Epoch\u4e3a20\u65f6\uff0c\u6709\u6700\u9ad8\u51c6\u786e\u7387\u4e3a93.18%<\/p>\n\n\n\n<p>\u6a21\u578b\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5b9e\u73b0\u4e8693.49%\u7684\u51c6\u786e\u7387\u3002<\/p>\n\n\n\n<p class=\"has-text-align-right\">precision   recall   f1-score support    <\/p>\n\n\n\n<p class=\"has-text-align-right\">accuracy                                    0.9349        5435<br>macro avg    0.9030     0.8849      0.8910        5435<br>weighted avg    0.9322     0.9349      0.9320        5435<\/p>\n\n\n\n<p class=\"has-text-align-left\">\u8bad\u7ec3\u635f\u5931\u548c\u9a8c\u8bc1\u51c6\u786e\u7387\u4e0eEpoch\u5173\u7cfb\u5982\u4e0b\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"1200\" height=\"500\" src=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot.png\" alt=\"\" class=\"wp-image-3813\" srcset=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot.png 1200w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-300x125.png 300w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1024x427.png 1024w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-768x320.png 768w\" sizes=\"auto, (max-width: 1200px) 100vw, 1200px\" \/><\/figure>\n\n\n\n<p>\u5728Epoch\u4e3a100\u65f6\uff0c\u6709\u6700\u9ad8\u51c6\u786e\u738797.84%<\/p>\n\n\n\n<p>\u6a21\u578b\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5b9e\u73b0\u4e8697.88%\u7684\u51c6\u786e\u7387\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"1200\" height=\"500\" src=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1.png\" alt=\"\" class=\"wp-image-3821\" srcset=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1.png 1200w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1-300x125.png 300w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1-1024x427.png 1024w, https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/training_history_plot-1-768x320.png 768w\" sizes=\"auto, (max-width: 1200px) 100vw, 1200px\" \/><\/figure>\n\n\n\n<p>\u5177\u4f53\u8bad\u7ec3\u7ed3\u679c\u53ef\u4ee5\u770b\u8fd9\u91cc\uff1a<\/p>\n\n\n\n<div class=\"wp-block-file\"><a id=\"wp-block-file--media-ac80082a-5aec-41a0-abc9-c41f1c9b20cc\" href=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/results-1.xlsx\">results<\/a><a href=\"https:\/\/www.leexinghai.com\/aic\/wp-content\/uploads\/2025\/11\/results-1.xlsx\" class=\"wp-block-file__button wp-element-button\" download aria-describedby=\"wp-block-file--media-ac80082a-5aec-41a0-abc9-c41f1c9b20cc\">\u4e0b\u8f7d<\/a><\/div>\n\n\n\n<p><\/p>\n","protected":false},"excerpt":{"rendered":"<p>0.Github CrystalChanB31\/clip_on_plantvillage: CLIP\u6a21\u578b\u5728Pl [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":3819,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[80],"tags":[83],"class_list":["post-3809","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-xmk","tag-zuhui4"],"_links":{"self":[{"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/posts\/3809","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/comments?post=3809"}],"version-history":[{"count":6,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/posts\/3809\/revisions"}],"predecessor-version":[{"id":3823,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/posts\/3809\/revisions\/3823"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/media\/3819"}],"wp:attachment":[{"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/media?parent=3809"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/categories?post=3809"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.leexinghai.com\/aic\/wp-json\/wp\/v2\/tags?post=3809"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}