ソースコードから理解する技術-UnderSourceCode

手を動かす(プログラムを組む)ことで技術を理解するブログ

aws-sdk-goのMockを使ったテスト

Mocking Techniques for Go. Go provides you all the tools you need… | by kyleyost | CBI Engineering | Jun, 2020 | Medium

こちらの記事を読んでいたところ、以下のような記述がありました。

When working with the aws-sdk, they provide interfaces for all of their major services

(https://medium.com/cbi-engineering/mocking-techniques-for-go-805c10f1676b より)

上記の記事ではDynamoDBを使う場合のテストについて書いてありますが
自分は最近S3を使ったので、S3のinterfaceを使ったテストを書いてみました。

以下、ソースについてです。

ソース

「ListObjectsPages」を使い、指定された「フォルダ」内の.logファイルを取得するメソッドと、そのテストです。
(厳密にはS3には「フォルダ」はないですが、まあサンプルということで・・・)

メソッドを実装した「logfile.go」と、テストの「logfile_test.go」から成ります。

logfile.go

package logfile

import (
	"errors"
	"fmt"
	"path"
	"path/filepath"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3iface"
)

// FindLogFiles ...
func FindLogFiles(svc s3iface.S3API, bucket, folder string) ([]string, error) {
	fmt.Printf("find logFiles start. srcBucket = %s, folder = %s\n", bucket, folder)

	var logFiles []string
	params := &s3.ListObjectsInput{
		Bucket: aws.String(bucket),
		Prefix: aws.String(folder),
	}

	err := svc.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) (shouldContinue bool) {
		for _, obj := range p.Contents {
			fileName := path.Base(*obj.Key)
			if filepath.Ext(fileName) != ".log" {
				continue
			}
			logFiles = append(logFiles, *obj.Key)
		}
		return true
	})
	if err != nil {
		return nil, errors.New("failed to list objects")
	}

	fmt.Println("find logFiles complete.")
	return logFiles, nil
}

logfile_test.go

package logfile

import (
	"fmt"
	"testing"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3iface"
)

type mockS3Client struct {
	s3iface.S3API
	objects []*s3.Object
}

func (m *mockS3Client) ListObjectsPages(input *s3.ListObjectsInput, fn func(p *s3.ListObjectsOutput, last bool) bool) error {
	output := &s3.ListObjectsOutput{}
	output.Contents = m.objects
	last := false
	fn(output, last)
	return nil
}

func TestFindLogFiles(t *testing.T) {
	test := func(objects []*s3.Object, expectCount int) {
		svc := &mockS3Client{objects: objects}
		logFiles, err := FindLogFiles(svc, "sample-bucket", "log-folder/sample-folder")
		if err != nil {
			t.Fatalf("error raise. %#v", err)
		}
		if len(logFiles) != expectCount {
			t.Fatalf("not expect logs, result count = %v\n", len(logFiles))
		}
		fmt.Printf("logFiles = %v\n", logFiles)
	}

	objects := []*s3.Object{}
	objects = append(objects, &s3.Object{Key: aws.String("log_test1.log")})
	objects = append(objects, &s3.Object{Key: aws.String("log_test2.log")})
	objects = append(objects, &s3.Object{Key: aws.String("log_test3.log")})
	objects = append(objects, &s3.Object{Key: aws.String("log_test4.err")})
	test(objects, 3)
}

FindLogFiles()ではS3のサービスのオブジェクトを「s3iface.S3API」インターフェース型の引数として受け取り、そのメソッドを呼び出して一覧を取得するようにしています。
テストではこの「s3iface.S3API」インターフェースを満たすMockを作り(「mockS3Client」)、これをFindLogFiles()に渡すことで実際にS3を呼び出さずにテストをパスするようにしています。