From 5cc1c11b1ac77fc44a28481b9b176c0f7d29cfce Mon Sep 17 00:00:00 2001 From: TwinProduction Date: Tue, 13 Jul 2021 22:59:43 -0400 Subject: [PATCH] Move all transactions to the exported methods --- storage/store/database/database.go | 91 +++++++++++++++--------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/storage/store/database/database.go b/storage/store/database/database.go index 8edfd868..6402ea0d 100644 --- a/storage/store/database/database.go +++ b/storage/store/database/database.go @@ -12,6 +12,10 @@ import ( _ "modernc.org/sqlite" ) +////////////////////////////////////////////////////////////////////////////////////////////////// +// Note that only exported functions in this file may create, commit, or rollback a transaction // +////////////////////////////////////////////////////////////////////////////////////////////////// + const ( arraySeparator = "|~|" ) @@ -121,12 +125,27 @@ func (s *Store) createSchema() error { // GetAllServiceStatusesWithResultPagination returns all monitored core.ServiceStatus // with a subset of core.Result defined by the page and pageSize parameters func (s *Store) GetAllServiceStatusesWithResultPagination(page, pageSize int) map[string]*core.ServiceStatus { - serviceStatuses := s.getAllServiceStatuses(0, 0, page, pageSize) - m := make(map[string]*core.ServiceStatus, len(serviceStatuses)) - for _, serviceStatus := range serviceStatuses { - m[serviceStatus.Key] = serviceStatus + tx, err := s.db.Begin() + if err != nil { + return nil } - return m + keys, err := s.getAllServiceKeys(tx) + if err != nil { + _ = tx.Rollback() + return nil + } + serviceStatuses := make(map[string]*core.ServiceStatus, len(keys)) + for _, key := range keys { + serviceStatus, err := s.getServiceStatusByKey(tx, key, 0, 0, page, pageSize) + if err != nil { + continue + } + serviceStatuses[key] = serviceStatus + } + if err = tx.Commit(); err != nil { + _ = tx.Rollback() + } + return serviceStatuses } // GetServiceStatus returns the service status for a given service name in the given group @@ -136,7 +155,18 @@ func (s *Store) GetServiceStatus(groupName, serviceName string) *core.ServiceSta // GetServiceStatusByKey returns the service status for a given key func (s *Store) GetServiceStatusByKey(key string) *core.ServiceStatus { - serviceStatus, _ := s.getServiceStatusByKey(key, 1, core.MaximumNumberOfEvents, 1, core.MaximumNumberOfResults) + tx, err := s.db.Begin() + if err != nil { + return nil + } + serviceStatus, err := s.getServiceStatusByKey(tx, key, 1, core.MaximumNumberOfEvents, 1, core.MaximumNumberOfResults) + if err != nil { + _ = tx.Rollback() + return nil + } + if err = tx.Commit(); err != nil { + _ = tx.Rollback() + } return serviceStatus } @@ -265,24 +295,8 @@ func (s *Store) Close() { _ = s.db.Close() } -func (s *Store) getAllServiceStatuses(eventsPage, eventsPageSize, resultsPage, resultsPageSize int) []*core.ServiceStatus { - var serviceStatuses []*core.ServiceStatus - keys, err := s.getAllServiceKeys() - if err != nil { - return nil - } - for _, key := range keys { - serviceStatus, err := s.getServiceStatusByKey(key, eventsPage, eventsPageSize, resultsPage, resultsPageSize) - if err != nil { - continue - } - serviceStatuses = append(serviceStatuses, serviceStatus) - } - return serviceStatuses -} - -func (s *Store) getAllServiceKeys() (keys []string, err error) { - rows, err := s.db.Query("SELECT service_key FROM service") +func (s *Store) getAllServiceKeys(tx *sql.Tx) (keys []string, err error) { + rows, err := tx.Query("SELECT service_key FROM service") if err != nil { return nil, err } @@ -295,8 +309,8 @@ func (s *Store) getAllServiceKeys() (keys []string, err error) { return } -func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, resultsPage, resultsPageSize int) (*core.ServiceStatus, error) { // TODO: add uptimePage? - serviceID, serviceName, serviceGroup, err := s.getServiceIDGroupAndNameByKey(key) +func (s *Store) getServiceStatusByKey(tx *sql.Tx, key string, eventsPage, eventsPageSize, resultsPage, resultsPageSize int) (*core.ServiceStatus, error) { // TODO: add uptimePage? + serviceID, serviceName, serviceGroup, err := s.getServiceIDGroupAndNameByKey(tx, key) if err != nil { return nil, err } @@ -307,12 +321,12 @@ func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, re Uptime: nil, } if eventsPageSize > 0 { - if serviceStatus.Events, err = s.getEventsByServiceID(serviceID, eventsPage, eventsPageSize); err != nil { + if serviceStatus.Events, err = s.getEventsByServiceID(tx, serviceID, eventsPage, eventsPageSize); err != nil { log.Printf("[database][getServiceStatusByKey] Failed to retrieve events for key=%s: %s", key, err.Error()) } } if resultsPageSize > 0 { - if serviceStatus.Results, err = s.getResultsByServiceID(serviceID, resultsPage, resultsPageSize); err != nil { + if serviceStatus.Results, err = s.getResultsByServiceID(tx, serviceID, resultsPage, resultsPageSize); err != nil { log.Printf("[database][getServiceStatusByKey] Failed to retrieve results for key=%s: %s", key, err.Error()) } } @@ -321,8 +335,8 @@ func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, re return serviceStatus, nil } -func (s *Store) getServiceIDGroupAndNameByKey(key string) (id int64, group, name string, err error) { - rows, err := s.db.Query("SELECT service_id, service_group, service_name FROM service WHERE service_key = $1 LIMIT 1", key) +func (s *Store) getServiceIDGroupAndNameByKey(tx *sql.Tx, key string) (id int64, group, name string, err error) { + rows, err := tx.Query("SELECT service_id, service_group, service_name FROM service WHERE service_key = $1 LIMIT 1", key) if err != nil { return 0, "", "", err } @@ -336,8 +350,8 @@ func (s *Store) getServiceIDGroupAndNameByKey(key string) (id int64, group, name return } -func (s *Store) getEventsByServiceID(serviceID int64, page, pageSize int) (events []*core.Event, err error) { - rows, err := s.db.Query( +func (s *Store) getEventsByServiceID(tx *sql.Tx, serviceID int64, page, pageSize int) (events []*core.Event, err error) { + rows, err := tx.Query( ` SELECT event_type, event_timestamp FROM service_event @@ -361,11 +375,7 @@ func (s *Store) getEventsByServiceID(serviceID int64, page, pageSize int) (event return } -func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (results []*core.Result, err error) { - tx, err := s.db.Begin() - if err != nil { - return - } +func (s *Store) getResultsByServiceID(tx *sql.Tx, serviceID int64, page, pageSize int) (results []*core.Result, err error) { rows, err := tx.Query( ` SELECT service_result_id, success, errors, connected, status, dns_rcode, certificate_expiration, hostname, ip, duration, timestamp @@ -379,7 +389,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu (page-1)*pageSize, ) if err != nil { - _ = tx.Rollback() return nil, err } idResultMap := make(map[int64]*core.Result) @@ -404,7 +413,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu serviceResultID, ) if err != nil { - _ = tx.Rollback() return } for rows.Next() { @@ -414,10 +422,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu } _ = rows.Close() } - if err = tx.Commit(); err != nil { - _ = tx.Rollback() - return - } return } @@ -564,7 +568,6 @@ func (s *Store) insertConditionResults(tx *sql.Tx, serviceResultID int64, condit cr.Success, ) if err != nil { - _ = tx.Rollback() return err } }