1515package auth
1616
1717import (
18+ "context"
1819 "encoding/json"
1920 "errors"
2021 "fmt"
22+ "io"
2123 "io/fs"
2224 "net/http"
2325 "net/url"
@@ -75,46 +77,52 @@ type Credential struct {
7577 Audience string `toml:"audience,omitempty"`
7678}
7779
78- func (t * Credential ) Refresh () bool {
80+ func (t * Credential ) Refresh () error {
7981 switch t .Type {
8082 case TypeApiKey :
81- rsp , err := http .DefaultClient .Do (& http.Request {
82- Method : http .MethodGet ,
83- URL : (* url .URL )(& t .AuthURI ),
84- Header : http.Header {
85- "authorization" : []string {"Bearer " + t .ApiKey },
86- },
87- })
88- if err != nil || rsp .StatusCode != http .StatusOK {
89- return false
83+ req , err := http .NewRequestWithContext (context .Background (),
84+ http .MethodGet , (* url .URL )(& t .AuthURI ).String (), nil )
85+ if err != nil {
86+ return fmt .Errorf ("apikey refresh: %w" , err )
87+ }
88+ req .Header .Set ("Authorization" , "Bearer " + t .ApiKey )
89+
90+ rsp , err := http .DefaultClient .Do (req )
91+ if err != nil {
92+ return fmt .Errorf ("apikey refresh: %w" , err )
9093 }
9194 defer rsp .Body .Close ()
95+ if rsp .StatusCode != http .StatusOK {
96+ return fmt .Errorf ("apikey refresh: status %s" , rsp .Status )
97+ }
9298
9399 var tokenResp struct {
94100 Token string `json:"access_token"`
95101 }
96102 if err := json .NewDecoder (rsp .Body ).Decode (& tokenResp ); err != nil {
97- return false
103+ return err
98104 }
99105
100106 t .Token = tokenResp .Token
101- return true
107+ return nil
102108 case TypeToken :
103109 if err := refreshOauth (t ); err != nil {
104- return false
110+ return fmt . Errorf ( "oauth refresh: %w" , err )
105111 }
106- return true
112+ return nil
107113 }
108114
109- return false
115+ return fmt . Errorf ( "unsupported credential type: %s" , t . Type )
110116}
111117
118+ // GetAuthToken returns the current token, refreshing if needed.
119+ // Must not be called while credMu is held (Refresh may acquire it via UpdateCreds).
112120func (t * Credential ) GetAuthToken () string {
113121 if t .Token != "" {
114122 return t .Token
115123 }
116124
117- if t .Refresh () {
125+ if err := t .Refresh (); err == nil {
118126 _ = UpdateCreds ()
119127 return t .Token
120128 }
@@ -127,18 +135,30 @@ var (
127135 credentialErr error
128136 loaded sync.Once
129137 credPath string
138+ credPathMu sync.Mutex
139+ credMu sync.RWMutex
130140)
131141
132- func init () {
133- var err error
134- credPath , err = internal .GetCredentialPath ()
135- if err != nil {
136- panic (fmt .Sprintf ("failed to get credential path: %s" , err ))
142+ func getCredPath () (string , error ) {
143+ credPathMu .Lock ()
144+ defer credPathMu .Unlock ()
145+ if credPath == "" {
146+ var err error
147+ credPath , err = internal .GetCredentialPath ()
148+ if err != nil {
149+ return "" , fmt .Errorf ("failed to get credential path: %w" , err )
150+ }
137151 }
152+ return credPath , nil
138153}
139154
140155func loadCreds () ([]Credential , error ) {
141- credFile , err := os .Open (credPath )
156+ cp , err := getCredPath ()
157+ if err != nil {
158+ return nil , fmt .Errorf ("failed to get credential path: %w" , err )
159+ }
160+
161+ credFile , err := os .Open (cp )
142162 if err != nil {
143163 if errors .Is (err , os .ErrNotExist ) {
144164 return []Credential {}, nil
@@ -163,6 +183,8 @@ func GetCredentials(u *url.URL) (*Credential, error) {
163183 return nil , err
164184 }
165185
186+ credMu .RLock ()
187+ defer credMu .RUnlock ()
166188 for i , cred := range loadedCredentials {
167189 if cred .RegistryURL .Host == u .Host {
168190 return & loadedCredentials [i ], nil
@@ -174,7 +196,10 @@ func GetCredentials(u *url.URL) (*Credential, error) {
174196
175197func LoadCredentials () error {
176198 loaded .Do (func () {
177- loadedCredentials , credentialErr = loadCreds ()
199+ creds , err := loadCreds ()
200+ credMu .Lock ()
201+ loadedCredentials , credentialErr = creds , err
202+ credMu .Unlock ()
178203 })
179204 return credentialErr
180205}
@@ -184,6 +209,9 @@ func AddCredential(cred Credential, allowOverwrite bool) error {
184209 return err
185210 }
186211
212+ credMu .Lock ()
213+ defer credMu .Unlock ()
214+
187215 idx := slices .IndexFunc (loadedCredentials , func (c Credential ) bool {
188216 return c .RegistryURL .Host == cred .RegistryURL .Host
189217 })
@@ -196,14 +224,17 @@ func AddCredential(cred Credential, allowOverwrite bool) error {
196224 } else {
197225 loadedCredentials = append (loadedCredentials , cred )
198226 }
199- return UpdateCreds ()
227+ return writeCreds ()
200228}
201229
202230func RemoveCredential (host Uri ) error {
203231 if err := LoadCredentials (); err != nil {
204232 return err
205233 }
206234
235+ credMu .Lock ()
236+ defer credMu .Unlock ()
237+
207238 idx := slices .IndexFunc (loadedCredentials , func (c Credential ) bool {
208239 return c .RegistryURL .Host == host .Host
209240 })
@@ -213,20 +244,22 @@ func RemoveCredential(host Uri) error {
213244 }
214245
215246 loadedCredentials = append (loadedCredentials [:idx ], loadedCredentials [idx + 1 :]... )
216- return UpdateCreds ()
247+ return writeCreds ()
217248}
218249
219- func UpdateCreds () error {
220- if err := LoadCredentials (); err != nil {
221- return err
250+ // writeCreds persists loadedCredentials to disk.
251+ // Caller must hold credMu.
252+ func writeCreds () error {
253+ cp , err := getCredPath ()
254+ if err != nil {
255+ return fmt .Errorf ("failed to get credential path: %w" , err )
222256 }
223257
224- err := os .MkdirAll (filepath .Dir (credPath ), 0o700 )
225- if err != nil {
258+ if err := os .MkdirAll (filepath .Dir (cp ), 0o700 ); err != nil {
226259 return err
227260 }
228261
229- f , err := os .OpenFile (credPath , os .O_CREATE | os .O_TRUNC | os .O_RDWR , 0o600 )
262+ f , err := os .OpenFile (cp , os .O_CREATE | os .O_TRUNC | os .O_RDWR , 0o600 )
230263 if err != nil {
231264 return err
232265 }
@@ -239,13 +272,28 @@ func UpdateCreds() error {
239272 })
240273}
241274
275+ func UpdateCreds () error {
276+ if err := LoadCredentials (); err != nil {
277+ return err
278+ }
279+
280+ credMu .Lock ()
281+ defer credMu .Unlock ()
282+ return writeCreds ()
283+ }
284+
242285func PurgeCredentials () error {
286+ cp , err := getCredPath ()
287+ if err != nil {
288+ return fmt .Errorf ("failed to get credential path: %w" , err )
289+ }
290+
243291 var fileList = []string {
244292 "credentials.toml" ,
245293 "columnar.lic" ,
246294 }
247295
248- prefix := filepath .Dir (credPath )
296+ prefix := filepath .Dir (cp )
249297
250298 for _ , file := range fileList {
251299 fullPath := filepath .Join (prefix , file )
@@ -268,16 +316,23 @@ var (
268316 ErrLicenseAlreadyExists = errors .New ("license already exists (use --force to overwrite)" )
269317)
270318
271- func LicensePath () string {
272- return filepath .Join (filepath .Dir (credPath ), "columnar.lic" )
319+ func LicensePath () (string , error ) {
320+ cp , err := getCredPath ()
321+ if err != nil {
322+ return "" , err
323+ }
324+ return filepath .Join (filepath .Dir (cp ), "columnar.lic" ), nil
273325}
274326
275327func InstallLicenseFromFile (srcPath string , force bool ) error {
276328 if ! force && filepath .Base (srcPath ) != "columnar.lic" {
277329 return ErrLicenseWrongFilename
278330 }
279331
280- destPath := LicensePath ()
332+ destPath , err := LicensePath ()
333+ if err != nil {
334+ return fmt .Errorf ("failed to determine license path: %w" , err )
335+ }
281336
282337 if ! force {
283338 if _ , err := os .Stat (destPath ); err == nil {
@@ -302,8 +357,13 @@ func InstallLicenseFromFile(srcPath string, force bool) error {
302357}
303358
304359func FetchColumnarLicense (cred * Credential ) error {
305- licensePath := filepath .Join (filepath .Dir (credPath ), "columnar.lic" )
306- _ , err := os .Stat (licensePath )
360+ cp , err := getCredPath ()
361+ if err != nil {
362+ return fmt .Errorf ("failed to get credential path: %w" , err )
363+ }
364+
365+ licensePath := filepath .Join (filepath .Dir (cp ), "columnar.lic" )
366+ _ , err = os .Stat (licensePath )
307367 if err == nil { // license exists already
308368 return nil
309369 }
@@ -332,7 +392,7 @@ func FetchColumnarLicense(cred *Credential) error {
332392 return fmt .Errorf ("unsupported credential type: %s" , cred .Type )
333393 }
334394
335- req , err := http .NewRequest ( http .MethodGet , licenseURI , nil )
395+ req , err := http .NewRequestWithContext ( context . Background (), http .MethodGet , licenseURI , nil )
336396 if err != nil {
337397 return err
338398 }
@@ -355,14 +415,19 @@ func FetchColumnarLicense(cred *Credential) error {
355415 }
356416 }
357417
358- licenseFile , err := os .OpenFile ( licensePath , os . O_CREATE | os . O_TRUNC | os . O_RDWR , 0o600 )
418+ tmp , err := os .CreateTemp ( filepath . Dir ( licensePath ), ".lic.*" )
359419 if err != nil {
360420 return err
361421 }
362- defer licenseFile .Close ()
363- if _ , err = licenseFile .ReadFrom (resp .Body ); err != nil {
364- licenseFile .Close ()
365- os .Remove (licensePath )
422+ tmpName := tmp .Name ()
423+ defer os .Remove (tmpName )
424+
425+ if _ , err := io .Copy (tmp , resp .Body ); err != nil {
426+ tmp .Close ()
427+ return fmt .Errorf ("write license: %w" , err )
428+ }
429+ if err := tmp .Close (); err != nil {
430+ return fmt .Errorf ("close license temp file: %w" , err )
366431 }
367- return err
432+ return os . Rename ( tmpName , licensePath )
368433}
0 commit comments