本文共 27370 字,大约阅读时间需要 91 分钟。
因为本人太菜了,所以此处就通过一行行debug,然后先记录整个运行逻辑,然后后面再一点点增补每个trick的原理。
def main(opt): setlogging(RANK) if RANK in [-1,0]: print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items())) check_git_status() check_requirements(exclude=['thop'])
这里的 RANK是默认为-1
然后就开启记录logging的函数 在general.py
def set_logging(rank=-1,verbose=True): logging.basicConfig( format="%(message)s", level=logging.INFO if (verbose and rank in [-1,0]) else logging.WARN)
其余两个函数是检测是否做好依赖以及当前是否有搭建github仓库,当然我是没有开启的。。。。
def main(opt): setlogging(RANK) if RANK in [-1,0]: print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items())) check_git_status() check_requirements(exclude=['thop']) #Resume wandb_run=check_wandb_resume(opt) #None #所以这里的if就不上了 opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp)
这里的check_file函数是为了判断传入的数据路径是否合法 general.py
通常如果数据路径不对,就会触发else:的分支
def check_file(file): file=str(file) if Path(file).is_file() or file =='': #exists return file elif file.startswith(('http:/','https:/')): #download url=str(Path(file)).replace(':/','://') file = Path(urllib.parse.unquote(file)).name.split('?')[0] #'%2F' ==> '/' print(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url,file) assert Path(file).exists() and Path(file).stat().st_size>0,f'File download failed: {url}' return file else: files=glob.glob('./**/'+file,recursive=True) # find file assert len(files),f'File not found:{file}' assert len(files)==1, f'Multiple files match '{ file}',specify exact path:{ files}" return files[0]
def main(opt): setlogging(RANK) if RANK in [-1,0]: print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items())) check_git_status() check_requirements(exclude=['thop']) #Resume wandb_run=check_wandb_resume(opt) #None #所以这里的if就不上了 opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp) assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified' opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size))) opt.name='evolve' if opt.evolve else opt.name opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve))
这句opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size)))
的作用是扩展两个列表,记录train和test的sizes,输出[640,640],迷幻,不知道是啥,我们继续
opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve))
定义保存路径,分析一下参数 opt.project
:run/train
opt.name
:exp
然后Path这个路径库,可以直接通过/
拼接 然后这个increment_path
其实是个递增函数,就是用于记录每个epoch的模型 general.py
def increment_path(path,exist_ok=False,sep='',mkdir=False): path=Path(path) if path.exists() and not exist_ok: suffix=path.suffix path=path.with_suffix('') dirs=glob.glob(f"{path}{sep}*") #similar paths matches=[re.search(rf"%s{sep}(\d+)"%path.stem,d) for d in dirs] i=[int(m.groups()[0]) for m in matches if m] #indices n=max(i)+1 if i else 2 # increment number path=Path(f"{path}{sep}{n}{suffix}") #update path dir=path if path.suffix=='' else path.parant #directory if not dir.exists() and mkdir: dir.mkkdir(parents=True,exist_ok=True) return path
正则化:\d+
是有多个数字,path.stem是指取出整个路径的最后一个文件夹的名字,如:
test_path = Path(’/Users/xxx/Desktop/project/data/’)
print(test_path.stem) 输出:‘data’
def main(opt): setlogging(RANK) if RANK in [-1,0]: print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items())) check_git_status() check_requirements(exclude=['thop']) #Resume wandb_run=check_wandb_resume(opt) #None #所以这里的if就不上了 opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp) assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified' opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size))) opt.name='evolve' if opt.evolve else opt.name opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve)) device = select_device(opt.device,batch_size=opt.batch_size)
这里就返回选择设备的函数 select_device torch_utils.py
def select_device(device='',batch_size=None): s=f'YOLOv5 {git_describe() or date_modified()} torch {torch.__version__}' cpu=device.lower()=='cpu' if cpu: os.environ['CUDA_VISIBLE_DEVICES']='-1' elif device: os.environ['CUDA_VISIBLE_DEVICES']=device assert torch.cuda.is_available(),f'CUDA unavailable,invalid device {device} requested' cuda=not cpu and torch.cuda.is_available() if cuda: devices=device.split(',') if device else '0' n=len(devices) if n>1 and batch_size: assert batch_size%n==0,f'batch-size {batch_size} not multiple of GPU count {n}' space=' '* (len(s)+1) for i,d in enumerate(devices): p=torch.cuda.get_device_properities(i) #'GetForce GTX 1080 TI' s+=f"{'' if i==0 else space}CUDA:{d} ({p.name},{p.total_memory/1024**2}MB\n)" else: s+='CPU\n' logger.info(s.encode().decode('ascii','ignore') if platform.system()=='Windows' else s) return torch.device('cuda:0' if cuda else 'cpu')
def main(opt): setlogging(RANK) if RANK in [-1,0]: print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items())) check_git_status() check_requirements(exclude=['thop']) #Resume wandb_run=check_wandb_resume(opt) #None #所以这里的if就不上了 opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp) assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified' opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size))) opt.name='evolve' if opt.evolve else opt.name opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve)) device = select_device(opt.device,batch_size=opt.batch_size) #DDP那部分先忽略。。。。。 if not opt.evolve: train(opt.hyp,opt,device)
此处,opt.hyp
:data/hyp.scratch.yaml
-------分割线---------------------------------------------------------------------------------------------------------
然后进入train函数def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt'
到这里为止,就是解析参数+定义最后模型的保存路径和最优模型保存路径以及结果txt的保存路径
def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False)
此处 yaml.safe_load(f)是加载yaml的标准函数接口, yaml.safe_dump()是将yaml文件序列化
vars(opt)
的作用是把数据类型是Namespace的数据转换为字典的形式。 def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1
此处的init_seeds是初始化随机种子,目的是同一训练策略可复现 general.py
def init_seeds(seed=0): random.seed(seed) np.random.seed(seed) init_torch_seeds(seed)
torch_utils.py
def init_torch_seeds(seed=0): torch.manual_seed(seed) if seed==0: cudnn.benchmark,cudnn.deterministic=False,True else: cudnn.benchmark,cudnn.deterministic=True,False
cudnn.deterministic=True 可避免随机性
cudnn.benchmark =True 随机模式def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp
SummaryWriter(str(save_dir)) 这里是设置tensorboard的保存位置,然后默认不用wandb,因为不会。。。。。
def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp nc=1 if single_cls else int(data_dict['nc']) # 类别数量 names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names'] assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check is_coco=data.endswith('coco.yaml') and nc==80
导入类别数,并加了判断当前的数据集是否为coco数据
def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp nc=1 if single_cls else int(data_dict['nc']) # 类别数量 names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names'] assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check is_coco=data.endswith('coco.yaml') and nc==80 # Model pretrained=weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(RANK): weights=attempt_download(weights) ckpt=torch.load(weights,map_location=device) model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型 exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors state_dict=ckpt['model'].float().state_dict() state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude) model.load_state_dict(state_dict,strict=False) logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) else: model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device)
其中 intersect_dicts 函数是找出anchors, 但从表面上看好像没有什么特别的
torch_utils.py
def intersect_dicts(da,db,exclude=()): return { k:v for k,v in da.items() if k in db and not any(x in k for x in exclude) and v.shape==db[k].shape}
def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp nc=1 if single_cls else int(data_dict['nc']) # 类别数量 names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names'] assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check is_coco=data.endswith('coco.yaml') and nc==80 # Model pretrained=weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(RANK): weights=attempt_download(weights) ckpt=torch.load(weights,map_location=device) model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型 exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors state_dict=ckpt['model'].float().state_dict() state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude) model.load_state_dict(state_dict,strict=False) logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) else: model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) # 组装 训练数据路径和测试数据路径 train_path=data_dict['train'] test_path=data_dict['val'] #动态冻结某层 freeze=[] for k,v in model.named_parameters(): v.requires_grad=True if any(x in k for x in freeze): print('freezing %s'%k) v.requires_grad=False
总结一下:
train.py 中到目前的工作是做了以下工作:def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp nc=1 if single_cls else int(data_dict['nc']) # 类别数量 names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names'] assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check is_coco=data.endswith('coco.yaml') and nc==80 # Model pretrained=weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(RANK): weights=attempt_download(weights) ckpt=torch.load(weights,map_location=device) model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型 exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors state_dict=ckpt['model'].float().state_dict() state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude) model.load_state_dict(state_dict,strict=False) logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) else: model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) # 组装 训练数据路径和测试数据路径 train_path=data_dict['train'] test_path=data_dict['val'] #动态冻结某层 freeze=[] for k,v in model.named_parameters(): v.requires_grad=True if any(x in k for x in freeze): print('freezing %s'%k) v.requires_grad=False #配置优化器参数 nbs=64 #nominal batch size accumulate=max(round(nbs/batch_size),1) #32 batch=2 hyp['weight_decay']*=batch_size*accumulate/nbs #0.0005 logger.info(f'Scaled weight_decay={hyp['weight_decay']}') pg0,pg1,pg2=[],[],[] # optimizer parameter groups for k,v in model.named_modules(): if hasattr(v,'bias') and isinstance(v.bias,nn.Parameter): pg2.append(v.bias) #biases if isinstance(v,nn.BatchNorm2d): pg0.append(v.weight) #no decay elif hasattr(v,'weight') and isinstance(v.weight,nn.Parameter): pg1.append(v.weight) #apply decay if opt.adam: optimizer=optim.Adam(pg0,lr=hyp['lr0'],betas=(hyp['momentum'],0.999)) #adjust beta1 to momentum else: optimizer=optim.SGD(pg0,lr=hyp['lr0'],momentum=hyp['momentum'],nesterov=True) #配置decay和biases 这一步的操作是会在optimizer中的param_groups增加一个字典 optimizer.add_param_group({ 'params':pg1,'weight_decay':hyp['weight_decay']}) optimizer.add_param_group({ 'params':pg2}) logger.info('Optimizer groups:%g .bias, %g conv.weight, %g other' %(len(pg2),len(pg1),len(pg0)) del pg0,pg1,pg2 #配置学习率 if opt.linear_lr: lf=lambda x:(1-x/(epochs-1))*(1.0-hyp['lrf'])+hyp['lrf'] else:#OneCycleLR lf=one_cycle(1,hyp['lrf'],epochs) #cosine 1->hyp['lrf'] scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
此处的one_cycle的代码在general.py
def one_cycle(y1=0.0,y2=1.0,steps=100): return lambda x:((1-math.cos(x*math.pi/steps))/2)*(y2-y1)+y1
def train(hyp,opt,device): #解析参数 save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\ opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers #Directories save_dir=Path(save_dir) wdir=save_dir/'weights' wdir.mkdir(parents=True,exist_ok=True) last=wdir/'last.pt' best=wdir/'best.pt' results_file=save_dir/'results.txt' #Hyperparameters if isinstance(hyp,str): with open(hyp) as f: hyp=yaml.safe_load(f) logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items())) #save run settings with open(save_dir/'hyp.yaml','w') as f: yaml.safe_dump(hyp,f,sort_keys=False) with open(save_dir/ 'opt.yaml','w') as f: yaml.safe_dump(vars(opt),f,sort_keys=False) #Configure plots=not evolve #create plots cuda=device.type!='cpu' init_seeds(2+RANK) #RANK=-1 #导入数据 with open(data) as f: data_dict=yaml.safe_load(f) #Loggers loggers={ 'wandb':None,'tb':None} #loggers dict if RANK in [-1,0]: #TensorBoard if not evolve: prefix=colorstr('tensorboard: ') logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/") loggers['tb']=SummaryWriter(str(save_dir)) # W&B opt.hyp=hyp #add hyperparameters run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id=run_id if opt.resume else None wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict) logger['wandb']=wandb_logger.wandb if logger['wandb']: data_dict=wandb_logger.data_dict weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp nc=1 if single_cls else int(data_dict['nc']) # 类别数量 names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names'] assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check is_coco=data.endswith('coco.yaml') and nc==80 # Model pretrained=weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(RANK): weights=attempt_download(weights) ckpt=torch.load(weights,map_location=device) model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型 exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors state_dict=ckpt['model'].float().state_dict() state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude) model.load_state_dict(state_dict,strict=False) logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) else: model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) # 组装 训练数据路径和测试数据路径 train_path=data_dict['train'] test_path=data_dict['val'] #动态冻结某层 freeze=[] for k,v in model.named_parameters(): v.requires_grad=True if any(x in k for x in freeze): print('freezing %s'%k) v.requires_grad=False #配置优化器参数 nbs=64 #nominal batch size accumulate=max(round(nbs/batch_size),1) #32 batch=2 hyp['weight_decay']*=batch_size*accumulate/nbs #0.0005 logger.info(f'Scaled weight_decay={hyp['weight_decay']}') pg0,pg1,pg2=[],[],[] # optimizer parameter groups for k,v in model.named_modules(): if hasattr(v,'bias') and isinstance(v.bias,nn.Parameter): pg2.append(v.bias) #biases if isinstance(v,nn.BatchNorm2d): pg0.append(v.weight) #no decay elif hasattr(v,'weight') and isinstance(v.weight,nn.Parameter): pg1.append(v.weight) #apply decay if opt.adam: optimizer=optim.Adam(pg0,lr=hyp['lr0'],betas=(hyp['momentum'],0.999)) #adjust beta1 to momentum else: optimizer=optim.SGD(pg0,lr=hyp['lr0'],momentum=hyp['momentum'],nesterov=True) #配置decay和biases 这一步的操作是会在optimizer中的param_groups增加一个字典 optimizer.add_param_group({ 'params':pg1,'weight_decay':hyp['weight_decay']}) optimizer.add_param_group({ 'params':pg2}) logger.info('Optimizer groups:%g .bias, %g conv.weight, %g other' %(len(pg2),len(pg1),len(pg0)) del pg0,pg1,pg2 #配置学习率 if opt.linear_lr: lf=lambda x:(1-x/(epochs-1))*(1.0-hyp['lrf'])+hyp['lrf'] else:#OneCycleLR lf=one_cycle(1,hyp['lrf'],epochs) #cosine 1->hyp['lrf'] scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf) #EMA ema=modelEMA(model) if RANK in [-1,0] else None
这里的EMA 全名:Exponential Moving Average 目的是保持模型参数在一种动态平均的状态。
代码在torch_utils.py
class ModelEMA: def __init__(self,model,decay=0.999,updates=0.): #创建EMA self.ema=deepcopy(model.module if is_parallel(model) else model).eval() #FP32 ema self.updates=updates self.decay=lambda x:decay*(1-math.exp(-x/2000)) for p in self.ema.parameters(): p.requires_grad_(False) def update(self,model): #更新EMA with torch.no_grad(): self.updates+=1 d=self.decay(self.updates) msd=model.module.state_dict() if is_parallel(model) else model.state_dict() for k,v in self.ema.state_dict().items(): if v.dtype.is_floating_point: v*=d v+=(1.-d)*msd[k].detach() def update_attr(self,model,include=(),exclude=('process_group','reducer')): #更新属性 copy_attr(self.ema,model,include,exclude)
因为篇幅太长,博客太卡了,在此等后续
转载地址:http://wiwsi.baihongyu.com/