分块下载

package downloader

import (
	"fmt"
	"github.com/k0kubun/go-ansi"
	"github.com/schollz/progressbar/v3"
	"io"
	"net/http"
	"os"
	"path"
	"strings"
	"sync"
)

type Downloader struct {
	concurrency int
	resume      bool
	bar         *progressbar.ProgressBar
}

func NewDownloader(concurrency int, resume bool) *Downloader {
	return &Downloader{
		concurrency: concurrency, resume: resume,
	}
}

func (d *Downloader) Download(strUrl, fileName string) error {
	if fileName == "" {
		fileName = path.Base(strUrl)
		fmt.Println(fileName)
	}
	resp, err := http.Head(strUrl)
	if err != nil {
		return err
	}
	if resp.StatusCode == http.StatusOK && resp.Header.Get("Accept-Ranges") == "bytes" {
		return d.multiDownload(strUrl, fileName, int(resp.ContentLength))
	}
	return d.singleDownload(strUrl, fileName)
}

func (d *Downloader) setBar(length int) {
	d.bar = progressbar.NewOptions(
		length,
		progressbar.OptionSetWriter(ansi.NewAnsiStdout()),
		progressbar.OptionEnableColorCodes(true),
		progressbar.OptionShowBytes(true),
		progressbar.OptionSetWidth(50),
		progressbar.OptionSetDescription("downloading..."),
		progressbar.OptionSetTheme(progressbar.Theme{
			Saucer:        "[green]=[reset]",
			SaucerHead:    "[green]>[reset]",
			SaucerPadding: " ",
			BarStart:      "[",
			BarEnd:        "]",
		}),
	)
}

func (d *Downloader) multiDownload(strUrl, fileName string, contentLen int) error {
	d.setBar(contentLen)
	partSize := contentLen / d.concurrency
	// 创建部分文件的存放目录
	partDir := d.getPartDir(fileName)
	os.Mkdir(partDir, 0777)
	defer os.RemoveAll(partDir)
	var wg sync.WaitGroup
	wg.Add(d.concurrency)
	rangeStart := 0
	for i := 0; i < d.concurrency; i++ {
		go func(i, rangeStart int) {
			defer wg.Done()
			rangeEnd := rangeStart + partSize
			// 最后一部分,总长度不能超过 ContentLength
			if i == d.concurrency-1 {
				rangeEnd = contentLen
			}
			download := 0
			if d.resume {
				partFileName := d.getPartFileName(fileName, i)
				content, err := os.ReadFile(partFileName)
				if err == nil {
					download = len(content)
				}
				d.bar.Add(download)
			}
			d.downloadPartial(strUrl, fileName, rangeStart+download, rangeEnd, i)
		}(i, rangeStart)
		rangeStart += partSize + 1
	}
	wg.Wait()
	// 合并文件
	d.merge(fileName)
	return nil
}

func (d *Downloader) merge(fileName string) error {
	destFile, err := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		return err
	}
	defer destFile.Close()
	for i := 0; i < d.concurrency; i++ {
		partFileName := d.getPartFileName(fileName, i)
		partFile, err := os.Open(partFileName)
		if err != nil {
			return err
		}
		io.Copy(destFile, partFile)
		partFile.Close()
		os.Remove(partFileName)
	}
	return nil
}

func (d *Downloader) singleDownload(strUrl, fileName string) error {
	resp, err := http.Get(strUrl)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	d.setBar(int(resp.ContentLength))

	f, err := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		return err
	}
	defer f.Close()

	buf := make([]byte, 32*1024)
	_, err = io.CopyBuffer(io.MultiWriter(f, d.bar), resp.Body, buf)
	return err
}

func (d *Downloader) downloadPartial(strUrl, fileName string, rangeStart, rangeEnd, i int) error {
	if rangeStart >= rangeEnd {
		return nil
	}
	req, err := http.NewRequest("GET", strUrl, nil)
	if err != nil {
		return err
	}
	req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", rangeStart, rangeEnd))
	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	flags := os.O_CREATE | os.O_WRONLY
	partFile, err := os.OpenFile(d.getPartFileName(fileName, i), flags, 0666)
	if err != nil {
		return err
	}
	defer partFile.Close()
	buf := make([]byte, 32*1024)
	_, err = io.CopyBuffer(io.MultiWriter(partFile,d.bar), resp.Body, buf)
	if err != nil {
		if err == io.EOF {
			return nil
		}
		return err
	}
	return nil
}

// getPartDir 部分文件存放的目录
func (d *Downloader) getPartDir(fileName string) string {
	return strings.SplitN(fileName, ".", 2)[0]
}

// getPartFilename 构造部分文件的名字
func (d *Downloader) getPartFileName(fileName string, partNum int) string {
	partDir := d.getPartDir(fileName)
	return fmt.Sprintf("%s/%s-%d", partDir, fileName, partNum)
}
package downloader

import (
	"fmt"
	"runtime"
	"testing"
)

func TestNewDownloader(t *testing.T) {
	strUrl := "https://apache.claz.org/zookeeper/zookeeper-3.7.0/apache-zookeeper-3.7.0-bin.tar.gz"
	fileName := ""
	concurrencyNum := runtime.NumCPU()
	err := NewDownloader(concurrencyNum, true).Download(strUrl, fileName)
	if err != nil {
		fmt.Println(err.Error())
	} else {
		fmt.Println("下载完毕")
	}
}

最后更新于