diff --git a/README.md b/README.md index ce4dc78..34f920d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ A golang library for extracting and creating archives. ## Roadmap - [ ] Automatically detect the archive format -- [ ] Add mutexes to work with concurrency +- [x] Add mutexes to work with concurrency ## Supported Formats - `.tar.gz` diff --git a/archive.go b/archive.go index e27779f..a18d799 100644 --- a/archive.go +++ b/archive.go @@ -10,6 +10,7 @@ import ( "io/fs" "os" "regexp" + "sync" ) type Type string @@ -36,6 +37,7 @@ type Archive struct { tarReader *tar.Reader // Used for anything with .tar due to how tar.Reader cannot be reset. files map[string]*File archiveFile *bytes.Reader + mu sync.Mutex } // Filesystem represents a standard interface for filesystems. @@ -109,6 +111,8 @@ type File struct { // It takes the path of the file in the Archive as its parameter. // The function returns ArchiveFile and an error, if any. func (a *Archive) GetFile(path string) (*File, error) { + a.mu.Lock() + defer a.mu.Unlock() file, ok := a.files[path] if !ok { return nil, ErrArchiveFileNotFound diff --git a/archive_test.go b/archive_test.go index 058f12a..759d307 100644 --- a/archive_test.go +++ b/archive_test.go @@ -10,6 +10,7 @@ import ( "os" "regexp" "strconv" + "sync" "testing" ) @@ -19,6 +20,43 @@ const ( var archiveRegex = regexp.MustCompile(`(?m)test[1|5]`) +func TestArchiveConcurrency(t *testing.T) { + err := os.MkdirAll(testArchiveBaseDir, os.ModePerm) + assert.NoError(t, err) + + testGenerateZip(t) + + archive, err := Open(Zip, testArchiveBaseDir+"/test.zip") + assert.NoError(t, err) + + var wg sync.WaitGroup + + // Slam it, see if it breaks. + for _ = range 100_000 { + go func() { + wg.Add(1) + defer wg.Done() + file, err := archive.GetFile("test0.txt") + assert.NoError(t, err) + assert.Equal(t, "test0.txt", file.FileName) + + err = file.Extract(ExtractFileOptions{ + Overwrite: true, + Folder: testArchiveBaseDir + "/extracted/zip", + }) + assert.NoError(t, err) + }() + } + + wg.Wait() + + err = os.RemoveAll(testArchiveBaseDir + "/extracted") + assert.NoError(t, err) + + err = archive.Close() + assert.NoError(t, err) +} + func TestArchiveExtract(t *testing.T) { err := os.MkdirAll(testArchiveBaseDir, os.ModePerm) assert.NoError(t, err) @@ -59,6 +97,9 @@ func TestArchiveExtract(t *testing.T) { }) assert.NoError(t, err) + err = os.RemoveAll(testArchiveBaseDir + "/extracted") + assert.NoError(t, err) + } func TestArchiveZip(t *testing.T) { diff --git a/extract.go b/extract.go index 19a9c7f..cf8cdca 100644 --- a/extract.go +++ b/extract.go @@ -22,6 +22,9 @@ type extractOptions struct { } func extract(filesystem Filesystem, opts ExtractOptions, archive *Archive) error { + archive.mu.Lock() + defer archive.mu.Unlock() + if filesystem.billyFS == nil { filesystem.file = true } @@ -55,6 +58,9 @@ func extract(filesystem Filesystem, opts ExtractOptions, archive *Archive) error } func extractFile(filesystem Filesystem, opts ExtractFileOptions, file *File) error { + file.archive.mu.Lock() + defer file.archive.mu.Unlock() + if filesystem.billyFS == nil { filesystem.file = true }