186 lines
4.4 KiB
Go
186 lines
4.4 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
// Modifications by M/S Omukk
|
|
|
|
package host
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
|
)
|
|
|
|
const (
|
|
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
|
|
|
|
mmdsDefaultAddress = "169.254.169.254"
|
|
mmdsTokenExpiration = 60 * time.Second
|
|
|
|
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
|
|
)
|
|
|
|
var mmdsAccessTokenClient = &http.Client{
|
|
Timeout: mmdsAccessTokenRequestClientTimeout,
|
|
Transport: &http.Transport{
|
|
DisableKeepAlives: true,
|
|
},
|
|
}
|
|
|
|
type MMDSOpts struct {
|
|
SandboxID string `json:"instanceID"`
|
|
TemplateID string `json:"envID"`
|
|
LogsCollectorAddress string `json:"address"`
|
|
AccessTokenHash string `json:"accessTokenHash"`
|
|
}
|
|
|
|
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
|
|
opts.SandboxID = sandboxID
|
|
opts.TemplateID = templateID
|
|
opts.LogsCollectorAddress = collectorAddress
|
|
}
|
|
|
|
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
|
|
parsed := make(map[string]any)
|
|
|
|
err := json.Unmarshal(jsonLogs, &parsed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsed["instanceID"] = opts.SandboxID
|
|
parsed["envID"] = opts.TemplateID
|
|
|
|
data, err := json.Marshal(parsed)
|
|
|
|
return data, err
|
|
}
|
|
|
|
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
|
|
|
|
response, err := client.Do(request)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
token := string(body)
|
|
|
|
if len(token) == 0 {
|
|
return "", fmt.Errorf("mmds token is an empty string")
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
request.Header["X-metadata-token"] = []string{token}
|
|
request.Header["Accept"] = []string{"application/json"}
|
|
|
|
response, err := client.Do(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var opts MMDSOpts
|
|
|
|
err = json.Unmarshal(body, &opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &opts, nil
|
|
}
|
|
|
|
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
|
|
// This is used to validate that /init requests come from the orchestrator.
|
|
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
|
|
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get MMDS token: %w", err)
|
|
}
|
|
|
|
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
|
|
}
|
|
|
|
return opts.AccessTokenHash, nil
|
|
}
|
|
|
|
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
|
|
httpClient := &http.Client{}
|
|
defer httpClient.CloseIdleConnections()
|
|
|
|
ticker := time.NewTicker(50 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
|
|
|
|
return
|
|
case <-ticker.C:
|
|
token, err := getMMDSToken(ctx, httpClient)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
|
|
|
|
continue
|
|
}
|
|
|
|
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
|
|
|
|
continue
|
|
}
|
|
|
|
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
|
|
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
|
|
|
|
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
|
|
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
|
|
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
|
|
}
|
|
|
|
if mmdsOpts.LogsCollectorAddress != "" {
|
|
mmdsChan <- mmdsOpts
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|